Source code for wbia.algo.graph.mixin_simulation

# -*- coding: utf-8 -*-
"""
Mixin functionality for experiments, tests, and simulations.
This includes recordings measures used to generate plots in JC's thesis.
"""
# -*- coding: utf-8 -*-
import logging
import utool as ut
import ubelt as ub
import pandas as pd
import itertools as it
from wbia.algo.graph import nx_utils as nxu
from wbia.algo.graph.state import POSTV, NEGTV, INCMP, UNREV, UNKWN, NULL

print, rrr, profile = ut.inject2(__name__)
logger = logging.getLogger('wbia')


[docs]class SimulationHelpers(object):
[docs] def init_simulation( infr, oracle_accuracy=1.0, k_redun=2, enable_autoreview=True, enable_inference=True, classifiers=None, match_state_thresh=None, pb_state_thresh=None, max_outer_loops=None, name=None, ): infr.print('INIT SIMULATION', color='yellow') infr.name = name infr.simulation_mode = True infr.verifiers = classifiers infr.params['inference.enabled'] = enable_inference infr.params['autoreview.enabled'] = enable_autoreview infr.params['redun.pos'] = k_redun infr.params['redun.neg'] = k_redun # keeps track of edges where the decision != the groundtruth infr.mistake_edges = set() infr.queue = ut.PriorityQueue() infr.oracle = UserOracle(oracle_accuracy, rng=infr.name) if match_state_thresh is None: match_state_thresh = { POSTV: 1.0, NEGTV: 1.0, INCMP: 1.0, } if pb_state_thresh is None: pb_state_thresh = { 'pb': 0.5, 'notpb': 0.9, } infr.task_thresh = { 'photobomb_state': pd.Series(pb_state_thresh), 'match_state': pd.Series(match_state_thresh), } infr.params['algo.max_outer_loops'] = max_outer_loops
[docs] def init_test_mode(infr): from wbia.algo.graph import nx_dynamic_graph infr.print('init_test_mode') infr.test_mode = True # infr.edge_truth = {} infr.metrics_list = [] infr.test_state = { 'n_decision': 0, 'n_algo': 0, 'n_manual': 0, 'n_true_merges': 0, 'n_error_edges': 0, 'confusion': None, } infr.test_gt_pos_graph = nx_dynamic_graph.DynConnGraph() infr.test_gt_pos_graph.add_nodes_from(infr.aids) infr.nid_to_gt_cc = ut.group_items(infr.aids, infr.orig_name_labels) infr.node_truth = ut.dzip(infr.aids, infr.orig_name_labels) # infr.real_n_pcc_mst_edges = sum( # len(cc) - 1 for cc in infr.nid_to_gt_cc.values()) # ut.cprint('real_n_pcc_mst_edges = %r' % ( # infr.real_n_pcc_mst_edges,), 'red') infr.metrics_list = [] infr.nid_to_gt_cc = ut.group_items(infr.aids, infr.orig_name_labels) infr.real_n_pcc_mst_edges = sum(len(cc) - 1 for cc in infr.nid_to_gt_cc.values()) infr.print( 'real_n_pcc_mst_edges = %r' % (infr.real_n_pcc_mst_edges,), color='red' )
[docs] def measure_error_edges(infr): for edge, data in infr.edges(data=True): true_state = data['truth'] pred_state = data.get('evidence_decision', UNREV) if pred_state != UNREV: if true_state != pred_state: error = ut.odict([('real', true_state), ('pred', pred_state)]) yield edge, error
[docs] @profile def measure_metrics(infr): real_pos_edges = [] n_true_merges = infr.test_state['n_true_merges'] confusion = infr.test_state['confusion'] n_tp = confusion[POSTV][POSTV] confusion[POSTV] columns = set(confusion.keys()) reviewd_cols = columns - {UNREV} non_postv = reviewd_cols - {POSTV} non_negtv = reviewd_cols - {NEGTV} n_fn = sum(ut.take(confusion[POSTV], non_postv)) n_fp = sum(ut.take(confusion[NEGTV], non_negtv)) n_error_edges = sum( confusion[r][c] + confusion[c][r] for r, c in ut.combinations(reviewd_cols, 2) ) # assert n_fn + n_fp == n_error_edges pred_n_pcc_mst_edges = n_true_merges if 0: import ubelt as ub for timer in ub.Timerit(10): with timer: # Find undetectable errors num_undetectable_fn = 0 for nid1, nid2 in infr.neg_redun_metagraph.edges(): cc1 = infr.pos_graph.component(nid1) cc2 = infr.pos_graph.component(nid2) neg_edges = nxu.edges_cross(infr.neg_graph, cc1, cc2) for u, v in neg_edges: real_nid1 = infr.node_truth[u] real_nid2 = infr.node_truth[v] if real_nid1 == real_nid2: num_undetectable_fn += 1 break # Find undetectable errors num_undetectable_fp = 0 for nid in infr.pos_redun_nids: cc = infr.pos_graph.component(nid) if not ut.allsame(ut.take(infr.node_truth, cc)): num_undetectable_fp += 1 logger.info('num_undetectable_fn = %r' % (num_undetectable_fn,)) logger.info('num_undetectable_fp = %r' % (num_undetectable_fp,)) if 0: n_error_edges2 = 0 n_fn2 = 0 n_fp2 = 0 for edge, data in infr.edges(data=True): decision = data.get('evidence_decision', UNREV) true_state = infr.edge_truth[edge] if true_state == decision and true_state == POSTV: real_pos_edges.append(edge) elif decision != UNREV: if true_state != decision: n_error_edges2 += 1 if true_state == POSTV: n_fn2 += 1 elif true_state == NEGTV: n_fp2 += 1 assert n_error_edges2 == n_error_edges assert n_tp == len(real_pos_edges) assert n_fn == n_fn2 assert n_fp == n_fp2 # pred_n_pcc_mst_edges2 = sum( # len(cc) - 1 for cc in infr.test_gt_pos_graph.connected_components() # ) if False: import networkx as nx # set(infr.test_gt_pos_graph.edges()) == set(real_pos_edges) pred_n_pcc_mst_edges = 0 for cc in nx.connected_components(nx.Graph(real_pos_edges)): pred_n_pcc_mst_edges += len(cc) - 1 assert n_true_merges == pred_n_pcc_mst_edges # Find all annotations involved in a mistake assert n_error_edges == len(infr.mistake_edges) direct_mistake_aids = {a for edge in infr.mistake_edges for a in edge} mistake_nids = set(infr.node_labels(*direct_mistake_aids)) mistake_aids = set( ut.flatten([infr.pos_graph.component(nid) for nid in mistake_nids]) ) pos_acc = pred_n_pcc_mst_edges / infr.real_n_pcc_mst_edges metrics = { 'n_decision': infr.test_state['n_decision'], 'n_manual': infr.test_state['n_manual'], 'n_algo': infr.test_state['n_algo'], 'phase': infr.loop_phase, 'pos_acc': pos_acc, 'n_merge_total': infr.real_n_pcc_mst_edges, 'n_merge_remain': infr.real_n_pcc_mst_edges - n_true_merges, 'n_true_merges': n_true_merges, 'recovering': infr.is_recovering(), # 'recovering2': infr.test_state['recovering'], 'merge_remain': 1 - pos_acc, 'n_mistake_aids': len(mistake_aids), 'frac_mistake_aids': len(mistake_aids) / len(infr.aids), 'n_mistake_nids': len(mistake_nids), 'n_errors': n_error_edges, 'n_fn': n_fn, 'n_fp': n_fp, 'refresh_support': len(infr.refresh.manual_decisions), 'pprob_any': infr.refresh.prob_any_remain(), 'mu': infr.refresh._ewma, 'test_action': infr.test_state['test_action'], 'action': infr.test_state.get('action', None), 'user_id': infr.test_state['user_id'], 'pred_decision': infr.test_state['pred_decision'], 'true_decision': infr.test_state['true_decision'], 'n_neg_redun': infr.neg_redun_metagraph.number_of_edges(), # 'n_neg_redun1': ( # infr.neg_metagraph.number_of_edges() # - infr.neg_metagraph.number_of_selfloops() # ), } return metrics
def _print_previous_loop_statistics(infr, count): # Print stats about what happend in the this loop history = infr.metrics_list[-count:] recover_blocks = ut.group_items( [ (k, sum(1 for i in g)) for k, g in it.groupby(ut.take_column(history, 'recovering')) ] ).get(True, []) infr.print( ('Recovery mode entered {} times, ' 'made {} recovery decisions.').format( len(recover_blocks), sum(recover_blocks) ), color='green', ) testaction_hist = ut.dict_hist(ut.take_column(history, 'test_action')) infr.print( 'Test Action Histogram: {}'.format(ut.repr4(testaction_hist, si=True)), color='yellow', ) if infr.params['inference.enabled']: action_hist = ut.dict_hist( ut.emap(frozenset, ut.take_column(history, 'action')) ) infr.print( 'Inference Action Histogram: {}'.format(ub.repr2(action_hist, si=True)), color='yellow', ) infr.print( 'Decision Histogram: {}'.format( ut.repr2(ut.dict_hist(ut.take_column(history, 'pred_decision')), si=True) ), color='yellow', ) infr.print( 'User Histogram: {}'.format( ut.repr2(ut.dict_hist(ut.take_column(history, 'user_id')), si=True) ), color='yellow', ) @profile def _dynamic_test_callback(infr, edge, decision, prev_decision, user_id): was_gt_pos = infr.test_gt_pos_graph.has_edge(*edge) # prev_decision = infr.get_edge_attr(edge, 'decision', default=UNREV) # prev_decision = list(infr.edge_decision_from([edge]))[0] true_decision = infr.edge_truth[edge] was_within_pred = infr.pos_graph.are_nodes_connected(*edge) was_within_gt = infr.test_gt_pos_graph.are_nodes_connected(*edge) was_reviewed = prev_decision != UNREV is_within_gt = was_within_gt was_correct = prev_decision == true_decision is_correct = true_decision == decision # logger.info('prev_decision = {!r}'.format(prev_decision)) # logger.info('decision = {!r}'.format(decision)) # logger.info('true_decision = {!r}'.format(true_decision)) test_print = ut.partial(infr.print, level=2) def test_print(x, **kw): infr.print('[ACTION] ' + x, level=2, **kw) # test_print = lambda *a, **kw: None # NOQA if 0: num = infr.recover_graph.number_of_components() old_data = infr.get_nonvisual_edge_data(edge) # logger.info('old_data = %s' % (ut.repr4(old_data, stritems=True),)) logger.info('n_prev_reviews = %r' % (old_data['num_reviews'],)) logger.info('prev_decision = %r' % (prev_decision,)) logger.info('decision = %r' % (decision,)) logger.info('was_gt_pos = %r' % (was_gt_pos,)) logger.info('was_within_pred = %r' % (was_within_pred,)) logger.info('was_within_gt = %r' % (was_within_gt,)) logger.info('num inconsistent = %r' % (num,)) # is_recovering = infr.is_recovering() if decision == POSTV: if is_correct: if not was_gt_pos: infr.test_gt_pos_graph.add_edge(*edge) elif was_gt_pos: test_print('UNDID GOOD POSITIVE EDGE', color='red') infr.test_gt_pos_graph.remove_edge(*edge) is_within_gt = infr.test_gt_pos_graph.are_nodes_connected(*edge) split_gt = is_within_gt != was_within_gt if split_gt: test_print('SPLIT A GOOD MERGE', color='red') infr.test_state['n_true_merges'] -= 1 confusion = infr.test_state['confusion'] if confusion is None: # initialize dynamic confusion matrix # import pandas as pd states = (POSTV, NEGTV, INCMP, UNREV, UNKWN) confusion = {r: {c: 0 for c in states} for r in states} # pandas takes a really long time doing this # confusion = pd.DataFrame(columns=states, index=states) # confusion[:] = 0 # confusion.index.name = 'real' # confusion.columns.name = 'pred' infr.test_state['confusion'] = confusion if was_reviewed: confusion[true_decision][prev_decision] -= 1 confusion[true_decision][decision] += 1 else: confusion[true_decision][decision] += 1 test_action = None action_color = None if is_correct: # CORRECT DECISION if was_reviewed: if prev_decision == decision: test_action = 'correct duplicate' action_color = 'yellow' else: infr.mistake_edges.remove(edge) test_action = 'correction' action_color = 'green' if decision == POSTV: if not was_within_gt: test_action = 'correction redid merge' action_color = 'green' infr.test_state['n_true_merges'] += 1 else: if decision == POSTV: if not was_within_gt: test_action = 'correct merge' action_color = 'green' infr.test_state['n_true_merges'] += 1 else: test_action = 'correct redundant positive' action_color = 'blue' else: if decision == NEGTV: test_action = 'correct negative' action_color = 'cyan' else: test_action = 'correct uninferrable' action_color = 'cyan' else: action_color = 'red' # INCORRECT DECISION infr.mistake_edges.add(edge) if was_reviewed: if prev_decision == decision: test_action = 'incorrect duplicate' elif was_correct: test_action = 'incorrect undid good edge' else: if decision == POSTV: if was_within_pred: test_action = 'incorrect redundant merge' else: test_action = 'incorrect new merge' else: test_action = 'incorrect new mistake' infr.test_state['test_action'] = test_action infr.test_state['pred_decision'] = decision infr.test_state['true_decision'] = true_decision infr.test_state['user_id'] = user_id infr.test_state['recovering'] = infr.recover_graph.has_node( edge[0] ) or infr.recover_graph.has_node(edge[1]) infr.test_state['n_decision'] += 1 if user_id.startswith('algo'): infr.test_state['n_algo'] += 1 elif user_id.startswith('user') or user_id == 'oracle': infr.test_state['n_manual'] += 1 else: raise AssertionError('unknown user_id=%r' % (user_id,)) test_print(test_action, color=action_color) assert test_action is not None, 'what happened?'
[docs]class UserOracle(object): def __init__(oracle, accuracy, rng): if isinstance(rng, str): rng = sum(map(ord, rng)) rng = ut.ensure_rng(rng, impl='python') if isinstance(accuracy, tuple): oracle.normal_accuracy = accuracy[0] oracle.recover_accuracy = accuracy[1] else: oracle.normal_accuracy = accuracy oracle.recover_accuracy = accuracy # .5 oracle.rng = rng oracle.states = {POSTV, NEGTV, INCMP}
[docs] def review(oracle, edge, truth, infr, accuracy=None): feedback = { 'user_id': 'user:oracle', 'confidence': 'absolutely_sure', 'evidence_decision': None, 'meta_decision': NULL, 'timestamp_s1': ut.get_timestamp('int', isutc=True), 'timestamp_c1': ut.get_timestamp('int', isutc=True), 'timestamp_c2': ut.get_timestamp('int', isutc=True), 'tags': [], } is_recovering = infr.is_recovering() if accuracy is None: if is_recovering: accuracy = oracle.recover_accuracy else: accuracy = oracle.normal_accuracy # The oracle can get anything where the hardness is less than its # accuracy hardness = oracle.rng.random() error = accuracy < hardness if error: error_options = list(oracle.states - {truth} - {INCMP}) observed = oracle.rng.choice(list(error_options)) else: observed = truth if accuracy < 1.0: feedback['confidence'] = 'pretty_sure' if accuracy < 0.5: feedback['confidence'] = 'guessing' feedback['evidence_decision'] = observed if error: infr.print( 'ORACLE ERROR real={} pred={} acc={:.2f} hard={:.2f}'.format( truth, observed, accuracy, hardness ), 2, color='red', ) # infr.print( # 'ORACLE ERROR edge={}, truth={}, pred={}, rec={}, hardness={:.3f}'.format(edge, truth, observed, is_recovering, hardness), # 2, color='red') return feedback