Source code for wbia.algo.verif.torch.fit_harness

# -*- coding: utf-8 -*-
import logging
import torch
from collections import defaultdict
from os.path import join
from torch.autograd import Variable
from wbia.algo.verif.torch import netmath
import tensorboard_logger
import utool as ut

# from wbia.algo.verif.torch import gpu_util

(print, rrr, profile) = ut.inject2(__name__)
logger = logging.getLogger('wbia')


[docs]class FitHarness(object): def __init__( harn, model, train_loader, vali_loader=None, test_loader=None, criterion='cross_entropy', lr_scheduler='exp', optimizer_cls='Adam', class_weights=None, gpu_num=None, workdir=None, ): harn.workdir = workdir harn.train_loader = train_loader harn.vali_loader = vali_loader harn.test_loader = test_loader harn.model = model harn.optimizer_cls = optimizer_cls harn.criterion = criterion harn.lr_scheduler = lr_scheduler # netmath.Optimizers.lookup(optimizer_cls) # netmath.Criterions.lookup(criterion) # netmath.LRSchedules.lookup(lr_scheduler) harn.class_weights = class_weights harn.gpu_num = gpu_num harn.use_cuda = harn.gpu_num is not None # harn.model = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda() harn.config = { 'maxIterations': 10000, 'displayInterval': 1, 'vail_displayInterval': 1, 'model_dir': '.', 'margin': 1.0, } harn.lr = harn.lr_scheduler.init_lr harn.epoch = 0
[docs] def log(harn, msg): logger.info(msg)
[docs] def log_value(harn, key, value, n_iter): if False: logger.info('{}={} @ {}'.format(key, value, n_iter)) if tensorboard_logger: tensorboard_logger.log_value(key, value, n_iter)
def _to_xpu(harn, *args): """Puts data on the GPU if available""" if harn.use_cuda: args = [Variable(item.cuda(harn.gpu_num)) for item in args] # input_batch = [Variable(item.cuda()) for item in input_batch] else: args = [Variable(item) for item in args] return args
[docs] def run(harn): harn.log('Begin training') if False: # TODO: can we run this as a subprocess that dies when we die? # or do we need to run externally? # tensorboard --logdir runs # http://aretha:6006 pass if tensorboard_logger: harn.log('Initializing tensorboard') tensorboard_logger.configure('runs/wbia', flush_secs=2) if harn.use_cuda: harn.log('Fitting model on GPU({})'.format(harn.gpu_num)) harn.model.cuda(harn.gpu_num) else: harn.log('Fitting model on the CPU') if harn.class_weights is not None: (harn.class_weights,) = harn._to_xpu(harn.class_weights) lr = harn.lr_scheduler(harn.epoch) harn.optimizer = harn.optimizer_cls(harn.model.parameters(), lr=lr) # train loop while not harn.check_termination(): harn.train_epoch() if harn.vali_loader: harn.validation_epoch() harn.save_snapshot() harn.epoch += 1
[docs] def check_termination(harn): # check for termination if harn.epoch > harn.config['maxIterations']: harn.log('Maximum harn.epoch reached, terminating ...') return True return False
[docs] def train_epoch(harn): ave_metrics = defaultdict(lambda: 0) # change learning rate (modified optimizer inplace) lr = harn.lr_scheduler(harn.epoch, harn.optimizer) # train batch for batch_idx, input_batch in enumerate(harn.train_loader): input_batch = harn._to_xpu(*input_batch) # logger.info('Begin batch {}'.format(batch_idx)) t_cur_metrics = harn.train_batch(input_batch) for k, v in t_cur_metrics.items(): ave_metrics[k] += v # display training info if (batch_idx + 1) % harn.config['displayInterval'] == 0: for k in ave_metrics.keys(): ave_metrics[k] /= harn.config['displayInterval'] n_train = len(harn.train_loader) harn.log( 'Epoch {0}: {1} / {2} | lr:{3} - tloss:{4:.5f} acc:{5:.2f} | sdis:{6:.3f} ddis:{7:.3f}'.format( harn.epoch, batch_idx, n_train, lr, ave_metrics['loss'], ave_metrics['accuracy'], ave_metrics['pos_dist'], ave_metrics['neg_dist'], ) ) iter_idx = harn.epoch * n_train + batch_idx for key, value in ave_metrics.items(): harn.log_value('train ' + key, value, iter_idx) # diagnoseGradients(model.parameters()) for k in ave_metrics.keys(): ave_metrics[k] = 0
[docs] def validation_epoch(harn): ave_metrics = defaultdict(lambda: 0) final_metrics = ave_metrics.copy() for vali_idx, input_batch in enumerate(harn.vali_loader): input_batch = harn._to_xpu(*input_batch) # logger.info('Begin batch {}'.format(vali_idx)) v_cur_metrics = harn.validation_batch(input_batch) for k, v in v_cur_metrics.items(): ave_metrics[k] += v final_metrics[k] += v if (vali_idx + 1) % harn.config['vail_displayInterval'] == 0: for k in ave_metrics.keys(): ave_metrics[k] /= harn.config['displayInterval'] harn.log( 'Epoch {0}: {1} / {2} | vloss:{3:.5f} acc:{4:.2f} | sdis:{5:.3f} ddis:{6:.3f}'.format( harn.epoch, vali_idx, len(harn.vali_loader), ave_metrics['loss'], ave_metrics['accuracy'], ave_metrics['pos_dist'], ave_metrics['neg_dist'], ) ) for k in ave_metrics.keys(): ave_metrics[k] = 0 for k in final_metrics.keys(): final_metrics[k] /= len(harn.vali_loader) harn.log( 'Epoch {0}: final vloss:{1:.5f} acc:{2:.2f} | sdis:{3:.3f} ddis:{4:.3f}'.format( harn.epoch, final_metrics['loss'], final_metrics['accuracy'], final_metrics['pos_dist'], final_metrics['neg_dist'], ) ) iter_idx = harn.epoch * len(harn.vali_loader) + vali_idx for key, value in final_metrics.items(): harn.log_value('validation ' + key, value, iter_idx)
# def display_metrics(): # pass
[docs] def load_snapshot(harn, load_path): snapshot = torch.load(load_path) harn.model.load_state_dict(snapshot['model_state_dict']) harn.epoch = snapshot['epoch'] harn.log('Model loaded from {}'.format(load_path))
[docs] def save_snapshot(harn): # save snapshot save_path = join( harn.config['model_dir'], 'snapshot_epoch_{}.pt'.format(harn.epoch) ) snapshot = { 'epoch': harn.epoch, 'model_state_dict': harn.model.state_dict(), } torch.save(snapshot, save_path) harn.log('Snapshot saved to {}'.format(save_path))
[docs] def train_batch(harn, input_batch): """ https://github.com/meetshah1995/pytorch-semseg/blob/master/train.py """ harn.model.train(True) *inputs, label = input_batch # Forward prop through the model output = harn.model(*inputs) # Compute the loss loss = harn.criterion(output, label, weight=harn.class_weights) # Measure train accuracy and other informative metrics t_metrics = harn._measure_metrics(output, label, loss) # Backprop and learn harn.optimizer.zero_grad() loss.backward() harn.optimizer.step() return t_metrics
[docs] def validation_batch(harn, input_batch): harn.model.train(False) *inputs, label = input_batch output = harn.model(*inputs) loss = harn.criterion(output, label, weight=harn.class_weights) # Measure validation accuracy and other informative metrics v_metrics = harn._measure_metrics(output, label, loss) return v_metrics
def _measure_metrics(harn, output, label, loss): metrics = netmath.Metrics._siamese_metrics( output, label, margin=harn.criterion.margin ) assert 'loss' not in metrics, 'cannot compute loss as an extra metric' loss_sum = loss.data.sum() inf = float('inf') if loss_sum == inf or loss_sum == -inf: harn.log('WARNING: received an inf loss, setting loss value to 0') loss_value = 0 else: loss_value = loss.data[0] metrics['loss'] = loss_value # metrics = { # 'tpr': netmath.Metrics.tpr(output, label) # } return metrics