Source code for wbia.algo.verif.torch.train_main

# -*- coding: utf-8 -*-
import logging
from os.path import join  # NOQA
import cv2
import numpy as np
import torch
import torch.nn
import utool as ut
import torchvision

print, rrr, profile = ut.inject2(__name__)
logger = logging.getLogger('wbia')


[docs]class LRSchedule(object):
[docs] @staticmethod def exp(optimizer, epoch, init_lr=0.001, lr_decay_epoch=2): """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" lr = init_lr # epoch += 1 if epoch % lr_decay_epoch == 0 and epoch != 0: lr *= 0.1 if epoch % lr_decay_epoch == 0: logger.info('LR is set to {}'.format(lr)) for param_group in optimizer.param_groups: param_group['lr'] = lr return optimizer, lr
[docs]def siam_vsone_train(): r""" CommandLine: python -m wbia.algo.verif.torch.train_main siam_vsone_train Example: >>> # DISABLE_DOCTEST >>> from wbia.algo.verif.torch.train_main import * # NOQA >>> siam_vsone_train() """ # wrapper around the RF vsone problem from wbia.algo.verif import vsone # pblm = vsone.OneVsOneProblem.from_empty('PZ_MTEST') pblm = vsone.OneVsOneProblem.from_empty('GZ_Master1') ibs = pblm.infr.ibs pblm.load_samples() samples = pblm.samples samples.print_info() xval_kw = pblm.xval_kw.asdict() skf_list = pblm.samples.stratified_kfold_indices(**xval_kw) def load_dataset(subset_idx): aids1, aids2 = pblm.samples.aid_pairs[subset_idx].T labels = pblm.samples['match_state'].y_enc[subset_idx] # train only on positive-vs-negative (ignore incomparable) labels = (labels == 1).astype(np.int64) chip_config = {'resize_dim': 'wh', 'dim_size': (224, 224)} img1_fpaths = ibs.depc_annot.get( 'chips', aids1, read_extern=False, colnames='img', config=chip_config ) img2_fpaths = ibs.depc_annot.get( 'chips', aids2, read_extern=False, colnames='img', config=chip_config ) dataset = LabeledPairDataset(img1_fpaths, img2_fpaths, labels) return dataset learn_idx, test_idx = skf_list[0] train_idx, val_idx = pblm.samples.subsplit_indices(learn_idx, n_splits=10)[0] # Split everything in the learning set into training / validation train_dataset = load_dataset(train_idx) vali_dataset = load_dataset(val_idx) test_dataset = load_dataset(test_idx) logger.info('* len(train_dataset) = {}'.format(len(train_dataset))) logger.info('* len(vali_dataset) = {}'.format(len(vali_dataset))) logger.info('* len(test_dataset) = {}'.format(len(test_dataset))) from wbia.algo.verif.torch import gpu_util gpu_num = gpu_util.find_unused_gpu(min_memory=6000) use_cuda = gpu_num is not None data_kw = {} if use_cuda: data_kw = {'num_workers': 6, 'pin_memory': True} batch_size = 64 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, **data_kw ) vali_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, **data_kw ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, **data_kw ) from wbia.algo.verif.torch import fit_harness from wbia.algo.verif.torch import models from wbia.algo.verif.torch import netmath from wbia.algo.verif.torch import lr_schedule model = models.Siamese() criterion = netmath.Criterions.ContrastiveLoss(margin=1) lr_scheduler = lr_schedule.Exponential() optimizer_cls = netmath.Optimizers.Adam class_weights = train_dataset.class_weights() logger.info('class_weights = {!r}'.format(class_weights)) harn = fit_harness.FitHarness( model=model, criterion=criterion, lr_scheduler=lr_scheduler, train_loader=train_loader, vali_loader=vali_loader, test_loader=test_loader, optimizer_cls=optimizer_cls, class_weights=class_weights, gpu_num=gpu_num, ) harn.run()
[docs]class LabeledPairDataset(torch.utils.data.Dataset): """ transform=transforms.Compose([ transforms.Scale(224), transforms.ToTensor(), torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.225, 0.225, 0.225]) ] Ignore: >>> from wbia.algo.verif.torch.train_main import * >>> from wbia.algo.verif.vsone import * # NOQA >>> pblm = OneVsOneProblem.from_empty('PZ_MTEST') >>> ibs = pblm.infr.ibs >>> pblm.load_samples() >>> samples = pblm.samples >>> samples.print_info() >>> xval_kw = pblm.xval_kw.asdict() >>> skf_list = pblm.samples.stratified_kfold_indices(**xval_kw) >>> train_idx, test_idx = skf_list[0] >>> aids1, aids2 = pblm.samples.aid_pairs[train_idx].T >>> labels = pblm.samples['match_state'].y_enc[train_idx] >>> labels = (labels == 1).astype(np.int64) >>> chip_config = {'resize_dim': 'wh', 'dim_size': (224, 224)} >>> img1_fpaths = ibs.depc_annot.get('chips', aids1, read_extern=False, colnames='img', config=chip_config) >>> img2_fpaths = ibs.depc_annot.get('chips', aids2, read_extern=False, colnames='img', config=chip_config) >>> self = LabeledPairDataset(img1_fpaths, img2_fpaths, labels) >>> img1, img2, label = self[0] """ def __init__(self, img1_fpaths, img2_fpaths, labels, transform='default'): assert len(img1_fpaths) == len(img2_fpaths) assert len(labels) == len(img2_fpaths) self.img1_fpaths = img1_fpaths self.img2_fpaths = img2_fpaths self.labels = labels if transform == 'default': transform = torchvision.transforms.Compose( [ # torchvision.transforms.Scale(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( [0.5, 0.5, 0.5], [0.225, 0.225, 0.225] ), ] ) self.transform = transform
[docs] def class_weights(self): import pandas as pd label_freq = pd.value_counts(self.labels) class_weights = label_freq.median() / label_freq class_weights = class_weights.sort_index().values class_weights = torch.from_numpy(class_weights.astype(np.float32)) return class_weights
def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image1, image2, label) """ fpath1 = self.img1_fpaths[index] fpath2 = self.img2_fpaths[index] label = self.labels[index] def loader(fpath): bgr_255 = cv2.imread(fpath) bgr_01 = bgr_255.astype(np.float32) / 255.0 rgb_01 = cv2.cvtColor(bgr_01, cv2.COLOR_BGR2RGB) return rgb_01 img1 = loader(fpath1) img2 = loader(fpath2) if self.transform: img1 = self.transform(img1) img2 = self.transform(img2) return img1, img2, label def __len__(self): return len(self.img1_fpaths)