Source code for wbia.scripts.postdoc

# -*- coding: utf-8 -*-
import logging
import wbia.plottool as pt
import utool as ut
from wbia.algo.verif import vsone
from wbia.scripts._thesis_helpers import DBInputs
from wbia.scripts.thesis import Sampler  # NOQA
from wbia.scripts._thesis_helpers import Tabular, upper_one, ave_str
from wbia.scripts._thesis_helpers import dbname_to_species_nice
from wbia.scripts._thesis_helpers import TMP_RC, W, H, DPI
from wbia.algo.graph.state import POSTV, NEGTV, INCMP, UNREV  # NOQA
import numpy as np  # NOQA
import pandas as pd
import ubelt as ub  # NOQA
import itertools as it
import matplotlib as mpl
from os.path import basename, join, splitext, exists  # NOQA
import wbia.constants as const
import vtool as vt
from wbia.algo.graph.state import POSTV, NEGTV, INCMP, UNREV, UNKWN  # NOQA

(print, rrr, profile) = ut.inject2(__name__)
logger = logging.getLogger('wbia')


CLF = 'VAMP'
LNBNN = 'LNBNN'


[docs]def review_pz(): import wbia ibs = wbia.opendb('GZ_Master1') infr = wbia.AnnotInference(ibs, aids='all') infr.reset_feedback('staging', apply=True) infr.relabel_using_reviews(rectify=True) # infr.apply_nondynamic_update() logger.info(ut.repr4(infr.status())) infr.wbia_delta_info() infr.match_state_delta() infr.get_wbia_name_delta() infr.relabel_using_reviews(rectify=True) infr.write_wbia_annotmatch_feedback() infr.write_wbia_name_assignment() pass
[docs]@ut.reloadable_class class GraphExpt(DBInputs): r""" TODO: - [ ] Experimental analysis of duration of each phase and state of graph. - [ ] Experimental analysis of phase 3, including how far we can get with automatic decision making and do we discover new merges? If there are potential merges, can we run phase iii with exactly the same ordering as before: ordering by probability for automatically decidable and then by positive probability for others. This should work for phase 3 and therefore allow a clean combination of the three phases and our termination criteria. I just thought of this so don't really have it written cleanly above. - [ ] Experimental analysis of choice of automatic decision thresholds. by lowering the threshold we increase the risk of mistakes. Each mistake costs some number of manual reviews (perhaps 2-3), but if the frequency of errors is low then we could be saving ourselves a lot of manual reviews. \item OTHER SPECIES CommandLine: python -m wbia GraphExpt.measure all PZ_MTEST Ignore: >>> from wbia.scripts.postdoc import * >>> self = GraphExpt('PZ_MTEST') >>> self._precollect() >>> self._setup() """ base_dpath = ut.truepath('~/Desktop/graph_expt') def _precollect(self): if self.ibs is None: _GraphExpt = ut.fix_super_reload(GraphExpt, self) super(_GraphExpt, self)._precollect() # Split data into a training and testing test ibs = self.ibs annots = ibs.annots(self.aids_pool) names = list(annots.group_items(annots.nids).values()) ut.shuffle(names, rng=321) train_names, test_names = names[0::2], names[1::2] train_aids, test_aids = map(ut.flatten, (train_names, test_names)) self.test_train = train_aids, test_aids params = {} self.pblm = vsone.OneVsOneProblem.from_aids(ibs, train_aids, **params) # ut.get_nonconflicting_path(dpath, suffix='_old') self.const_dials = { # 'oracle_accuracy' : (0.98, 1.0), # 'oracle_accuracy' : (0.98, .98), 'oracle_accuracy': (0.99, 0.99), 'k_redun': 2, 'max_outer_loops': np.inf, # 'max_outer_loops' : 1, } config = ut.dict_union(self.const_dials) cfg_prefix = '{}_{}'.format(len(test_aids), len(train_aids)) self._setup_links(cfg_prefix, config) def _setup(self): """ python -m wbia GraphExpt._setup Example: >>> # DISABLE_DOCTEST >>> from wbia.scripts.postdoc import * >>> #self = GraphExpt('GZ_Master1') >>> self = GraphExpt('PZ_MTEST') >>> self = GraphExpt('PZ_Master1') >>> self._setup() """ self._precollect() train_aids, test_aids = self.test_train task_key = 'match_state' pblm = self.pblm data_key = pblm.default_data_key clf_key = pblm.default_clf_key pblm.eval_data_keys = [data_key] pblm.setup(with_simple=False) pblm.learn_evaluation_classifiers() res = pblm.task_combo_res[task_key][clf_key][data_key] # pblm.report_evaluation() # TODO: need more principled way of selecting thresholds # graph_thresh = res.get_pos_threshes('fpr', 0.01) graph_thresh = res.get_pos_threshes('fpr', 0.001) # rankclf_thresh = res.get_pos_threshes(fpr=0.01) # Load or create the deploy classifiers clf_dpath = ut.ensuredir((self.dpath, 'clf')) classifiers = pblm.ensure_deploy_classifiers(dpath=clf_dpath) sim_params = { 'test_aids': test_aids, 'train_aids': train_aids, 'classifiers': classifiers, 'graph_thresh': graph_thresh, # 'rankclf_thresh': rankclf_thresh, 'const_dials': self.const_dials, } self.pblm = pblm self.sim_params = sim_params return sim_params
[docs] def measure_all(self): self.measure_graphsim()
[docs] @profile def measure_graphsim(self): """ CommandLine: python -m wbia GraphExpt.measure graphsim GZ_Master1 1 Ignore: >>> from wbia.scripts.postdoc import * >>> #self = GraphExpt('PZ_MTEST') >>> #self = GraphExpt('GZ_Master1') >>> self = GraphExpt.measure('graphsim', 'PZ_Master1') >>> self = GraphExpt.measure('graphsim', 'GZ_Master1') >>> self = GraphExpt.measure('graphsim', 'PZ_MTEST') """ import wbia self.ensure_setup() ibs = self.ibs sim_params = self.sim_params classifiers = sim_params['classifiers'] test_aids = sim_params['test_aids'] graph_thresh = sim_params['graph_thresh'] const_dials = sim_params['const_dials'] sim_results = {} verbose = 1 # ---------- # Graph test dials1 = ut.dict_union( const_dials, { 'name': 'graph', 'enable_inference': True, 'match_state_thresh': graph_thresh, }, ) infr1 = wbia.AnnotInference( ibs=ibs, aids=test_aids, autoinit=True, verbose=verbose ) infr1.enable_auto_prioritize_nonpos = True infr1.params['refresh.window'] = 20 infr1.params['refresh.thresh'] = 0.052 infr1.params['refresh.patience'] = 72 infr1.params['redun.enforce_pos'] = True infr1.params['redun.enforce_neg'] = True infr1.init_simulation(classifiers=classifiers, **dials1) infr1.init_test_mode() infr1.reset(state='empty') # if False: # infr = infr1 # infr.init_refresh() # n_prioritized = infr.refresh_candidate_edges() # gen = infr.lnbnn_priority_gen(use_refresh=True) # next(gen) # edge = (25, 118) list(infr1.main_gen()) # infr1.main_loop() sim_results['graph'] = self._collect_sim_results(infr1, dials1) # ------------ # Dump experiment output to disk expt_name = 'graphsim' self.expt_results[expt_name] = sim_results ut.ensuredir(self.dpath) ut.save_data(join(self.dpath, expt_name + '.pkl'), sim_results)
def _collect_sim_results(self, infr, dials): pred_confusion = pd.DataFrame(infr.test_state['confusion']) pred_confusion.index.name = 'real' pred_confusion.columns.name = 'pred' logger.info('Edge confusion') logger.info(pred_confusion) expt_data = { 'real_ccs': list(infr.nid_to_gt_cc.values()), 'pred_ccs': list(infr.pos_graph.connected_components()), 'graph': infr.graph.copy(), 'dials': dials, 'refresh_thresh': infr.refresh._prob_any_remain_thresh, 'metrics': infr.metrics_list, } return expt_data
[docs] def draw_graphsim(self): """ CommandLine: python -m wbia GraphExpt.measure graphsim GZ_Master1 python -m wbia GraphExpt.draw graphsim GZ_Master1 --diskshow python -m wbia GraphExpt.draw graphsim PZ_MTEST --diskshow python -m wbia GraphExpt.draw graphsim GZ_Master1 --diskshow python -m wbia GraphExpt.draw graphsim PZ_Master1 --diskshow Ignore: >>> from wbia.scripts.postdoc import * >>> self = GraphExpt('GZ_Master1') >>> self = GraphExpt('PZ_MTEST') """ sim_results = self.ensure_results('graphsim') metric_nice = { 'n_errors': '# errors', 'n_manual': '# manual reviews', 'frac_mistake_aids': 'fraction error annots', 'merge_remain': 'fraction of merges remain', } # keys = ['ranking', 'rank+clf', 'graph'] # keycols = ['red', 'orange', 'b'] keys = ['graph'] keycols = ['b'] colors = ut.dzip(keys, keycols) dfs = {k: pd.DataFrame(v['metrics']) for k, v in sim_results.items()} n_aids = sim_results['graph']['graph'].number_of_nodes() df = dfs['graph'] df['frac_mistake_aids'] = df.n_mistake_aids / n_aids # mdf = pd.concat(dfs.values(), keys=dfs.keys()) import xarray as xr panel = xr.concat( [xr.DataArray(df, dims=('ts', 'metric')) for df in dfs.values()], dim=pd.Index(list(dfs.keys()), name='key'), ) xmax = panel.sel(metric='n_manual').values.max() xpad = (1.01 * xmax) - xmax pnum_ = pt.make_pnum_nextgen(nSubplots=2) mpl.rcParams.update(TMP_RC) fnum = 1 pt.figure(fnum=fnum, pnum=pnum_()) ax = pt.gca() xkey, ykey = 'n_manual', 'merge_remain' datas = panel.sel(metric=[xkey, ykey]) for key in keys: ax.plot(*datas.sel(key=key).values.T, label=key, color=colors[key]) ax.set_ylim(0, 1) ax.set_xlim(-xpad, xmax + xpad) ax.set_xlabel(metric_nice[xkey]) ax.set_ylabel(metric_nice[ykey]) ax.legend() pt.figure(fnum=fnum, pnum=pnum_()) ax = pt.gca() xkey, ykey = 'n_manual', 'frac_mistake_aids' datas = panel.sel(metric=[xkey, ykey]) for key in keys: ax.plot(*datas.sel(key=key).values.T, label=key, color=colors[key]) ax.set_ylim(0, datas.T[1].max() * 1.01) ax.set_xlim(-xpad, xmax + xpad) ax.set_xlabel(metric_nice[xkey]) ax.set_ylabel(metric_nice[ykey]) ax.legend() fig = pt.gcf() # NOQA fig.set_size_inches([W, H * 0.75]) pt.adjust_subplots(wspace=0.25, fig=fig) fpath = join(self.dpath, 'simulation.png') vt.imwrite(fpath, pt.render_figure_to_image(fig, dpi=DPI)) if ut.get_argflag('--diskshow'): ut.startfile(fpath)
[docs] def draw_graphsim2(self): """ CommandLine: python -m wbia GraphExpt.draw graphsim2 --db PZ_MTEST --diskshow python -m wbia GraphExpt.draw graphsim2 GZ_Master1 --diskshow python -m wbia GraphExpt.draw graphsim2 PZ_Master1 --diskshow Example: >>> # DISABLE_DOCTEST >>> from wbia.scripts.thesis import * >>> dbname = ut.get_argval('--db', default='GZ_Master1') >>> self = GraphExpt(dbname) >>> self.draw_graphsim2() >>> ut.show_if_requested() """ mpl.rcParams.update(TMP_RC) sim_results = self.ensure_results('graphsim') expt_data = sim_results['graph'] metrics_df = pd.DataFrame.from_dict(expt_data['metrics']) # n_aids = sim_results['graph']['graph'].number_of_nodes() # metrics_df['frac_mistake_aids'] = metrics_df.n_mistake_aids / n_aids fnum = 1 # NOQA default_flags = { 'phase': True, 'pred': False, 'user': True, 'real': True, 'error': 0, 'recover': 1, } def plot_intervals(flags, color=None, low=0, high=1): ax = pt.gca() idxs = np.where(flags)[0] ranges = ut.group_consecutives(idxs) bounds = [(min(a), max(a)) for a in ranges if len(a) > 0] xdata_ = xdata.values xs, ys = [xdata_[0]], [low] for a, b in bounds: x1, x2 = xdata_[a], xdata_[b] # if x1 == x2: x1 -= 0.5 x2 += 0.5 xs.extend([x1, x1, x2, x2]) ys.extend([low, high, high, low]) xs.append(xdata_[-1]) ys.append(low) ax.fill_between(xs, ys, low, alpha=0.6, color=color) def overlay_actions(ymax=1, kw=None): """ Draws indicators that detail the algorithm state at given timestamps. """ phase = metrics_df['phase'].map(lambda x: x.split('_')[0]) is_correct = ( metrics_df['test_action'].map(lambda x: x.startswith('correct')).values ) recovering = metrics_df['recovering'].values is_auto = metrics_df['user_id'].map(lambda x: x.startswith('algo:')).values ppos = metrics_df['pred_decision'].map(lambda x: x == POSTV).values rpos = metrics_df['true_decision'].map(lambda x: x == POSTV).values # ymax = max(metrics_df['n_errors']) if kw is None: kw = default_flags num = sum(kw.values()) steps = np.linspace(0, 1, num + 1) * ymax i = -1 def stacked_interval(data, color, i): plot_intervals(data, color, low=steps[i], high=steps[i + 1]) if kw.get('user', False): i += 1 pt.absolute_text( (0.2, steps[i : i + 2].mean()), 'user(algo=gold,manual=blue)' ) stacked_interval(is_auto, 'gold', i) stacked_interval(~is_auto, 'blue', i) if kw.get('pred', False): i += 1 pt.absolute_text((0.2, steps[i : i + 2].mean()), 'pred_pos') stacked_interval(ppos, 'aqua', low=steps[i], high=steps[i + 1]) # stacked_interval(~ppos, 'salmon', i) if kw.get('real', False): i += 1 pt.absolute_text((0.2, steps[i : i + 2].mean()), 'real_merge') stacked_interval(rpos, 'lime', i) # stacked_interval(~ppos, 'salmon', i) if kw.get('error', False): i += 1 pt.absolute_text((0.2, steps[i : i + 2].mean()), 'is_error') # stacked_interval(is_correct, 'blue', low=steps[i], high=steps[i + 1]) stacked_interval(~is_correct, 'red', i) if kw.get('recover', False): i += 1 pt.absolute_text((0.2, steps[i : i + 2].mean()), 'is_recovering') stacked_interval(recovering, 'orange', i) if kw.get('phase', False): i += 1 pt.absolute_text( (0.2, steps[i : i + 2].mean()), 'phase(1=yellow, 2=aqua, 3=pink)' ) stacked_interval(phase == 'ranking', 'yellow', i) stacked_interval(phase == 'posredun', 'aqua', i) stacked_interval(phase == 'negredun', 'pink', i) # stacked_interval(phase == 'ranking', 'red', i) # stacked_interval(phase == 'posredun', 'green', i) # stacked_interval(phase == 'negredun', 'blue', i) def accuracy_plot(xdata, xlabel): ydatas = ut.odict([('Graph', metrics_df['merge_remain'])]) pt.multi_plot( xdata, ydatas, marker='', markersize=1, xlabel=xlabel, ylabel='fraction of merge remaining', ymin=0, rcParams=TMP_RC, use_legend=True, fnum=1, pnum=pnum_(), ) def error_plot(xdata, xlabel): # ykeys = ['n_errors'] ykeys = ['frac_mistake_aids'] pt.multi_plot( xdata, metrics_df[ykeys].values.T, xlabel=xlabel, ylabel='fraction error annots', marker='', markersize=1, ymin=0, rcParams=TMP_RC, fnum=1, pnum=pnum_(), use_legend=False, ) def refresh_plot(xdata, xlabel): pt.multi_plot( xdata, [metrics_df['pprob_any']], label_list=['P(C=1)'], xlabel=xlabel, ylabel='refresh criteria', marker='', ymin=0, ymax=1, rcParams=TMP_RC, fnum=1, pnum=pnum_(), use_legend=False, ) ax = pt.gca() thresh = expt_data['refresh_thresh'] ax.plot( [min(xdata), max(xdata)], [thresh, thresh], '-g', label='refresh thresh' ) ax.legend() def error_breakdown_plot(xdata, xlabel): ykeys = ['n_fn', 'n_fp'] pt.multi_plot( xdata, metrics_df[ykeys].values.T, label_list=ykeys, xlabel=xlabel, ylabel='# of errors', marker='x', markersize=1, ymin=0, rcParams=TMP_RC, ymax=max(metrics_df['n_errors']), fnum=1, pnum=pnum_(), use_legend=True, ) def neg_redun_plot(xdata, xlabel): n_pred = len(sim_results['graph']['pred_ccs']) z = (n_pred * (n_pred - 1)) / 2 metrics_df['p_neg_redun'] = metrics_df['n_neg_redun'] / z metrics_df['p_neg_redun1'] = metrics_df['n_neg_redun1'] / z ykeys = ['p_neg_redun', 'p_neg_redun1'] pt.multi_plot( xdata, metrics_df[ykeys].values.T, label_list=ykeys, xlabel=xlabel, ylabel='% neg-redun-meta-edges', marker='x', markersize=1, ymin=0, rcParams=TMP_RC, ymax=max(metrics_df['p_neg_redun1']), fnum=1, pnum=pnum_(), use_legend=True, ) pnum_ = pt.make_pnum_nextgen(nRows=2, nSubplots=6) # --- ROW 1 --- xdata = metrics_df['n_decision'] xlabel = '# decisions' accuracy_plot(xdata, xlabel) # overlay_actions(1) error_plot(xdata, xlabel) overlay_actions(max(metrics_df['frac_mistake_aids'])) # overlay_actions(max(metrics_df['n_errors'])) # refresh_plot(xdata, xlabel) # overlay_actions(1, {'phase': True}) # error_breakdown_plot(xdata, xlabel) neg_redun_plot(xdata, xlabel) # --- ROW 2 --- xdata = metrics_df['n_manual'] xlabel = '# manual reviews' accuracy_plot(xdata, xlabel) # overlay_actions(1) error_plot(xdata, xlabel) overlay_actions(max(metrics_df['frac_mistake_aids'])) # overlay_actions(max(metrics_df['n_errors'])) # refresh_plot(xdata, xlabel) # overlay_actions(1, {'phase': True}) # error_breakdown_plot(xdata, xlabel) neg_redun_plot(xdata, xlabel) # fpath = join(self.dpath, expt_name + '2' + '.png') # fig = pt.gcf() # NOQA # fig.set_size_inches([W * 1.5, H * 1.1]) # vt.imwrite(fpath, pt.render_figure_to_image(fig, dpi=DPI)) # if ut.get_argflag('--diskshow'): # ut.startfile(fpath) # fig.save_fig # if 1: # pt.figure(fnum=fnum, pnum=(2, 2, 4)) # overlay_actions(ymax=1) pt.set_figtitle(self.dbname) fig = pt.gcf() # NOQA fig.set_size_inches([W * 2, H * 2.5]) fig.suptitle(self.dbname) pt.adjust_subplots(hspace=0.25, wspace=0.25, fig=fig) fpath = join(self.dpath, 'graphsim2.png') fig.savefig(fpath, dpi=DPI) # vt.imwrite(fpath, pt.render_figure_to_image(fig, dpi=DPI)) if ut.get_argflag('--diskshow'): ut.startfile(fpath)
[docs]def draw_match_states(): import wbia infr = wbia.AnnotInference('PZ_Master1', 'all') if infr.ibs.dbname == 'PZ_Master1': # [UUID('0cb1ebf5-2a4f-4b80-b172-1b449b8370cf'), # UUID('cd644b73-7978-4a5f-b570-09bb631daa75')] chosen = { POSTV: (17095, 17225), NEGTV: (3966, 5080), INCMP: (3197, 8455), } else: infr.reset_feedback('staging') chosen = { POSTV: list(infr.pos_graph.edges())[0], NEGTV: list(infr.neg_graph.edges())[0], INCMP: list(infr.incmp_graph.edges())[0], } import wbia.plottool as pt import vtool as vt for key, edge in chosen.items(): match = infr._make_matches_from( [edge], config={'match_config': {'ratio_thresh': 0.7}} )[0] with pt.RenderingContext(dpi=300) as ctx: match.show(heatmask=True, show_ell=False, show_ori=False, show_lines=False) vt.imwrite('matchstate_' + key + '.jpg', ctx.image)
[docs]def entropy_potential(infr, u, v, decision): """ Returns the number of edges this edge would invalidate from wbia.algo.graph import demo infr = demo.demodata_infr(pcc_sizes=[5, 2, 4, 2, 2, 1, 1, 1]) infr.refresh_candidate_edges() infr.params['redun.neg'] = 1 infr.params['redun.pos'] = 1 infr.apply_nondynamic_update() ut.qtensure() infr.show(show_cand=True, groupby='name_label') u, v = 1, 7 decision = 'positive' """ nid1, nid2 = infr.pos_graph.node_labels(u, v) # Cases for K=1 if decision == 'positive' and nid1 == nid2: # The actual reduction is the number previously needed to make the cc # k-edge-connected vs how many its needs now. # In the same CC does nothing # (unless k > 1, in which case check edge connectivity) return 0 elif decision == 'positive' and nid1 != nid2: # Between two PCCs reduces the number of PCCs by one n_ccs = infr.pos_graph.number_of_components() # Find needed negative redundency when appart if infr.neg_redun_metagraph.has_node(nid1): neg_redun_set1 = set(infr.neg_redun_metagraph.neighbors(nid1)) else: neg_redun_set1 = set() if infr.neg_redun_metagraph.has_node(nid2): neg_redun_set2 = set(infr.neg_redun_metagraph.neighbors(nid2)) else: neg_redun_set2 = set() # The number of negative edges needed before we place this edge # is the number of PCCs that each PCC doesnt have a negative edge to # yet n_neg_need1 = n_ccs - len(neg_redun_set1) - 1 n_neg_need2 = n_ccs - len(neg_redun_set2) - 1 n_neg_need_before = n_neg_need1 + n_neg_need2 # After we join them we take the union of their negative redundancy # (really we should check if it changes after) # and this is now the new number of negative edges that would be needed neg_redun_after = neg_redun_set1.union(neg_redun_set2) - {nid1, nid2} n_neg_need_after = (n_ccs - 2) - len(neg_redun_after) neg_entropy = n_neg_need_before - n_neg_need_after # NOQA
def _find_good_match_states(infr, ibs, edges): pos_edges = list(infr.pos_graph.edges()) timedelta = ibs.get_annot_pair_timedelta(*zip(*edges)) edges = ut.take(pos_edges, ut.argsort(timedelta))[::-1] infr.qt_edge_reviewer(edges) neg_edges = ut.shuffle(list(infr.neg_graph.edges())) infr.qt_edge_reviewer(neg_edges) if infr.incomp_graph.number_of_edges() > 0: incmp_edges = list(infr.incomp_graph.edges()) if False: ibs = infr.ibs # a1, a2 = map(ibs.annots, zip(*incmp_edges)) # q1 = np.array(ut.replace_nones(a1.qual, np.nan)) # q2 = np.array(ut.replace_nones(a2.qual, np.nan)) # edges = ut.compress(incmp_edges, # ((q1 > 3) | np.isnan(q1)) & # ((q2 > 3) | np.isnan(q2))) # a = ibs.annots(asarray=True) # flags = [t is not None and 'right' == t for t in a.viewpoint_code] # r = a.compress(flags) # flags = [q is not None and q > 4 for q in r.qual] rights = ibs.filter_annots_general( view='right', minqual='excellent', require_quality=True, require_viewpoint=True, ) lefts = ibs.filter_annots_general( view='left', minqual='excellent', require_quality=True, require_viewpoint=True, ) if False: edges = list(infr._make_rankings(3197, rights)) infr.qt_edge_reviewer(edges) edges = list(ut.random_product((rights, lefts), num=10, rng=0)) infr.qt_edge_reviewer(edges) for edge in incmp_edges: infr._make_matches_from([edge])[0] # infr._debug_edge_gt(edge)
[docs]def prepare_cdfs(cdfs, labels): cdfs = vt.pad_vstack(cdfs, fill_value=1) # Sort so the best is on top sortx = np.lexsort(cdfs.T[::-1])[::-1] cdfs = cdfs[sortx] labels = ut.take(labels, sortx) return cdfs, labels
[docs]def plot_cmcs(cdfs, labels, fnum=1, pnum=(1, 1, 1), ymin=0.4): cdfs, labels = prepare_cdfs(cdfs, labels) # Truncte to 20 ranks num_ranks = min(cdfs.shape[-1], 20) xdata = np.arange(1, num_ranks + 1) cdfs_trunc = cdfs[:, 0:num_ranks] label_list = [ '%6.3f%% - %s' % (cdf[0] * 100, lbl) for cdf, lbl in zip(cdfs_trunc, labels) ] # ymin = .4 num_yticks = (10 - int(ymin * 10)) + 1 pt.multi_plot( xdata, cdfs_trunc, label_list=label_list, xlabel='rank', ylabel='match probability', use_legend=True, legend_loc='lower right', num_yticks=num_yticks, ymax=1, ymin=ymin, ypad=0.005, xmin=0.9, num_xticks=5, xmax=num_ranks + 1 - 0.5, pnum=pnum, fnum=fnum, rcParams=TMP_RC, ) return pt.gcf()
[docs]@ut.reloadable_class class VerifierExpt(DBInputs): """ Collect data from experiments to visualize python -m wbia VerifierExpt.measure all PZ_Master1.GZ_Master1,GIRM_Master1,MantaMatcher,RotanTurtles,humpbacks_fb,LF_ALL python -m wbia VerifierExpt.measure all GIRM_Master1,PZ_Master1,LF_ALL python -m wbia VerifierExpt.measure all LF_ALL python -m wbia VerifierExpt.measure all PZ_Master1 python -m wbia VerifierExpt.measure all MantaMatcher python -m wbia VerifierExpt.draw all MantaMatcher python -m wbia VerifierExpt.draw rerank PZ_Master1 python -m wbia VerifierExpt.measure all RotanTurtles python -m wbia VerifierExpt.draw all RotanTurtles Ignore: >>> from wbia.scripts.postdoc import * >>> fpath = ut.glob(ut.truepath('~/Desktop/mtest_plots'), '*.pkl')[0] >>> self = ut.load_data(fpath) """ # base_dpath = ut.truepath('~/Desktop/pair_expts') base_dpath = ut.truepath('~/latex/crall-iccvw-2017/figures') agg_dbnames = [ 'PZ_Master1', 'GZ_Master1', # 'LF_ALL', 'MantaMatcher', 'RotanTurtles', 'humpbacks_fb', 'GIRM_Master1', ] task_nice_lookup = { 'match_state': const.EVIDENCE_DECISION.CODE_TO_NICE, 'photobomb_state': {'pb': 'Photobomb', 'notpb': 'Not Photobomb'}, } def _setup(self, quick=False): r""" CommandLine: python -m wbia VerifierExpt._setup --db GZ_Master1 python -m wbia VerifierExpt._setup --db PZ_Master1 --eval python -m wbia VerifierExpt._setup --db PZ_MTEST python -m wbia VerifierExpt._setup --db PZ_PB_RF_TRAIN python -m wbia VerifierExpt.measure_all --db PZ_PB_RF_TRAIN python -m wbia VerifierExpt.measure all GZ_Master1 python -m wbia VerifierExpt.measure all RotanTurtles --show Example: >>> # DISABLE_DOCTEST >>> from wbia.scripts.postdoc import * >>> dbname = ut.get_argval('--db', default='GZ_Master1') >>> self = VerifierExpt(dbname) >>> self._setup() Ignore: from wbia.scripts.postdoc import * self = VerifierExpt('PZ_Master1') from wbia.scripts.postdoc import * self = VerifierExpt('PZ_PB_RF_TRAIN') from wbia.scripts.postdoc import * self = VerifierExpt('LF_ALL') self = VerifierExpt('RotanTurtles') task = pblm.samples.subtasks['match_state'] ind_df = task.indicator_df dist = ibs.get_annotedge_viewdist(ind_df.index.tolist()) np.all(ind_df[dist > 1]['notcomp']) self.ibs.print_annot_stats(aids, prefix='P') """ self._precollect() logger.info('VerifierExpt _setup()') ibs = self.ibs aids = self.aids_pool # pblm = vsone.OneVsOneProblem.from_aids(ibs, aids, sample_method='random') pblm = vsone.OneVsOneProblem.from_aids( ibs, aids, sample_method='lnbnn+random', # sample_method='random', n_splits=10, ) data_key = 'learn(sum)' # tests without global features # data_key = 'learn(sum,glob)' # tests with global features # data_key = pblm.default_data_key # same as learn(sum,glob) clf_key = pblm.default_clf_key pblm.eval_task_keys = ['match_state'] # test with and without globals pblm.eval_data_keys = ['learn(sum)', 'learn(sum,glob)'] # pblm.eval_data_keys = [data_key] pblm.eval_clf_keys = [clf_key] ibs = pblm.infr.ibs # pblm.samples.print_info() species_code = ibs.get_database_species(pblm.infr.aids)[0] if species_code == 'zebra_plains': species = 'Plains Zebras' if species_code == 'zebra_grevys': species = "Grévy's Zebras" else: species = species_code self.pblm = pblm self.species = species self.data_key = data_key self.clf_key = clf_key if quick: return pblm.setup_evaluation(with_simple=True) pblm.report_evaluation() self.eval_task_keys = pblm.eval_task_keys cfg_prefix = '{}'.format(len(pblm.samples)) config = pblm.hyper_params self._setup_links(cfg_prefix, config) logger.info('Finished setup')
[docs] @classmethod def agg_dbstats(cls): """ CommandLine: python -m wbia VerifierExpt.agg_dbstats python -m wbia VerifierExpt.measure_dbstats Example: >>> # DISABLE_DOCTEST >>> from wbia.scripts.postdoc import * # NOQA >>> result = VerifierExpt.agg_dbstats() >>> print(result) """ dfs = [] for dbname in cls.agg_dbnames: self = cls(dbname) info = self.ensure_results('dbstats', nocompute=False) sample_info = self.ensure_results('sample_info', nocompute=False) # info = self.measure_dbstats() outinfo = info['outinfo'] task = sample_info['subtasks']['match_state'] y_ind = task.indicator_df outinfo['Positive'] = (y_ind[POSTV]).sum() outinfo['Negative'] = (y_ind[NEGTV]).sum() outinfo['Incomparable'] = (y_ind[INCMP]).sum() if outinfo['Database'] == 'mantas': outinfo['Database'] = 'manta rays' dfs.append(outinfo) # labels.append(self.species_nice.capitalize()) df = pd.DataFrame(dfs) logger.info('df =\n{!r}'.format(df)) df = df.set_index('Database') df.index.name = None tabular = Tabular(df, colfmt='numeric') tabular.theadify = 16 enc_text = tabular.as_tabular() logger.info(enc_text) ut.write_to(join(cls.base_dpath, 'agg-dbstats.tex'), enc_text) _ = ut.render_latex( enc_text, dpath=self.base_dpath, fname='agg-dbstats', preamb_extra=['\\usepackage{makecell}'], ) _
# ut.startfile(_)
[docs] @classmethod def agg_results(cls, task_key): """ python -m wbia VerifierExpt.agg_results python -m wbia VerifierExpt.agg_results --link link-paper-final GZ_Master1,LF_ALL,MantaMatcher,RotanTurtles,humpbacks_fb,GIRM_Master1 Example: >>> # DISABLE_DOCTEST >>> from wbia.scripts.postdoc import * # NOQA >>> task_key = 'match_state' >>> result = VerifierExpt.agg_results(task_key) >>> print(result) """ cls.agg_dbstats() dbnames = cls.agg_dbnames all_results = ut.odict([]) for dbname in cls.agg_dbnames: self = cls(dbname) info = self.ensure_results('all') all_results[dbname] = info rerank_results = ut.odict([]) for dbname in cls.agg_dbnames: self = cls(dbname) info = self.ensure_results('rerank') rerank_results[dbname] = info rank_curves = ub.AutoOrderedDict() rank1_cmc_table = pd.DataFrame(columns=[LNBNN, CLF]) rank5_cmc_table = pd.DataFrame(columns=[LNBNN, CLF]) n_dbs = len(all_results) color_cycle = mpl.rcParams['axes.prop_cycle'].by_key()['color'][:n_dbs] color_cycle = ['r', 'b', 'purple', 'orange', 'deeppink', 'g'] markers = pt.distinct_markers(n_dbs) dbprops = ub.AutoDict() for n, dbname in enumerate(dbnames): dbprops[dbname]['color'] = color_cycle[n] dbprops[dbname]['marker'] = markers[n] def highlight_metric(metric, data1, data2): # Highlight the bigger one for each metric for d1, d2 in it.permutations([data1, data2], 2): text = '{:.3f}'.format(d1[metric]) if d1[metric] >= d2[metric]: d1[metric + '_tex'] = '\\mathbf{' + text + '}' d1[metric + '_text'] = text + '*' else: d1[metric + '_tex'] = text d1[metric + '_text'] = text for dbname in dbnames: results = all_results[dbname] data_key = results['data_key'] clf_key = results['clf_key'] lnbnn_data = results['lnbnn_data'] task_combo_res = results['task_combo_res'] res = task_combo_res[task_key][clf_key][data_key] nice = dbname_to_species_nice(dbname) # ranking results results = rerank_results[dbname] cdfs, infos = list(zip(*results)) lnbnn_cdf, clf_cdf = cdfs cdfs = { CLF: clf_cdf, LNBNN: lnbnn_cdf, } rank1_cmc_table.loc[nice, LNBNN] = lnbnn_cdf[0] rank1_cmc_table.loc[nice, CLF] = clf_cdf[0] rank5_cmc_table.loc[nice, LNBNN] = lnbnn_cdf[4] rank5_cmc_table.loc[nice, CLF] = clf_cdf[4] # Check the ROC for only things in the top of the LNBNN ranked lists # nums = [1, 2, 3, 4, 5, 10, 20, np.inf] nums = [1, 5, np.inf] for num in nums: ranks = lnbnn_data['rank_lnbnn_1vM'].values sub_data = lnbnn_data[ranks <= num] scores = sub_data['score_lnbnn_1vM'].values y = sub_data[POSTV].values probs = res.probs_df[POSTV].loc[sub_data.index].values cfsm_vsm = vt.ConfusionMetrics().fit(scores, y) cfsm_clf = vt.ConfusionMetrics().fit(probs, y) algo_confusions = {LNBNN: cfsm_vsm, CLF: cfsm_clf} datas = [] for algo in {LNBNN, CLF}: cfms = algo_confusions[algo] data = { 'dbname': dbname, 'species': nice, 'fpr': cfms.fpr, 'tpr': cfms.tpr, 'auc': cfms.auc, 'cmc0': cdfs[algo][0], 'cmc': cdfs[algo], 'color': dbprops[dbname]['color'], 'marker': dbprops[dbname]['marker'], 'tpr@fpr=0': cfms.get_metric_at_metric( 'tpr', 'fpr', 0, tiebreaker='minthresh' ), 'thresh@fpr=0': cfms.get_metric_at_metric( 'thresh', 'fpr', 0, tiebreaker='minthresh' ), } rank_curves[num][algo][dbname] = data datas.append(data) # Highlight the bigger one for each metric highlight_metric('auc', *datas) highlight_metric('tpr@fpr=0', *datas) highlight_metric('cmc0', *datas) rank_auc_tables = ut.ddict(lambda: pd.DataFrame(columns=[LNBNN, CLF])) rank_tpr_tables = ut.ddict(lambda: pd.DataFrame(columns=[LNBNN, CLF])) rank_tpr_thresh_tables = ut.ddict(lambda: pd.DataFrame(columns=[LNBNN, CLF])) for num in rank_curves.keys(): rank_auc_df = rank_auc_tables[num] rank_auc_df.index.name = 'AUC@rank<={}'.format(num) rank_tpr_df = rank_tpr_tables[num] rank_tpr_df.index.name = 'tpr@fpr=0&rank<={}'.format(num) rank_thesh_df = rank_tpr_thresh_tables[num] rank_thesh_df.index.name = 'thresh@fpr=0&rank<={}'.format(num) for algo in rank_curves[num].keys(): for dbname in rank_curves[num][algo].keys(): data = rank_curves[num][algo][dbname] nice = data['species'] rank_auc_df.loc[nice, algo] = data['auc'] rank_tpr_df.loc[nice, algo] = data['tpr@fpr=0'] rank_thesh_df.loc[nice, algo] = data['thresh@fpr=0'] from utool.experimental.pandas_highlight import to_string_monkey nums = [1] for rank in nums: logger.info('-----') logger.info('AUC at rank = {!r}'.format(rank)) rank_auc_df = rank_auc_tables[rank] logger.info(to_string_monkey(rank_auc_df, 'all')) logger.info('===============') for rank in nums: logger.info('-----') logger.info('TPR at rank = {!r}'.format(rank)) rank_tpr_df = rank_tpr_tables[rank] logger.info(to_string_monkey(rank_tpr_df, 'all')) def _bf_best(df): df = df.copy() for rx in range(len(df)): col = df.iloc[rx] for cx in ut.argmax(col.values, multi=True): val = df.iloc[rx, cx] df.iloc[rx, cx] = '\\mathbf{{{:.3f}}}'.format(val) return df if True: # Tables rank1_auc_table = rank_auc_tables[1] rank1_tpr_table = rank_tpr_tables[1] # all_stats = pd.concat(ut.emap(_bf_best, [auc_table, rank1_cmc_table, rank5_cmc_table]), axis=1) column_parts = [ ('Rank $1$ AUC', rank1_auc_table), ('Rank $1$ TPR', rank1_tpr_table), ('Pos. @ Rank $1$', rank1_cmc_table), ] all_stats = pd.concat( ut.emap(_bf_best, ut.take_column(column_parts, 1)), axis=1 ) all_stats.index.name = None colfmt = 'l|' + '|'.join(['rr'] * len(column_parts)) multi_header = ( [None] + [(2, 'c|', name) for name in ut.take_column(column_parts, 0)[0:-1]] + [(2, 'c', name) for name in ut.take_column(column_parts, 0)[-1:]] ) from wbia.scripts import _thesis_helpers tabular = _thesis_helpers.Tabular(all_stats, colfmt=colfmt, escape=False) tabular.add_multicolumn_header(multi_header) tabular.precision = 3 tex_text = tabular.as_tabular() # HACKS import re num_pat = ut.named_field('num', r'[0-9]*\.?[0-9]*') tex_text = re.sub( re.escape('\\mathbf{$') + num_pat + re.escape('$}'), '$\\mathbf{' + ut.bref_field('num') + '}$', tex_text, ) logger.info(tex_text) # tex_text = tex_text.replace('\\mathbf{$', '$\\mathbf{') # tex_text = tex_text.replace('$}', '}$') ut.write_to(join(cls.base_dpath, 'agg-results-all.tex'), tex_text) _ = ut.render_latex( tex_text, dpath=cls.base_dpath, fname='agg-results-all', preamb_extra=['\\usepackage{makecell}'], ) # ut.startfile(_) if True: # Tables rank1_auc_table = rank_auc_tables[1] rank1_tpr_table = rank_tpr_tables[1] logger.info( '\nrank1_auc_table =\n{}'.format(to_string_monkey(rank1_auc_table, 'all')) ) logger.info( '\nrank1_tpr_table =\n{}'.format(to_string_monkey(rank1_tpr_table, 'all')) ) logger.info( '\nrank1_cmc_table =\n{}'.format(to_string_monkey(rank1_cmc_table, 'all')) ) # Tables rank1_auc_table = rank_auc_tables[1] rank1_tpr_table = rank_tpr_tables[1] # all_stats = pd.concat(ut.emap(_bf_best, [auc_table, rank1_cmc_table, rank5_cmc_table]), axis=1) column_parts = [ ('Rank $1$ AUC', rank1_auc_table), # ('Rank $1$ TPR', rank1_tpr_table), ('Pos. @ Rank $1$', rank1_cmc_table), ] all_stats = pd.concat( ut.emap(_bf_best, ut.take_column(column_parts, 1)), axis=1 ) all_stats.index.name = None colfmt = 'l|' + '|'.join(['rr'] * len(column_parts)) multi_header = ( [None] + [(2, 'c|', name) for name in ut.take_column(column_parts, 0)[0:-1]] + [(2, 'c', name) for name in ut.take_column(column_parts, 0)[-1:]] ) from wbia.scripts import _thesis_helpers tabular = _thesis_helpers.Tabular(all_stats, colfmt=colfmt, escape=False) tabular.add_multicolumn_header(multi_header) tabular.precision = 3 tex_text = tabular.as_tabular() # HACKS import re num_pat = ut.named_field('num', r'[0-9]*\.?[0-9]*') tex_text = re.sub( re.escape('\\mathbf{$') + num_pat + re.escape('$}'), '$\\mathbf{' + ut.bref_field('num') + '}$', tex_text, ) logger.info(tex_text) logger.info(tex_text) # tex_text = tex_text.replace('\\mathbf{$', '$\\mathbf{') # tex_text = tex_text.replace('$}', '}$') ut.write_to(join(cls.base_dpath, 'agg-results.tex'), tex_text) _ = ut.render_latex( tex_text, dpath=cls.base_dpath, fname='agg-results', preamb_extra=['\\usepackage{makecell}'], ) _ # ut.startfile(_) method = 2 if method == 2: mpl.rcParams['text.usetex'] = True mpl.rcParams['text.latex.unicode'] = True # mpl.rcParams['axes.labelsize'] = 12 mpl.rcParams['legend.fontsize'] = 12 mpl.rcParams['xtick.color'] = 'k' mpl.rcParams['ytick.color'] = 'k' mpl.rcParams['axes.labelcolor'] = 'k' # mpl.rcParams['text.color'] = 'k' nums = [1, np.inf] nums = [1] for num in nums: chunked_dbnames = list(ub.chunks(dbnames, 2)) for fnum, dbname_chunk in enumerate(chunked_dbnames, start=1): fig = pt.figure(fnum=fnum) # NOQA fig.clf() ax = pt.gca() for dbname in dbname_chunk: data1 = rank_curves[num][CLF][dbname] data2 = rank_curves[num][LNBNN][dbname] data1['label'] = 'TPR=${tpr}$ {algo} {species}'.format( algo=CLF, tpr=data1['tpr@fpr=0_tex'], species=data1['species'] ) data1['ls'] = '-' data1['chunk_marker'] = '^' data1['color'] = dbprops[dbname]['color'] data2['label'] = 'TPR=${tpr}$ {algo} {species}'.format( algo=LNBNN, tpr=data2['tpr@fpr=0_tex'], species=data2['species'], ) data2['ls'] = '--' data2['chunk_marker'] = 'v' data2['color'] = dbprops[dbname]['color'] for d in [data1, data2]: ax.plot( d['fpr'], d['tpr'], d['ls'], color=d['color'], zorder=10 ) for d in [data1, data2]: ax.plot( 0, d['tpr@fpr=0'], d['ls'], marker=d['chunk_marker'], markeredgecolor='k', markersize=8, # fillstyle='none', color=d['color'], label=d['label'], zorder=100, ) ax.set_xlabel('false positive rate') ax.set_ylabel('true positive rate') ax.set_ylim(0, 1) ax.set_xlim(-0.05, 0.5) # ax.set_title('ROC with ranks $<= {}$'.format(num)) ax.legend(loc='lower right') pt.adjust_subplots(top=0.8, bottom=0.2, left=0.12, right=0.9) fig.set_size_inches([W * 0.7, H]) fname = 'agg_roc_rank_{}_chunk_{}_{}.png'.format(num, fnum, task_key) fig_fpath = join(str(cls.base_dpath), fname) vt.imwrite(fig_fpath, pt.render_figure_to_image(fig, dpi=DPI)) chunked_dbnames = list(ub.chunks(dbnames, 2)) for fnum, dbname_chunk in enumerate(chunked_dbnames, start=1): fig = pt.figure(fnum=fnum) # NOQA fig.clf() ax = pt.gca() for dbname in dbname_chunk: data1 = rank_curves[num][CLF][dbname] data2 = rank_curves[num][LNBNN][dbname] data1['label'] = 'pos@rank1=${cmc0}$ {algo} {species}'.format( algo=CLF, cmc0=data1['cmc0_tex'], species=data1['species'] ) data1['ls'] = '-' data1['chunk_marker'] = '^' data1['color'] = dbprops[dbname]['color'] data2['label'] = 'pos@rank1=${cmc0}$ {algo} {species}'.format( algo=LNBNN, cmc0=data2['cmc0_tex'], species=data2['species'] ) data2['ls'] = '--' data2['chunk_marker'] = 'v' data2['color'] = dbprops[dbname]['color'] for d in [data1, data2]: ax.plot(d['fpr'], d['tpr'], d['ls'], color=d['color']) for d in [data1, data2]: ax.plot( d['cmc'], d['ls'], # marker=d['chunk_marker'], # markeredgecolor='k', # markersize=8, # fillstyle='none', color=d['color'], label=d['label'], ) ax.set_xlabel('rank') ax.set_ylabel('match probability') ax.set_ylim(0, 1) ax.set_xlim(1, 20) ax.set_xticks([1, 5, 10, 15, 20]) # ax.set_title('ROC with ranks $<= {}$'.format(num)) ax.legend(loc='lower right') pt.adjust_subplots(top=0.8, bottom=0.2, left=0.12, right=0.9) fig.set_size_inches([W * 0.7, H]) fname = 'agg_cmc_chunk_{}_{}.png'.format(fnum, task_key) fig_fpath = join(str(cls.base_dpath), fname) vt.imwrite(fig_fpath, pt.render_figure_to_image(fig, dpi=DPI)) if method == 1: # Does going from rank 1 to rank inf generally improve deltas? # -rank_tpr_tables[np.inf].diff(axis=1) - -rank_tpr_tables[1].diff(axis=1) mpl.rcParams['text.usetex'] = True mpl.rcParams['text.latex.unicode'] = True # mpl.rcParams['axes.labelsize'] = 12 mpl.rcParams['legend.fontsize'] = 12 mpl.rcParams['xtick.color'] = 'k' mpl.rcParams['ytick.color'] = 'k' mpl.rcParams['axes.labelcolor'] = 'k' # mpl.rcParams['text.color'] = 'k' def method1_roc(roc_curves, algo, other): ax = pt.gca() for dbname in dbnames: data = roc_curves[algo][dbname] ax.plot(data['fpr'], data['tpr'], color=data['color']) for dbname in dbnames: data = roc_curves[algo][dbname] other_data = roc_curves[other][dbname] other_tpr = other_data['tpr@fpr=0'] species = data['species'] tpr = data['tpr@fpr=0'] tpr_text = '{:.3f}'.format(tpr) if tpr >= other_tpr: if mpl.rcParams['text.usetex']: tpr_text = '\\mathbf{' + tpr_text + '}' else: tpr_text = tpr_text + '*' label = 'TPR=${tpr}$ {species}'.format(tpr=tpr_text, species=species) ax.plot( 0, data['tpr@fpr=0'], marker=data['marker'], label=label, color=data['color'], ) if algo: algo = algo.rstrip() + ' ' algo = '' ax.set_xlabel(algo + 'false positive rate') ax.set_ylabel('true positive rate') ax.set_ylim(0, 1) ax.set_xlim(-0.005, 0.5) # ax.set_title('%s ROC for %s' % (target_class.title(), self.species)) ax.legend(loc='lower right') pt.adjust_subplots(top=0.8, bottom=0.2, left=0.12, right=0.9) fig.set_size_inches([W * 0.7, H]) nums = [1, np.inf] # nums = [1] for num in nums: algos = {CLF, LNBNN} for fnum, algo in enumerate(algos, start=1): roc_curves = rank_curves[num] other = next(iter(algos - {algo})) fig = pt.figure(fnum=fnum) # NOQA method1_roc(roc_curves, algo, other) fname = 'agg_roc_rank_{}_{}_{}.png'.format(num, algo, task_key) fig_fpath = join(str(cls.base_dpath), fname) vt.imwrite(fig_fpath, pt.render_figure_to_image(fig, dpi=DPI)) # ------------- mpl.rcParams['text.usetex'] = True mpl.rcParams['text.latex.unicode'] = True mpl.rcParams['xtick.color'] = 'k' mpl.rcParams['ytick.color'] = 'k' mpl.rcParams['axes.labelcolor'] = 'k' mpl.rcParams['text.color'] = 'k' def method1_cmc(cmc_curves): ax = pt.gca() color_cycle = mpl.rcParams['axes.prop_cycle'].by_key()['color'] markers = pt.distinct_markers(len(cmc_curves)) for data, marker, color in zip(cmc_curves.values(), markers, color_cycle): species = data['species'] if mpl.rcParams['text.usetex']: cmc0_text = data['cmc0_tex'] label = 'pos@rank1=${}$ {species}'.format( cmc0_text, species=species ) else: cmc0_text = data['cmc0_text'] label = 'pos@rank1={} {species}'.format( cmc0_text, species=species ) ranks = np.arange(1, len(data['cmc']) + 1) ax.plot(ranks, data['cmc'], marker=marker, color=color, label=label) ax.set_xlabel('rank') ax.set_ylabel('match probability') ax.set_ylim(0, 1) ax.set_xlim(1, 20) ax.set_xticks([1, 5, 10, 15, 20]) # ax.set_title('%s ROC for %s' % (target_class.title(), self.species)) ax.legend(loc='lower right') pt.adjust_subplots(top=0.8, bottom=0.2, left=0.12, right=0.9) fig.set_size_inches([W * 0.7, H]) fig = pt.figure(fnum=1) # NOQA # num doesnt actually matter here num = 1 cmc_curves = rank_curves[num][CLF] method1_cmc(cmc_curves) fname = 'agg_cmc_clf_{}.png'.format(task_key) fig_fpath = join(str(cls.base_dpath), fname) vt.imwrite(fig_fpath, pt.render_figure_to_image(fig, dpi=DPI)) fig = pt.figure(fnum=2) # NOQA cmc_curves = rank_curves[num][LNBNN] method1_cmc(cmc_curves) fname = 'agg_cmc_lnbnn_{}.png'.format(task_key) fig_fpath = join(str(cls.base_dpath), fname) vt.imwrite(fig_fpath, pt.render_figure_to_image(fig, dpi=DPI)) if True: # Agg metrics agg_y_pred = [] agg_y_true = [] agg_sample_weight = [] agg_class_names = None for dbname, results in all_results.items(): task_combo_res = results['task_combo_res'] res = task_combo_res[task_key][clf_key][data_key] res.augment_if_needed() y_true = res.y_test_enc incmp_enc = ut.aslist(res.class_names).index(INCMP) if sum(y_true == incmp_enc) < 500: continue # Find auto thresholds logger.info('-----') logger.info('dbname = {!r}'.format(dbname)) for k in range(res.y_test_bin.shape[1]): class_k_truth = res.y_test_bin.T[k] class_k_probs = res.clf_probs.T[k] cfms_ovr = vt.ConfusionMetrics().fit(class_k_probs, class_k_truth) # auc = sklearn.metrics.roc_auc_score(class_k_truth, class_k_probs) state = res.class_names[k] # for state, cfms_ovr in res.confusions_ovr(): if state == POSTV: continue tpr = cfms_ovr.get_metric_at_metric( 'tpr', 'fpr', 0, tiebreaker='minthresh' ) # thresh = cfsm_scores_rank.get_metric_at_metric( # 'thresh', 'fpr', 0, tiebreaker='minthresh') logger.info('state = {!r}'.format(state)) logger.info('tpr = {:.3f}'.format(tpr)) logger.info('+--') logger.info('-----') # aggregate results y_pred = res.clf_probs.argmax(axis=1) agg_y_true.extend(y_true.tolist()) agg_y_pred.extend(y_pred.tolist()) agg_sample_weight.extend(res.sample_weight.tolist()) assert ( agg_class_names is None or agg_class_names == res.class_names ), 'classes are inconsistent' agg_class_names = res.class_names from wbia.algo.verif import sklearn_utils agg_report = sklearn_utils.classification_report2( agg_y_true, agg_y_pred, agg_class_names, agg_sample_weight, verbose=False ) metric_df = agg_report['metrics'] confusion_df = agg_report['confusion'] # multiclass_mcc = agg_report['mcc'] # df.loc['combined', 'MCC'] = multiclass_mcc multiclass_mcc = agg_report['mcc'] metric_df.loc['combined', 'mcc'] = multiclass_mcc logger.info(metric_df) logger.info(confusion_df) dpath = str(self.base_dpath) confusion_fname = 'agg_confusion_{}'.format(task_key) metrics_fname = 'agg_eval_metrics_{}'.format(task_key) # df = self.task_confusion[task_key] df = confusion_df.copy() df = df.rename_axis(self.task_nice_lookup[task_key], 0) df = df.rename_axis(self.task_nice_lookup[task_key], 1) df.columns.name = None df.index.name = 'Real' colfmt = '|l|' + 'r' * (len(df) - 1) + '|l|' tabular = Tabular(df, colfmt=colfmt, hline=True) tabular.groupxs = [list(range(len(df) - 1)), [len(df) - 1]] tabular.add_multicolumn_header([None, (3, 'c|', 'Predicted'), None]) latex_str = tabular.as_tabular() sum_pred = df.index[-1] sum_real = df.columns[-1] latex_str = latex_str.replace(sum_pred, r'$\sum$ predicted') latex_str = latex_str.replace(sum_real, r'$\sum$ real') confusion_tex = ut.align(latex_str, '&', pos=None) logger.info(confusion_tex) ut.render_latex(confusion_tex, dpath=self.base_dpath, fname=confusion_fname) df = metric_df # df = self.task_metrics[task_key] df = df.rename_axis(self.task_nice_lookup[task_key], 0) df = df.rename_axis({'mcc': 'MCC'}, 1) df = df.rename_axis({'combined': 'Combined'}, 1) df = df.drop(['markedness', 'bookmaker', 'fpr'], axis=1) df.index.name = None df.columns.name = None df['support'] = df['support'].astype(np.int) df.columns = ut.emap(upper_one, df.columns) import re tabular = Tabular(df, colfmt='numeric') top, header, mid, bot = tabular.as_parts() lines = mid[0].split('\n') newmid = [lines[0:-1], lines[-1:]] tabular.parts = (top, header, newmid, bot) latex_str = tabular.as_tabular() latex_str = re.sub(' -0.00 ', ' 0.00 ', latex_str) metrics_tex = latex_str logger.info(metrics_tex) confusion_tex = confusion_tex.replace('Incomparable', 'Incomp.') confusion_tex = confusion_tex.replace('predicted', 'pred') metrics_tex = metrics_tex.replace('Incomparable', 'Incomp.') ut.write_to(join(dpath, confusion_fname + '.tex'), confusion_tex) ut.write_to(join(dpath, metrics_fname + '.tex'), metrics_tex) ut.render_latex(confusion_tex, dpath=dpath, fname=confusion_fname) ut.render_latex(metrics_tex, dpath=dpath, fname=metrics_fname) old_cmc = rank1_cmc_table[LNBNN] new_cmc = rank1_cmc_table[CLF] cmc_diff = new_cmc - old_cmc cmc_change = cmc_diff / old_cmc improved = cmc_diff > 0 logger.info( '{} / {} datasets saw CMC improvement'.format(sum(improved), len(cmc_diff)) ) logger.info('CMC average absolute diff: {}'.format(cmc_diff.mean())) logger.info('CMC average percent change: {}'.format(cmc_change.mean())) logger.info('Average AUC:\n{}'.format(rank1_auc_table.mean(axis=0))) logger.info('Average TPR:\n{}'.format(rank1_tpr_table.mean(axis=0))) old_tpr = rank1_tpr_table[LNBNN] new_tpr = rank1_tpr_table[CLF] tpr_diff = new_tpr - old_tpr tpr_change = tpr_diff / old_tpr improved = tpr_diff > 0 logger.info( '{} / {} datasets saw TPR improvement'.format(sum(improved), len(tpr_diff)) ) logger.info('TPR average absolute diff: {}'.format(tpr_diff.mean())) logger.info('TPR average percent change: {}'.format(tpr_change.mean()))
[docs] @profile def measure_dbstats(self): """ python -m wbia VerifierExpt.measure dbstats GZ_Master1 python -m wbia VerifierExpt.measure dbstats PZ_Master1 python -m wbia VerifierExpt.measure dbstats MantaMatcher python -m wbia VerifierExpt.measure dbstats RotanTurtles Ignore: >>> from wbia.scripts.postdoc import * >>> #self = VerifierExpt('GZ_Master1') >>> self = VerifierExpt('MantaMatcher') """ if self.ibs is None: self._precollect() ibs = self.ibs # self.ibs.print_annot_stats(self.aids_pool) # encattr = 'static_encounter' encattr = 'encounter_text' # encattr = 'aids' annots = ibs.annots(self.aids_pool) encounters = annots.group2(getattr(annots, encattr)) nids = ut.take_column(encounters.nids, 0) nid_to_enc = ut.group_items(encounters, nids) single_encs = {nid: e for nid, e in nid_to_enc.items() if len(e) == 1} multi_encs = { nid: self.ibs._annot_groups(e) for nid, e in nid_to_enc.items() if len(e) > 1 } multi_annots = ibs.annots(ut.flatten(ut.flatten(multi_encs.values()))) single_annots = ibs.annots(ut.flatten(ut.flatten(single_encs.values()))) def annot_stats(annots, encattr): encounters = annots.group2(getattr(annots, encattr)) nid_to_enc = ut.group_items(encounters, ut.take_column(encounters.nids, 0)) nid_to_nenc = ut.map_vals(len, nid_to_enc) n_enc_per_name = list(nid_to_nenc.values()) n_annot_per_enc = ut.lmap(len, encounters) enc_deltas = [] for encs_ in nid_to_enc.values(): times = [np.mean(a.image_unixtimes_asfloat) for a in encs_] for tup in ut.combinations(times, 2): delta = max(tup) - min(tup) enc_deltas.append(delta) # pass # delta = times.max() - times.min() # enc_deltas.append(delta) annot_info = ut.odict() annot_info['n_names'] = len(nid_to_enc) annot_info['n_annots'] = len(annots) annot_info['n_encs'] = len(encounters) annot_info['enc_time_deltas'] = ut.get_stats(enc_deltas) annot_info['n_enc_per_name'] = ut.get_stats(n_enc_per_name) annot_info['n_annot_per_enc'] = ut.get_stats(n_annot_per_enc) # logger.info(ut.repr4(annot_info, si=True, nl=1, precision=2)) return annot_info enc_info = ut.odict() enc_info['all'] = annot_stats(annots, encattr) del enc_info['all']['enc_time_deltas'] enc_info['multi'] = annot_stats(multi_annots, encattr) enc_info['single'] = annot_stats(single_annots, encattr) del enc_info['single']['n_encs'] del enc_info['single']['n_enc_per_name'] del enc_info['single']['enc_time_deltas'] qual_info = ut.dict_hist(annots.quality_texts) qual_info['None'] = qual_info.pop('UNKNOWN', 0) qual_info['None'] += qual_info.pop(None, 0) view_info = ut.dict_hist(annots.viewpoint_code) view_info['None'] = view_info.pop('unknown', 0) view_info['None'] += view_info.pop(None, 0) info = ut.odict([]) info['species_nice'] = self.species_nice info['enc'] = enc_info info['qual'] = qual_info info['view'] = view_info logger.info('Annotation Pool DBStats') logger.info(ut.repr4(info, si=True, nl=3, precision=2)) def _ave_str2(d): try: return ave_str(*ut.take(d, ['mean', 'std'])) except Exception: return 0 outinfo = ut.odict( [ ('Database', info['species_nice']), ('Annots', enc_info['all']['n_annots']), ('Names (singleton)', enc_info['single']['n_names']), ('Names (resighted)', enc_info['multi']['n_names']), ( 'Enc per name (resighted)', _ave_str2(enc_info['multi']['n_enc_per_name']), ), ('Annots per encounter', _ave_str2(enc_info['all']['n_annot_per_enc'])), ] ) info['outinfo'] = outinfo df = pd.DataFrame([outinfo]) df = df.set_index('Database') df.index.name = None df.index = ut.emap(upper_one, df.index) tabular = Tabular(df, colfmt='numeric') tabular.theadify = 16 enc_text = tabular.as_tabular() logger.info(enc_text) # ut.render_latex(enc_text, dpath=self.dpath, fname='dbstats', # preamb_extra=['\\usepackage{makecell}']) # ut.startfile(_) # expt_name = ut.get_stack_frame().f_code.co_name.replace('measure_', '') expt_name = 'dbstats' self.expt_results[expt_name] = info ut.ensuredir(self.dpath) ut.save_data(join(self.dpath, expt_name + '.pkl'), info) return info
[docs] def measure_all(self): r""" CommandLine: python -m wbia VerifierExpt.measure all GZ_Master1,MantaMatcher,RotanTurtles,LF_ALL python -m wbia VerifierExpt.measure all GZ_Master1 Ignore: from wbia.scripts.postdoc import * self = VerifierExpt('PZ_MTEST') self.measure_all() """ self._setup() pblm = self.pblm expt_name = 'sample_info' results = { 'graph': pblm.infr.graph, 'aid_pool': self.aids_pool, 'pblm_aids': pblm.infr.aids, 'encoded_labels2d': pblm.samples.encoded_2d(), 'subtasks': pblm.samples.subtasks, 'multihist': pblm.samples.make_histogram(), } self.expt_results[expt_name] = results ut.save_data(join(str(self.dpath), expt_name + '.pkl'), results) # importance = { # task_key: pblm.feature_importance(task_key=task_key) # for task_key in pblm.eval_task_keys # } task = pblm.samples['match_state'] scores = pblm.samples.simple_scores['score_lnbnn_1vM'] lnbnn_ranks = pblm.samples.simple_scores['rank_lnbnn_1vM'] y = task.indicator_df[task.default_class_name] lnbnn_data = pd.concat([scores, lnbnn_ranks, y], axis=1) results = { 'lnbnn_data': lnbnn_data, 'task_combo_res': self.pblm.task_combo_res, # 'importance': importance, 'data_key': self.data_key, 'clf_key': self.clf_key, } expt_name = 'all' self.expt_results[expt_name] = results ut.save_data(join(str(self.dpath), expt_name + '.pkl'), results) task_key = 'match_state' self.measure_hard_cases(task_key) self.measure_dbstats() self.measure_rerank() if ut.get_argflag('--draw'): self.draw_all()
[docs] def draw_all(self): r""" CommandLine: python -m wbia VerifierExpt.draw_all --db PZ_MTEST python -m wbia VerifierExpt.draw_all --db PZ_PB_RF_TRAIN python -m wbia VerifierExpt.draw_all --db GZ_Master1 python -m wbia VerifierExpt.draw_all --db PZ_Master1 Example: >>> # DISABLE_DOCTEST >>> from wbia.scripts.postdoc import * >>> dbname = ut.get_argval('--db', default='PZ_MTEST') >>> dbnames = ut.get_argval('--dbs', type_=list, default=[dbname]) >>> for dbname in dbnames: >>> print('dbname = %r' % (dbname,)) >>> self = VerifierExpt(dbname) >>> self.draw_all() """ results = self.ensure_results('all') eval_task_keys = set(results['task_combo_res'].keys()) logger.info('eval_task_keys = {!r}'.format(eval_task_keys)) task_key = 'match_state' if ut.get_argflag('--cases'): self.draw_hard_cases(task_key) self.write_sample_info() self.draw_roc(task_key) self.draw_rerank() self.write_metrics(task_key) self.draw_class_score_hist() self.draw_mcc_thresh(task_key)
[docs] def draw_roc(self, task_key='match_state'): """ python -m wbia VerifierExpt.draw roc GZ_Master1 photobomb_state python -m wbia VerifierExpt.draw roc GZ_Master1 match_state python -m wbia VerifierExpt.draw roc PZ_MTEST """ mpl.rcParams.update(TMP_RC) results = self.ensure_results('all') data_key = results['data_key'] clf_key = results['clf_key'] task_combo_res = results['task_combo_res'] lnbnn_data = results['lnbnn_data'] task_key = 'match_state' scores = lnbnn_data['score_lnbnn_1vM'].values y = lnbnn_data[POSTV].values # task_key = 'match_state' target_class = POSTV res = task_combo_res[task_key][clf_key][data_key] cfsm_vsm = vt.ConfusionMetrics().fit(scores, y) cfsm_clf = res.confusions(target_class) roc_curves = [ { 'label': LNBNN, 'fpr': cfsm_vsm.fpr, 'tpr': cfsm_vsm.tpr, 'auc': cfsm_vsm.auc, }, {'label': CLF, 'fpr': cfsm_clf.fpr, 'tpr': cfsm_clf.tpr, 'auc': cfsm_clf.auc}, ] rank_clf_roc_curve = ut.ddict(list) rank_lnbnn_roc_curve = ut.ddict(list) roc_info_lines = [] # Check the ROC for only things in the top of the LNBNN ranked lists if True: rank_auc_df = pd.DataFrame() rank_auc_df.index.name = '<=rank' nums = [1, 2, 3, 4, 5, 10, 20, np.inf] for num in nums: ranks = lnbnn_data['rank_lnbnn_1vM'].values sub_data = lnbnn_data[ranks <= num] scores = sub_data['score_lnbnn_1vM'].values y = sub_data[POSTV].values probs = res.probs_df[POSTV].loc[sub_data.index].values cfsm_scores_rank = vt.ConfusionMetrics().fit(scores, y) cfsm_probs_rank = vt.ConfusionMetrics().fit(probs, y) # if num == np.inf: # num = 'inf' rank_auc_df.loc[num, LNBNN] = cfsm_scores_rank.auc rank_auc_df.loc[num, CLF] = cfsm_probs_rank.auc rank_lnbnn_roc_curve[num] = { 'label': LNBNN, 'fpr': cfsm_scores_rank.fpr, 'tpr': cfsm_scores_rank.tpr, 'auc': cfsm_scores_rank.auc, 'tpr@fpr=0': cfsm_scores_rank.get_metric_at_metric( 'tpr', 'fpr', 0, tiebreaker='minthresh' ), 'thresh@fpr=0': cfsm_scores_rank.get_metric_at_metric( 'thresh', 'fpr', 0, tiebreaker='minthresh' ), } rank_clf_roc_curve[num] = { 'label': CLF, 'fpr': cfsm_probs_rank.fpr, 'tpr': cfsm_probs_rank.tpr, 'auc': cfsm_probs_rank.auc, 'tpr@fpr=0': cfsm_probs_rank.get_metric_at_metric( 'tpr', 'fpr', 0, tiebreaker='minthresh' ), 'thresh@fpr=0': cfsm_probs_rank.get_metric_at_metric( 'thresh', 'fpr', 0, tiebreaker='minthresh' ), } auc_text = 'AUC when restricting to the top `num` LNBNN ranks:' auc_text += '\n' + str(rank_auc_df) logger.info(auc_text) roc_info_lines += [auc_text] if True: tpr_info = [] at_metric = 'tpr' for at_value in [0.25, 0.5, 0.75]: info = ut.odict() for want_metric in ['fpr', 'n_false_pos', 'n_true_pos', 'thresh']: key = '{}_@_{}={:.3f}'.format(want_metric, at_metric, at_value) info[key] = cfsm_clf.get_metric_at_metric( want_metric, at_metric, at_value, tiebreaker='minthresh' ) if key.startswith('n_'): info[key] = int(info[key]) tpr_info += [(ut.repr4(info, align=True, precision=8))] tpr_text = 'Metric TPR relationships\n' + '\n'.join(tpr_info) logger.info(tpr_text) roc_info_lines += [tpr_text] fpr_info = [] at_metric = 'fpr' for at_value in [0, 0.001, 0.01, 0.1]: info = ut.odict() for want_metric in ['tpr', 'n_false_pos', 'n_true_pos', 'thresh']: key = '{}_@_{}={:.3f}'.format(want_metric, at_metric, at_value) info[key] = cfsm_clf.get_metric_at_metric( want_metric, at_metric, at_value, tiebreaker='minthresh' ) if key.startswith('n_'): info[key] = int(info[key]) fpr_info += [(ut.repr4(info, align=True, precision=8))] fpr_text = 'Metric FPR relationships\n' + '\n'.join(fpr_info) logger.info(fpr_text) roc_info_lines += [fpr_text] roc_info_text = '\n\n'.join(roc_info_lines) ut.writeto(join(self.dpath, 'roc_info.txt'), roc_info_text) # logger.info(roc_info_text) fig = pt.figure(fnum=1) # NOQA ax = pt.gca() for data in roc_curves: ax.plot( data['fpr'], data['tpr'], label='AUC={:.3f} {}'.format(data['auc'], data['label']), ) ax.set_xlabel('false positive rate') ax.set_ylabel('true positive rate') # ax.set_title('%s ROC for %s' % (target_class.title(), self.species)) ax.legend() pt.adjust_subplots(top=0.8, bottom=0.2, left=0.12, right=0.9) fig.set_size_inches([W, H]) fname = 'roc_{}.png'.format(task_key) fig_fpath = join(str(self.dpath), fname) vt.imwrite(fig_fpath, pt.render_figure_to_image(fig, dpi=DPI)) logger.info('wrote roc figure to fig_fpath= {!r}'.format(fig_fpath)) for num in [1, 2, 5, np.inf]: roc_curves_ = [rank_clf_roc_curve[num], rank_lnbnn_roc_curve[num]] fig = pt.figure(fnum=1) # NOQA ax = pt.gca() for data in roc_curves_: ax.plot( data['fpr'], data['tpr'], label='AUC={:.3f} TPR={:.3f} {}'.format( data['auc'], data['tpr@fpr=0'], data['label'] ), ) ax.set_xlabel('false positive rate') ax.set_ylabel('true positive rate') ax.set_title('ROC@rank<={num}'.format(num=num)) ax.legend() pt.adjust_subplots(top=0.8, bottom=0.2, left=0.12, right=0.9) fig.set_size_inches([W, H]) fname = 'rank_{}_roc_{}.png'.format(num, task_key) fig_fpath = join(str(self.dpath), fname) vt.imwrite(fig_fpath, pt.render_figure_to_image(fig, dpi=DPI)) logger.info('wrote roc figure to fig_fpath= {!r}'.format(fig_fpath))
[docs] def draw_rerank(self): mpl.rcParams.update(TMP_RC) expt_name = 'rerank' results = self.ensure_results(expt_name) cdfs, infos = list(zip(*results)) lnbnn_cdf = cdfs[0] clf_cdf = cdfs[1] fig = pt.figure(fnum=1) plot_cmcs([lnbnn_cdf, clf_cdf], ['ranking', 'rank+clf'], fnum=1, ymin=0) fig.set_size_inches([W, H * 0.6]) qsizes = ut.take_column(infos, 'qsize') dsizes = ut.take_column(infos, 'dsize') assert ut.allsame(qsizes) and ut.allsame(dsizes) nonvaried_text = 'qsize={}, dsize={}'.format(qsizes[0], dsizes[0]) pt.relative_text('lowerleft', nonvaried_text, ax=pt.gca()) fpath = join(str(self.dpath), expt_name + '.png') vt.imwrite(fpath, pt.render_figure_to_image(fig, dpi=DPI)) if ut.get_argflag('--diskshow'): ut.startfile(fpath) return fpath
[docs] def measure_rerank(self): """ >>> from wbia.scripts.postdoc import * >>> defaultdb = 'PZ_Master1' >>> defaultdb = 'GZ_Master1' >>> self = VerifierExpt(defaultdb) >>> self._setup() >>> self.measure_rerank() """ if getattr(self, 'pblm', None) is None: self._setup() pblm = self.pblm infr = pblm.infr ibs = pblm.infr.ibs # NOTE: this is not the aids_pool for PZ_Master1 aids = pblm.infr.aids # These are not gaurenteed to be comparable if ibs.dbname == 'RotanTurtles': # HACK viewpoint_aware = True else: viewpoint_aware = False from wbia.scripts import thesis qaids, daids_list, info_list = thesis.Sampler._varied_inputs( ibs, aids, viewpoint_aware=viewpoint_aware ) daids = daids_list[0] info = info_list[0] # --------------------------- # Execute the ranking algorithm qaids = sorted(qaids) daids = sorted(daids) cfgdict = pblm._make_lnbnn_pcfg() qreq_ = ibs.new_query_request(qaids, daids, cfgdict=cfgdict) cm_list = qreq_.execute() cm_list = [cm.extend_results(qreq_) for cm in cm_list] # --------------------------- # Measure LNBNN rank probabilities top = 20 rerank_pairs = [] for cm in cm_list: pairs = [infr.e_(cm.qaid, daid) for daid in cm.get_top_aids(top)] rerank_pairs.extend(pairs) rerank_pairs = list(set(rerank_pairs)) # --------------------------- # Re-rank the those top ranks verif = pblm._make_evaluation_verifiers()['match_state'] # verif = infr.learn_evaluation_verifiers()['match_state'] probs = verif.predict_proba_df(rerank_pairs) pos_probs = probs[POSTV] clf_name_ranks = [] lnbnn_name_ranks = [] infr = pblm.infr for cm in cm_list: daids = cm.get_top_aids(top) edges = [infr.e_(cm.qaid, daid) for daid in daids] dnids = cm.dnid_list[ut.take(cm.daid2_idx, daids)] scores = pos_probs.loc[edges].values sortx = np.argsort(scores)[::-1] clf_ranks = np.where(cm.qnid == dnids[sortx])[0] if len(clf_ranks) == 0: clf_rank = len(cm.unique_nids) - 1 else: clf_rank = clf_ranks[0] lnbnn_rank = cm.get_name_ranks([cm.qnid])[0] clf_name_ranks.append(clf_rank) lnbnn_name_ranks.append(lnbnn_rank) bins = np.arange(len(qreq_.dnids)) hist = np.histogram(lnbnn_name_ranks, bins=bins)[0] lnbnn_cdf = np.cumsum(hist) / sum(hist) bins = np.arange(len(qreq_.dnids)) hist = np.histogram(clf_name_ranks, bins=bins)[0] clf_cdf = np.cumsum(hist) / sum(hist) results = [ (lnbnn_cdf, ut.update_dict(info.copy(), {'pcfg': cfgdict})), (clf_cdf, ut.update_dict(info.copy(), {'pcfg': cfgdict})), ] expt_name = 'rerank' self.expt_results[expt_name] = results ut.save_data(join(str(self.dpath), expt_name + '.pkl'), results)
[docs] def measure_hard_cases(self, task_key): """ Find a failure case for each class CommandLine: python -m wbia VerifierExpt.measure hard_cases GZ_Master1 match_state python -m wbia VerifierExpt.measure hard_cases GZ_Master1 photobomb_state python -m wbia VerifierExpt.draw hard_cases GZ_Master1 match_state python -m wbia VerifierExpt.draw hard_cases GZ_Master1 photobomb_state python -m wbia VerifierExpt.measure hard_cases PZ_Master1 match_state python -m wbia VerifierExpt.measure hard_cases PZ_Master1 photobomb_state python -m wbia VerifierExpt.draw hard_cases PZ_Master1 match_state python -m wbia VerifierExpt.draw hard_cases PZ_Master1 photobomb_state python -m wbia VerifierExpt.measure hard_cases PZ_MTEST match_state python -m wbia VerifierExpt.draw hard_cases PZ_MTEST photobomb_state python -m wbia VerifierExpt.draw hard_cases RotanTurtles match_state python -m wbia VerifierExpt.draw hard_cases MantaMatcher match_state Ignore: >>> task_key = 'match_state' >>> task_key = 'photobomb_state' >>> from wbia.scripts.postdoc import * >>> self = VerifierExpt('GZ_Master1') >>> self._setup() """ if getattr(self, 'pblm', None) is None: logger.info('Need to setup before measuring hard cases') self._setup() logger.info('Measuring hard cases') pblm = self.pblm front = mid = back = 8 res = pblm.task_combo_res[task_key][self.clf_key][self.data_key] logger.info('task_key = %r' % (task_key,)) if task_key == 'photobomb_state': method = 'max-mcc' method = res.get_thresholds('mcc', 'maximize') logger.info('Using thresholds: ' + ut.repr4(method)) else: method = 'argmax' logger.info('Using argmax') case_df = res.hardness_analysis(pblm.samples, pblm.infr, method=method) # group = case_df.sort_values(['real_conf', 'easiness']) case_df = case_df.sort_values(['easiness']) # failure_cases = case_df[(case_df['real_conf'] > 0) & case_df['failed']] failure_cases = case_df[case_df['failed']] if len(failure_cases) == 0: logger.info('No reviewed failures exist. Do pblm.qt_review_hardcases') logger.info('There are {} failure cases'.format(len(failure_cases))) logger.info( 'With average hardness {}'.format( ut.repr2( ut.stats_dict(failure_cases['hardness']), strkeys=True, precision=2 ) ) ) cases = [] for (pred, real), group in failure_cases.groupby(('pred', 'real')): group = group.sort_values(['easiness']) flags = ut.flag_percentile_parts(group['easiness'], front, mid, back) subgroup = group[flags] logger.info( 'Selected {} r({})-p({}) cases'.format( len(subgroup), res.class_names[real], res.class_names[pred] ) ) # ut.take_percentile_parts(group['easiness'], front, mid, back) # Prefer examples we have manually reviewed before # group = group.sort_values(['real_conf', 'easiness']) # subgroup = group[0:num_top] for idx, case in subgroup.iterrows(): edge = tuple(ut.take(case, ['aid1', 'aid2'])) cases.append( { 'edge': edge, 'real': res.class_names[case['real']], 'pred': res.class_names[case['pred']], 'failed': case['failed'], 'easiness': case['easiness'], 'real_conf': case['real_conf'], 'probs': res.probs_df.loc[edge].to_dict(), 'edge_data': pblm.infr.get_edge_data(edge), } ) logger.info('Selected %d cases in total' % (len(cases))) # Augment cases with their one-vs-one matches infr = pblm.infr data_key = self.data_key config = pblm.feat_extract_info[data_key][0] edges = [case['edge'] for case in cases] matches = infr._make_matches_from(edges, config=config) match = matches[0] match.config def _prep_annot(annot): # Load data needed for plot into annot dictionary annot['aid'] annot['rchip'] annot['kpts'] # Cast the lazy dict to a real one return {k: annot[k] for k in annot.evaluated_keys()} for case, match in zip(cases, matches): # store its chip fpath and other required info match.annot1 = _prep_annot(match.annot1) match.annot2 = _prep_annot(match.annot2) case['match'] = match fpath = join(str(self.dpath), task_key + '_hard_cases.pkl') ut.save_data(fpath, cases) logger.info('Hard case space on disk: {}'.format(ut.get_file_nBytes_str(fpath))) # if False: # ybin_df = res.target_bin_df # flags = ybin_df['pb'].values # pb_edges = ybin_df[flags].index.tolist() # matches = infr._exec_pairwise_match(pb_edges, config) # prefix = 'training_' # subdir = 'temp_cases_{}'.format(task_key) # dpath = join(str(self.dpath), subdir) # ut.ensuredir(dpath) # tbl = pblm.infr.ibs.db.get_table_as_pandas('annotmatch') # tagged_tbl = tbl[~pd.isnull(tbl['annotmatch_tag_text']).values] # ttext = tagged_tbl['annotmatch_tag_text'] # flags = ['photobomb' in t.split(';') for t in ttext] # pb_table = tagged_tbl[flags] # am_pb_edges = set( # ut.estarmap(infr.e_, zip(pb_table.annot_rowid1.tolist(), # pb_table.annot_rowid2.tolist()))) # # missing = am_pb_edges - set(pb_edges) # # matches = infr._exec_pairwise_match(missing, config) # # prefix = 'missing_' # # infr.relabel_using_reviews() # # infr.apply_nondynamic_update() # # infr.verbose = 100 # # for edge in missing: # # logger.info(edge[0] in infr.aids) # # logger.info(edge[1] in infr.aids) # # fix = [ # # (1184, 1185), # # (1376, 1378), # # (1377, 1378), # # ] # # fb = infr.current_feedback(edge).copy() # # fb = ut.dict_subset(fb, ['decision', 'tags', 'confidence'], # # default=None) # # fb['user_id'] = 'jon_fixam' # # fb['confidence'] = 'pretty_sure' # # fb['tags'] += ['photobomb'] # # infr.add_feedback(edge, **fb) # for c, match in enumerate(ut.ProgIter(matches)): # edge = match.annot1['aid'], match.annot2['aid'] # fig = pt.figure(fnum=1, clf=True) # ax = pt.gca() # # Draw with feature overlay # match.show(ax, vert=False, heatmask=True, show_lines=True, # show_ell=False, show_ori=False, show_eig=False, # line_lw=1, line_alpha=.1, # modifysize=True) # fname = prefix + '_'.join(ut.emap(str, edge)) # ax.set_xlabel(fname) # fpath = join(str(dpath), fname + '.jpg') # vt.imwrite(fpath, pt.render_figure_to_image(fig, dpi=DPI)) # # visualize real photobomb cases return cases
[docs] def draw_hard_cases(self, task_key='match_state'): """ draw hard cases with and without overlay python -m wbia VerifierExpt.draw hard_cases GZ_Master1 match_state python -m wbia VerifierExpt.draw hard_cases PZ_Master1 match_state python -m wbia VerifierExpt.draw hard_cases PZ_Master1 photobomb_state python -m wbia VerifierExpt.draw hard_cases GZ_Master1 photobomb_state python -m wbia VerifierExpt.draw hard_cases RotanTurtles match_state >>> from wbia.scripts.postdoc import * >>> self = VerifierExpt('PZ_MTEST') >>> task_key = 'match_state' >>> self.draw_hard_cases(task_key) """ REWORK = False REWORK = True if REWORK: # HACK if self.ibs is None: self._precollect() cases = self.ensure_results(task_key + '_hard_cases') logger.info('Loaded {} {} hard cases'.format(len(cases), task_key)) subdir = 'cases_{}'.format(task_key) dpath = join(str(self.dpath), subdir) # ut.delete(dpath) ut.ensuredir(dpath) code_to_nice = self.task_nice_lookup[task_key] mpl.rcParams.update(TMP_RC) prog = ut.ProgIter(cases, 'draw {} hard case'.format(task_key), bs=False) for case in prog: aid1, aid2 = case['edge'] match = case['match'] real_name, pred_name = case['real'], case['pred'] real_nice, pred_nice = ut.take(code_to_nice, [real_name, pred_name]) if real_nice != 'Negative': continue fname = 'fail_{}_{}_{}_{}'.format(real_name, pred_name, aid1, aid2) # Build x-label _probs = case['probs'] probs = ut.odict() for k, v in code_to_nice.items(): if k in _probs: probs[v] = _probs[k] probstr = ut.repr2(probs, precision=2, strkeys=True, nobr=True) xlabel = 'real={}, pred={},\n{}'.format(real_nice, pred_nice, probstr) fig = pt.figure(fnum=1000, clf=True) ax = pt.gca() # if REWORK: # ibs = self.ibs # annots = ibs.annots([aid1, aid2]) # imgs = ibs.images(annots.gids) # xlabel += '\nimg: ' + '-vs-'.join(map(repr, imgs.gnames)) # xlabel += '\nname: ' + '-vs-'.join(map(repr, annots.name)) # import datetime # delta = ut.get_timedelta_str(datetime.timedelta(seconds=np.diff(annots.image_unixtimes_asfloat)[0])) # xlabel += '\ntimeΔ: ' + delta # xlabel += '\nedge: ' + str(tuple(annots.aids)) if REWORK: ibs = self.ibs match.annot1['rchip'] = ibs.annots(match.annot1['aid'], config={}).rchip[ 0 ] match.annot2['rchip'] = ibs.annots(match.annot2['aid'], config={}).rchip[ 0 ] # match.annot1['rchip'] = ibs.annots(match.annot1['aid'], config={'medianblur': True, 'adapt_eq': True}).rchip[0] # match.annot2['rchip'] = ibs.annots(match.annot2['aid'], config={'medianblur': True, 'adapt_eq': True}).rchip[0] # Draw with feature overlay match.show( ax, vert=False, heatmask=True, show_lines=False, # show_lines=True, line_lw=1, line_alpha=.1, # ell_alpha=.3, show_ell=False, show_ori=False, show_eig=False, modifysize=True, ) ax.set_xlabel(xlabel) # ax.get_xaxis().get_label().set_fontsize(24) ax.get_xaxis().get_label().set_fontsize(24) fpath = join(str(dpath), fname + '.jpg') vt.imwrite(fpath, pt.render_figure_to_image(fig, dpi=DPI))
[docs] def write_metrics(self, task_key='match_state'): """ Writes confusion matricies CommandLine: python -m wbia VerifierExpt.draw metrics PZ_PB_RF_TRAIN match_state python -m wbia VerifierExpt.draw metrics GZ_Master1 photobomb_state python -m wbia VerifierExpt.draw metrics PZ_Master1,GZ_Master1 photobomb_state,match_state Ignore: >>> from wbia.scripts.postdoc import * >>> self = VerifierExpt('PZ_Master1') >>> task_key = 'match_state' """ results = self.ensure_results('all') task_combo_res = results['task_combo_res'] data_key = results['data_key'] clf_key = results['clf_key'] res = task_combo_res[task_key][clf_key][data_key] res.augment_if_needed() pred_enc = res.clf_probs.argmax(axis=1) y_pred = pred_enc y_true = res.y_test_enc sample_weight = res.sample_weight target_names = res.class_names from wbia.algo.verif import sklearn_utils report = sklearn_utils.classification_report2( y_true, y_pred, target_names, sample_weight, verbose=False ) metric_df = report['metrics'] confusion_df = report['confusion'] multiclass_mcc = report['mcc'] metric_df.loc['combined', 'mcc'] = multiclass_mcc logger.info(metric_df) logger.info(confusion_df) # df = self.task_confusion[task_key] df = confusion_df df = df.rename_axis(self.task_nice_lookup[task_key], 0) df = df.rename_axis(self.task_nice_lookup[task_key], 1) df.index.name = None df.columns.name = None colfmt = '|l|' + 'r' * (len(df) - 1) + '|l|' tabular = Tabular(df, colfmt=colfmt, hline=True) tabular.groupxs = [list(range(len(df) - 1)), [len(df) - 1]] latex_str = tabular.as_tabular() sum_pred = df.index[-1] sum_real = df.columns[-1] latex_str = latex_str.replace(sum_pred, r'$\sum$ predicted') latex_str = latex_str.replace(sum_real, r'$\sum$ real') confusion_tex = ut.align(latex_str, '&', pos=None) logger.info(confusion_tex) df = metric_df # df = self.task_metrics[task_key] df = df.rename_axis(self.task_nice_lookup[task_key], 0) df = df.rename_axis({'mcc': 'MCC'}, 1) df = df.rename_axis({'combined': 'Combined'}, 1) df = df.drop(['markedness', 'bookmaker', 'fpr'], axis=1) df.index.name = None df.columns.name = None df['support'] = df['support'].astype(np.int) df.columns = ut.emap(upper_one, df.columns) import re tabular = Tabular(df, colfmt='numeric') top, header, mid, bot = tabular.as_parts() lines = mid[0].split('\n') newmid = [lines[0:-1], lines[-1:]] tabular.parts = (top, header, newmid, bot) latex_str = tabular.as_tabular() latex_str = re.sub(' -0.00 ', ' 0.00 ', latex_str) metrics_tex = latex_str logger.info(metrics_tex) dpath = str(self.dpath) confusion_fname = 'confusion_{}'.format(task_key) metrics_fname = 'eval_metrics_{}'.format(task_key) confusion_tex = confusion_tex.replace('Incomparable', 'Incomp.') confusion_tex = confusion_tex.replace('predicted', 'pred') ut.write_to(join(dpath, confusion_fname + '.tex'), confusion_tex) ut.write_to(join(dpath, metrics_fname + '.tex'), metrics_tex) fpath1 = ut.render_latex(confusion_tex, dpath=dpath, fname=confusion_fname) fpath2 = ut.render_latex(metrics_tex, dpath=dpath, fname=metrics_fname) return fpath1, fpath2
[docs] def write_sample_info(self): """ python -m wbia VerifierExpt.draw sample_info GZ_Master1 """ results = self.ensure_results('sample_info') # results['aid_pool'] # results['encoded_labels2d'] # results['multihist'] import wbia infr = wbia.AnnotInference.from_netx(results['graph']) info = ut.odict() info['n_names'] = (infr.pos_graph.number_of_components(),) info['n_aids'] = (len(results['pblm_aids']),) info['known_n_incomparable'] = infr.incomp_graph.number_of_edges() subtasks = results['subtasks'] task = subtasks['match_state'] flags = task.encoded_df == task.class_names.tolist().index(INCMP) incomp_edges = task.encoded_df[flags.values].index.tolist() nid_edges = [infr.pos_graph.node_labels(*e) for e in incomp_edges] nid_edges = vt.ensure_shape(np.array(nid_edges), (None, 2)) n_true = nid_edges.T[0] == nid_edges.T[1] info['incomp_info'] = { 'inside_pcc': n_true.sum(), 'betweeen_pcc': (~n_true).sum(), } for task_key, task in subtasks.items(): info[task_key + '_hist'] = task.make_histogram() info_str = ut.repr4(info) fname = 'sample_info.txt' ut.write_to(join(str(self.dpath), fname), info_str)
[docs] def measure_thresh(self, pblm): task_key = 'match_state' res = pblm.task_combo_res[task_key][self.clf_key][self.data_key] infr = pblm.infr truth_colors = infr._get_truth_colors() cfms = res.confusions(POSTV) fig = pt.figure(fnum=1, doclf=True) # NOQA ax = pt.gca() ax.plot(cfms.thresholds, cfms.n_fp, label='positive', color=truth_colors[POSTV]) cfms = res.confusions(NEGTV) ax.plot(cfms.thresholds, cfms.n_fp, label='negative', color=truth_colors[NEGTV]) # cfms = res.confusions(INCMP) # if len(cfms.thresholds) == 1: # cfms.thresholds = [0, 1] # cfms.n_fp = np.array(cfms.n_fp.tolist() * 2) # ax.plot(cfms.thresholds, cfms.n_fp, label='incomparable', # color=pt.color_funcs.darken_rgb(truth_colors[INCMP], .15)) ax.set_xlabel('thresholds') ax.set_ylabel('n_fp') ax.set_ylim(0, 20) ax.legend() cfms.plot_vs('fpr', 'thresholds')
def _draw_score_hist(self, freqs, xlabel, fnum): """helper""" bins, freq0, freq1 = ut.take(freqs, ['bins', 'neg_freq', 'pos_freq']) width = np.diff(bins)[0] xlim = (bins[0] - (width / 2), bins[-1] + (width / 2)) fig = pt.multi_plot( bins, (freq0, freq1), label_list=('negative', 'positive'), color_list=(pt.FALSE_RED, pt.TRUE_BLUE), kind='bar', width=width, alpha=0.7, edgecolor='none', xlabel=xlabel, ylabel='frequency', fnum=fnum, pnum=(1, 1, 1), rcParams=TMP_RC, stacked=True, ytickformat='%.3f', xlim=xlim, # title='LNBNN positive separation' ) pt.adjust_subplots(top=0.8, bottom=0.2, left=0.12, right=0.9) fig.set_size_inches([W, H]) return fig
[docs] def draw_class_score_hist(self): """Plots distribution of positive and negative scores""" task_key = 'match_state' results = self.ensure_results('all') task_combo_res = results['task_combo_res'] data_key = results['data_key'] clf_key = results['clf_key'] res = task_combo_res[task_key][clf_key][data_key] y = res.target_bin_df[POSTV] scores = res.probs_df[POSTV] bins = np.linspace(0, 1, 100) pos_freq = np.histogram(scores[y], bins)[0] neg_freq = np.histogram(scores[~y], bins)[0] pos_freq = pos_freq / pos_freq.sum() neg_freq = neg_freq / neg_freq.sum() score_hist_pos = {'bins': bins, 'pos_freq': pos_freq, 'neg_freq': neg_freq} lnbnn_data = results['lnbnn_data'] scores = lnbnn_data['score_lnbnn_1vM'].values y = lnbnn_data[POSTV].values # Get 95% of the data at least maxbin = scores[scores.argsort()][-max(1, int(len(scores) * 0.05))] bins = np.linspace(0, max(maxbin, 10), 100) pos_freq = np.histogram(scores[y], bins)[0] neg_freq = np.histogram(scores[~y], bins)[0] pos_freq = pos_freq / pos_freq.sum() neg_freq = neg_freq / neg_freq.sum() score_hist_lnbnn = {'bins': bins, 'pos_freq': pos_freq, 'neg_freq': neg_freq} fig1 = self._draw_score_hist(score_hist_pos, 'positive probability', 1) fig2 = self._draw_score_hist(score_hist_lnbnn, 'LNBNN score', 2) fname = 'score_hist_pos_{}.png'.format(data_key) vt.imwrite(join(str(self.dpath), fname), pt.render_figure_to_image(fig1, dpi=DPI)) fname = 'score_hist_lnbnn.png' vt.imwrite(join(str(self.dpath), fname), pt.render_figure_to_image(fig2, dpi=DPI))
[docs] def draw_mcc_thresh(self, task_key): """ python -m wbia VerifierExpt.draw mcc_thresh GZ_Master1 match_state python -m wbia VerifierExpt.draw mcc_thresh PZ_Master1 match_state python -m wbia VerifierExpt.draw mcc_thresh GZ_Master1 photobomb_state python -m wbia VerifierExpt.draw mcc_thresh PZ_Master1 photobomb_state """ mpl.rcParams.update(TMP_RC) results = self.ensure_results('all') data_key = results['data_key'] clf_key = results['clf_key'] task_combo_res = results['task_combo_res'] code_to_nice = self.task_nice_lookup[task_key] if task_key == 'photobomb_state': classes = ['pb'] elif task_key == 'match_state': classes = [POSTV, NEGTV, INCMP] res = task_combo_res[task_key][clf_key][data_key] roc_curves = [] for class_name in classes: c1 = res.confusions(class_name) if len(c1.thresholds) <= 2: continue class_nice = code_to_nice[class_name] idx = c1.mcc.argmax() t = c1.thresholds[idx] mcc = c1.mcc[idx] roc_curves += [ { 'label': class_nice + ', t={:.3f}, mcc={:.3f}'.format(t, mcc), 'thresh': c1.thresholds, 'mcc': c1.mcc, }, ] fig = pt.figure(fnum=1) # NOQA ax = pt.gca() for data in roc_curves: ax.plot(data['thresh'], data['mcc'], label='%s' % (data['label'])) ax.set_xlabel('threshold') ax.set_ylabel('MCC') # ax.set_title('%s ROC for %s' % (target_class.title(), self.species)) ax.legend() pt.adjust_subplots(top=0.8, bottom=0.2, left=0.12, right=0.9) fig.set_size_inches([W, H]) fname = 'mcc_thresh_{}.png'.format(task_key) fig_fpath = join(str(self.dpath), fname) vt.imwrite(fig_fpath, pt.render_figure_to_image(fig, dpi=DPI)) if ut.get_argflag('--diskshow'): ut.startfile(fig_fpath)
[docs] @classmethod def draw_tagged_pair(cls): import wbia # ibs = wbia.opendb(defaultdb='GZ_Master1') ibs = wbia.opendb(defaultdb='PZ_Master1') query_tag = 'leftrightface' rowids = ibs._get_all_annotmatch_rowids() texts = ['' if t is None else t for t in ibs.get_annotmatch_tag_text(rowids)] tags = [[] if t is None else t.split(';') for t in texts] logger.info(ut.repr4(ut.dict_hist(ut.flatten(tags)))) flags = [query_tag in t.lower() for t in texts] filtered_rowids = ut.compress(rowids, flags) edges = ibs.get_annotmatch_aids(filtered_rowids) # The facematch leftright side example # edge = (5161, 5245) edge = edges[0] # for edge in ut.InteractiveIter(edges): infr = wbia.AnnotInference(ibs=ibs, aids=edge, verbose=10) infr.reset_feedback('annotmatch', apply=True) match = infr._exec_pairwise_match([edge])[0] if False: # Fix the example tags infr.add_feedback( edge, 'match', tags=['facematch', 'leftrightface'], user_id='qt-hack', confidence='pretty_sure', ) infr.write_wbia_staging_feedback() infr.write_wbia_annotmatch_feedback() pass # THE DEPCACHE IS BROKEN FOR ANNOTMATCH APPARENTLY! >:( # Redo matches feat_keys = ['vecs', 'kpts', '_feats', 'flann'] match.annot1._mutable = True match.annot2._mutable = True for key in feat_keys: if key in match.annot1: del match.annot1[key] if key in match.annot2: del match.annot2[key] match.apply_all({}) fig = pt.figure(fnum=1, clf=True) ax = pt.gca() mpl.rcParams.update(TMP_RC) match.show( ax, vert=False, heatmask=True, show_lines=False, # show_ell=False, show_ell=False, show_ori=False, show_eig=False, # ell_alpha=.3, modifysize=True, ) # ax.set_xlabel(xlabel) self = cls() fname = 'custom_match_{}_{}_{}'.format(query_tag, *edge) dpath = ut.truepath(self.base_dpath) fpath = join(str(dpath), fname + '.jpg') vt.imwrite(fpath, pt.render_figure_to_image(fig, dpi=DPI))
[docs] def custom_single_hard_case(self): """ Example: >>> # DISABLE_DOCTEST >>> from wbia.scripts.postdoc import * >>> defaultdb = 'PZ_PB_RF_TRAIN' >>> #defaultdb = 'GZ_Master1' >>> defaultdb = 'PZ_MTEST' >>> self = VerifierExpt.collect(defaultdb) >>> self.dbname = 'PZ_PB_RF_TRAIN' """ task_key = 'match_state' edge = (383, 503) for _case in self.hard_cases[task_key]: if _case['edge'] == edge: case = _case break import wbia ibs = wbia.opendb(self.dbname) from wbia import core_annots config = { 'augment_orientation': True, 'ratio_thresh': 0.8, } config['checks'] = 80 config['sver_xy_thresh'] = 0.02 config['sver_ori_thresh'] = 3 config['Knorm'] = 3 config['symmetric'] = True config = ut.hashdict(config) aid1, aid2 = case['edge'] real_name = case['real'] pred_name = case['pred'] match = case['match'] code_to_nice = self.task_nice_lookup[task_key] real_nice, pred_nice = ut.take(code_to_nice, [real_name, pred_name]) fname = 'fail_{}_{}_{}_{}'.format(real_nice, pred_nice, aid1, aid2) # Draw case probs = case['probs'].to_dict() order = list(code_to_nice.values()) order = ut.setintersect(order, probs.keys()) probs = ut.map_dict_keys(code_to_nice, probs) probstr = ut.repr2(probs, precision=2, strkeys=True, nobr=True, key_order=order) xlabel = 'real={}, pred={},\n{}'.format(real_nice, pred_nice, probstr) match_list = ibs.depc.get( 'pairwise_match', ([aid1], [aid2]), 'match', config=config ) match = match_list[0] configured_lazy_annots = core_annots.make_configured_annots( ibs, [aid1], [aid2], config, config, preload=True ) match.annot1 = configured_lazy_annots[config][aid1] match.annot2 = configured_lazy_annots[config][aid2] match.config = config fig = pt.figure(fnum=1, clf=True) ax = pt.gca() mpl.rcParams.update(TMP_RC) match.show( ax, vert=False, heatmask=True, show_lines=False, show_ell=False, show_ori=False, show_eig=False, # ell_alpha=.3, modifysize=True, ) ax.set_xlabel(xlabel) subdir = 'cases_{}'.format(task_key) dpath = join(str(self.dpath), subdir) fpath = join(str(dpath), fname + '_custom.jpg') vt.imwrite(fpath, pt.render_figure_to_image(fig, dpi=DPI))