Source code for wbia.algo.verif.torch.models

# -*- coding: utf-8 -*-
import logging
import torch
import torchvision
import utool as ut

print, rrr, profile = ut.inject2(__name__)
logger = logging.getLogger('wbia')


[docs]class Siamese(torch.nn.Module): """ Example: >>> # DISABLE_DOCTEST >>> from wbia.algo.verif.siamese import * >>> self = Siamese() """ def __init__(self): ut.super2(Siamese, self).__init__() self.resnet = torchvision.models.resnet50(pretrained=True) self.num_fcin = self.resnet.fc.in_features # replace the last layer of resnet self.resnet.fc = torch.nn.Linear(self.num_fcin, 500) self.pdist = torch.nn.PairwiseDistance(p=2)
[docs] def forward(self, input1, input2): """ Compute a resnet50 vector for each input and look at the L2 distance between the vectors. """ output1 = self.resnet(input1) output2 = self.resnet(input2) output = self.pdist(output1, output2) return output
[docs]def visualize(): import networkx as nx import torch from torch.autograd import Variable def make_nx(var, params): param_map = {id(v): k for k, v in params.items()} logger.info(param_map) node_attr = dict( style='filled', shape='box', align='left', fontsize='12', ranksep='0.1', height='0.2', ) seen = set() G = nx.DiGraph() def size_to_str(size): return '(' + (', ').join(['%d' % v for v in size]) + ')' def build_graph(var): if var not in seen: if torch.is_tensor(var): G.add_node( id(var), label=size_to_str(var.size()), fillcolor='orange', **node_attr, ) elif hasattr(var, 'variable'): u = var.variable node_name = '%s\n %s' % (param_map.get(id(u)), size_to_str(u.size())) G.add_node( id(var), label=node_name, fillcolor='lightblue', **node_attr ) else: G.add_node(id(var), label=str(type(var).__name__), **node_attr) seen.add(var) if hasattr(var, 'next_functions'): for u in var.next_functions: if u[0] is not None: G.add_edge(id(u[0]), id(var)) build_graph(u[0]) if hasattr(var, 'saved_tensors'): for t in var.saved_tensors: G.add_edge(id(t), id(var)) build_graph(t) build_graph(var.grad_fn) return G # inputs = torch.randn(1, 3, 224, 224) # resnet18 = models.resnet18() # y = resnet18(Variable(inputs)) inputs = torch.randn(1, 3, 224, 224) # model = torchvision.models.resnet18() model = torchvision.models.resnet50() model = Siamese() # y = model(Variable(inputs)) y = model(Variable(inputs), Variable(inputs)) params = model.state_dict() G = make_nx(y, params) import wbia.plottool as pt pt.dump_nx_ondisk(G, './pytorch_network.png') ut.startfile('./pytorch_network.png')
# pt.show_nx(G, arrow_width=1) # pt.zoom_factory() # pt.pan_factory()