Source code for wbia.algo.graph.mixin_priority

# -*- coding: utf-8 -*-
import logging
import numpy as np
import utool as ut
import networkx as nx
from wbia import constants as const
from wbia.algo.graph import nx_utils as nxu
from wbia.algo.graph.state import POSTV, NEGTV
from wbia.algo.graph.state import SAME, DIFF, NULL  # NOQA
from wbia.algo.graph import mixin_loops

print, rrr, profile = ut.inject2(__name__)
logger = logging.getLogger('wbia')


ENABLE_PRIORITY_PCC_CORRECTION = True


[docs]class Priority(object): """ Handles prioritization of edges for review. Example: >>> # ENABLE_DOCTEST >>> from wbia.algo.graph.mixin_priority import * # NOQA >>> from wbia.algo.graph import demo >>> infr = demo.demodata_infr(num_pccs=20) """
[docs] def remaining_reviews(infr): assert infr.queue is not None return len(infr.queue)
def _pop(infr, *args): """Wraps queue so ordering is determenistic""" (e, (p, _)) = infr.queue.pop(*args) return (e, -p) def _push(infr, edge, priority): """Wraps queue so ordering is determenistic""" PSEUDO_RANDOM_TIEBREAKER = False if PSEUDO_RANDOM_TIEBREAKER: # Make it so tiebreakers have a pseudo-random order chaotic = int.from_bytes(ut.digest_data(edge, alg='sha1')[:8], 'big') tiebreaker = (chaotic,) + edge else: tiebreaker = edge infr.assert_edge(edge) if False and mixin_loops.PRINCETON_KAIA_EDGE_LIST is not None: # logger.info('[Priority._push] FILTERING EDGES FOR KAIA') # Sanity check, make sure that one of the edges is in the tier 1 dataset include_filter_set = set(mixin_loops.PRINCETON_KAIA_EDGE_LIST) u, v = edge if u not in include_filter_set and v not in include_filter_set: return # tiebreaker = (chaotic(chaotic(u) + chaotic(v)), u, v) infr.queue[edge] = (-priority, tiebreaker) def _peek_many(infr, n): """Wraps queue so ordering is determenistic""" return [(k, -p) for (k, (p, _)) in infr.queue.peek_many(n)] def _remove_edge_priority(infr, edges): if infr.queue is None: return edges_ = [edge for edge in edges if edge in infr.queue] if len(edges_) > 0: infr.print('removed priority from {} edges'.format(len(edges_)), 5) infr.queue.delete_items(edges_) def _reinstate_edge_priority(infr, edges): if infr.queue is None: return edges_ = [edge for edge in edges if edge not in infr.queue] if len(edges_) > 0: # TODO: use whatever the current metric is metric = 'prob_match' infr.print('reprioritize {} edges'.format(len(edges_)), 5) priorities = infr.gen_edge_values(metric, edges_, default=1e-9) for edge, priority in zip(edges_, priorities): infr._push(edge, priority) def _increase_priority(infr, edges, amount=10): if infr.queue is None: return infr.print('increase priority of {} edges'.format(len(edges)), 5) metric = 'prob_match' priorities = infr.gen_edge_values(metric, edges, default=1e-9) for edge, base in zip(edges, priorities): infr.push(edge, base + amount)
[docs] def remove_internal_priority(infr, cc): if infr.queue is not None: infr._remove_edge_priority(nxu.edges_inside(infr.graph, cc))
[docs] def remove_external_priority(infr, cc): if infr.queue is not None: infr._remove_edge_priority(nxu.edges_outgoing(infr.graph, cc))
[docs] def remove_between_priority(infr, cc1, cc2): if infr.queue is not None: infr._remove_edge_priority(nxu.edges_cross(infr.graph, cc1, cc2))
[docs] def reinstate_between_priority(infr, cc1, cc2): if infr.queue is not None: # Reinstate the appropriate edges into the queue edges = nxu.edges_cross(infr.unreviewed_graph, cc1, cc2) infr._reinstate_edge_priority(edges)
[docs] def reinstate_internal_priority(infr, cc): if infr.queue is not None: # Reinstate the appropriate edges into the queue edges = nxu.edges_inside(infr.unreviewed_graph, cc) infr._reinstate_edge_priority(edges)
[docs] def reinstate_external_priority(infr, cc): if infr.queue is not None: # Reinstate the appropriate edges into the queue edges = nxu.edges_outgoing(infr.unreviewed_graph, cc) infr._reinstate_edge_priority(edges)
def _correct_priorities(infr, edge, priority): corrected_priority = None if not ENABLE_PRIORITY_PCC_CORRECTION: return corrected_priority if priority < 10: try: primary_task = 'match_state' decision_probs = infr.task_probs[primary_task][edge] except KeyError: decision_probs = None if decision_probs is not None: u, v = edge nid1, nid2 = infr.node_labels(u, v) if nid1 == nid2: if priority != decision_probs[NEGTV]: corrected_priority = decision_probs[NEGTV] else: if priority != decision_probs[POSTV]: corrected_priority = decision_probs[POSTV] return corrected_priority
[docs] @profile def prioritize( infr, metric=None, edges=None, scores=None, force_inconsistent=True, reset=False ): """ Adds edges to the priority queue Doctest: >>> from wbia.algo.graph.mixin_priority import * # NOQA >>> from wbia.algo.graph import demo >>> infr = demo.demodata_infr(num_pccs=7, size=5) >>> infr.ensure_cliques(meta_decision=SAME) >>> # Add a negative edge inside a PCC >>> ccs = list(infr.positive_components()) >>> edge1 = tuple(list(ccs[0])[0:2]) >>> edge2 = tuple(list(ccs[1])[0:2]) >>> infr.add_feedback(edge1, NEGTV) >>> infr.add_feedback(edge2, NEGTV) >>> num_new = infr.prioritize(reset=True) >>> order = infr._peek_many(np.inf) >>> scores = ut.take_column(order, 1) >>> assert scores[0] > 10 >>> assert len(scores) == num_new, 'should prioritize two hypotheis edges' >>> unrev_edges = set(infr.unreviewed_graph.edges()) >>> err_edges = set(ut.flatten(infr.nid_to_errors.values())) >>> edges = set(list(unrev_edges - err_edges)[0:2]) >>> edges.update(list(err_edges)[0:2]) >>> num_new = infr.prioritize(edges=edges, reset=True) >>> order2 = infr._peek_many(np.inf) >>> scores2 = np.array(ut.take_column(order2, 1)) >>> assert np.all(scores2[0:2] > 10) >>> assert np.all(scores2[2:] < 10) Example: import wbia infr = wbia.AnnotInference('PZ_MTEST', aids='all', autoinit='staging') infr.verbose = 1000 infr.load_published() incon_edges = set(ut.iflatten(infr.nid_to_errors.values())) assert len(incon_edges) > 0 edges = list(infr.find_pos_redun_candidate_edges()) assert len(set(incon_edges).intersection(set(edges))) == 0 infr.add_candidate_edges(edges) infr.prioritize() logger.info(ut.repr4(infr.status())) """ if reset or infr.queue is None: infr.queue = ut.PriorityQueue() low = 1e-9 if metric is None: metric = 'prob_match' # If edges are not explicilty specified get unreviewed and error edges # that are not redundant if edges is None: if scores is not None: raise ValueError('must provide edges with scores') unrev_edges = infr.unreviewed_graph.edges() edges = set(infr.filter_edges_flagged_as_redun(unrev_edges)) infr.print('ensuring {} edge(s) get priority'.format(len(edges)), 5) if infr.params['inference.enabled'] and force_inconsistent: # Ensure that maybe_error edges are always prioritized maybe_error_edges = set(infr.maybe_error_edges()) extra_edges = set(maybe_error_edges).difference(set(edges)) extra_edges = list(extra_edges) infr.print( 'ensuring {} inconsistent edge(s) get priority'.format(len(extra_edges)), 5, ) if scores is not None: pgen = list(infr.gen_edge_values(metric, extra_edges, default=low)) extra_scores = np.array(pgen) extra_scores[np.isnan(extra_scores)] = low scores = ut.aslist(scores) + ut.aslist(extra_scores) edges = ut.aslist(edges) + extra_edges # Ensure edges are in some arbitrary order edges = list(edges) # Ensure given scores do not have nan values if scores is None: pgen = infr.gen_edge_values(metric, edges, default=low) priorities = np.array(list(pgen)) priorities[np.isnan(priorities)] = low else: priorities = np.asarray(scores) if np.any(np.isnan(priorities)): priorities[np.isnan(priorities)] = low if infr.params['inference.enabled']: # Increase priority of any flagged maybe_error edges err_flags = [e in maybe_error_edges for e in edges] priorities[err_flags] += 10 # Push new items into the priority queue num_new = 0 for edge, priority in zip(edges, priorities): if edge not in infr.queue: num_new += 1 corrected_priority = infr._correct_priorities(edge, priority) if corrected_priority is not None: priority = corrected_priority infr._push(edge, priority) infr.print('added %d edges to the queue' % (num_new,), 1) return num_new
[docs] def push(infr, edge, priority=None): """ Push an edge back onto the queue """ if priority is None: priority = 'prob_match' if isinstance(priority, str): prob_match = infr.get_edge_attr(edge, priority, default=1e-9) priority = prob_match corrected_priority = infr._correct_priorities(edge, priority) if corrected_priority is not None: priority = corrected_priority # Use edge-nids to break ties for determenistic behavior infr._push(edge, priority)
[docs] @profile def pop(infr): """ Main interface to the priority queue used by the algorithm loops. Pops the highest priority edge from the queue. """ # The outer loop simulates recursive calls without using the stack SIZE_THRESH_ENABLED = False SIZE_THRESH = 10 while True: try: edge, priority = infr._pop() except IndexError: # Nothing more to review! return else: # Re-prioritize positive or negative relative to PCCs corrected_priority = infr._correct_priorities(edge, priority) if corrected_priority is not None: infr.push(edge, corrected_priority) continue if SIZE_THRESH_ENABLED and infr.phase > 0: u, v = edge nid1, nid2 = infr.node_labels(u, v) cc1 = infr.pos_graph.component(nid1) cc2 = infr.pos_graph.component(nid2) size1 = len(cc1) size2 = len(cc2) if nid1 == nid2 and not (size1 > SIZE_THRESH and size2 > SIZE_THRESH): continue if infr.params['redun.enabled']: u, v = edge nid1, nid2 = infr.node_labels(u, v) pos_graph = infr.pos_graph pos_graph[nid1] if nid1 == nid2: if nid1 not in infr.nid_to_errors: # skip edges that increase local connectivity beyond # redundancy thresholds. k_pos = infr.params['redun.pos'] # Much faster to compute local connectivity on subgraph cc = infr.pos_graph.component(nid1) pos_subgraph = infr.pos_graph.subgraph(cc) pos_conn = nx.connectivity.local_edge_connectivity( pos_subgraph, u, v, cutoff=k_pos ) # Compute local connectivity if pos_conn >= k_pos: continue # Loop instead of recursion # return infr.pop() if infr.params['queue.conf.thresh'] is not None: # Ignore reviews that would re-enforce a relationship that # already has high confidence. thresh_code = infr.params['queue.conf.thresh'] thresh = const.CONFIDENCE.CODE_TO_INT[thresh_code] # FIXME: at the time of writing a hard coded priority of 10 # or more means that this is part of an inconsistency if priority < 10: u, v = edge nid1, nid2 = infr.node_labels(u, v) if nid1 == nid2: if infr.confidently_connected(u, v, thresh): infr.pop() else: if infr.confidently_separated(u, v, thresh): infr.pop() if getattr(infr, 'fix_mode_split', False): # only checking edges within a name nid1, nid2 = infr.pos_graph.node_labels(*edge) if nid1 != nid2: continue # Loop instead of recursion # return infr.pop() if getattr(infr, 'fix_mode_merge', False): # only checking edges within a name nid1, nid2 = infr.pos_graph.node_labels(*edge) if nid1 == nid2: continue # Loop instead of recursive (infr.pop()) if getattr(infr, 'fix_mode_predict', False): # No longer needed. pred = infr.get_edge_data(edge).get('pred', None) # only report cases where the prediction differs # FIXME: at the time of writing a hard coded priority of 10 # or more means that this is part of an inconsistency if priority < 10: nid1, nid2 = infr.node_labels(*edge) if nid1 == nid2: u, v = edge # Don't re-review confident CCs thresh = const.CONFIDENCE.CODE_TO_INT['pretty_sure'] if infr.confidently_connected(u, v, thresh): continue # Loop instead of recursive (infr.pop()) if pred == POSTV and nid1 == nid2: continue # Loop instead of recursive (infr.pop()) if pred == NEGTV and nid1 != nid2: continue # Loop instead of recursive (infr.pop()) else: logger.info('in error recover mode') infr.assert_edge(edge) return edge, priority return
[docs] def peek(infr): return infr.peek_many(n=1)[0]
[docs] def peek_many(infr, n): """ Peeks at the top n edges in the queue. Example: >>> # ENABLE_DOCTEST >>> from wbia.algo.graph.mixin_priority import * # NOQA >>> from wbia.algo.graph import demo >>> infr = demo.demodata_infr(num_pccs=7, size=5) >>> infr.refresh_candidate_edges() >>> infr.peek_many(50) """ # Do pops that may invalidate pos redun edges internal to PCCs items = [] count = 0 # Pop the top n edges off the queue while len(infr.queue) > 0 and count < n: items.append(infr.pop()) count += 1 # Push them back because we are just peeking # (although we may have invalidated things based on local connectivity) for edge, priority in items: infr.push(edge, priority) return items
[docs] def confidently_connected(infr, u, v, thresh=2): """ Checks if u and v are conneted by edges above a confidence threshold """ def satisfied(G, child, edge): decision = infr.edge_decision(edge) if decision != POSTV: return False data = G.get_edge_data(*edge) conf = data.get('confidence', 'unspecified') conf_int = const.CONFIDENCE.CODE_TO_INT[conf] conf_int = 0 if conf_int is None else conf_int return conf_int >= thresh for node in ut.bfs_conditional( infr.graph, u, yield_if=satisfied, continue_if=satisfied ): if node == v: return True return False
[docs] def confidently_separated(infr, u, v, thresh=2): """ Checks if u and v are conneted by edges above a confidence threshold Doctest: >>> from wbia.algo.graph.mixin_priority import * # NOQA >>> from wbia.algo.graph import demo >>> infr = demo.make_demo_infr(ccs=[(1, 2), (3, 4), (5, 6), (7, 8)]) >>> infr.add_feedback((1, 5), NEGTV) >>> infr.add_feedback((5, 8), NEGTV) >>> infr.add_feedback((6, 3), NEGTV) >>> u, v = (1, 4) >>> thresh = 0 >>> assert not infr.confidently_separated(u, v, thresh) >>> infr.add_feedback((2, 3), NEGTV) >>> assert not infr.confidently_separated(u, v, thresh) """ def can_cross(G, edge, n_negs): """ DFS state condition Args: edge (tuple): the edge we are trying to cross n_negs (int): the number of negative edges crossed so far Returns: flag, new_state - flag (bool): True if the edge can be crossed new_state: new state for future decisions in this path. """ decision = infr.edge_decision(edge) # only cross positive or negative edges if decision in {POSTV, NEGTV}: # only cross a negative edge once willcross = decision == NEGTV if willcross and n_negs == 0: data = G.get_edge_data(*edge) # only cross edges above a threshold conf = data.get('confidence', 'unspecified') conf_int = const.CONFIDENCE.CODE_TO_INT[conf] conf_int = 0 if conf_int is None else conf_int flag = conf_int >= thresh num = n_negs + willcross return flag, num return False, n_negs # need to do DFS check for this. Make DFS only allowed to # cross a negative edge once. # def dfs_cond_rec(G, parent, state, visited=None): # if visited is None: # visited = set() # visited.add(parent) # for child in G.neighbors(parent): # if child not in visited: # edge = (parent, child) # flag, new_state = can_cross(G, edge, state) # if flag: # yield child # yield from dfs_cond_rec(G, child, new_state, visited) # need to do DFS check for this. Make DFS only allowed to # cross a negative edge once. def dfs_cond_stack(G, source, state): # stack based version visited = {source} stack = [(source, iter(G[source]), state)] while stack: parent, children, state = stack[-1] try: child = next(children) if child not in visited: edge = (parent, child) flag, new_state = can_cross(G, edge, state) if flag: yield child visited.add(child) stack.append((child, iter(G[child]), new_state)) except StopIteration: stack.pop() for node in dfs_cond_stack(infr.graph, u, 0): if node == v: return True return False
[docs] def generate_reviews(infr, pos_redun=None, neg_redun=None, data=False): """ Dynamic generator that yeilds high priority reviews """ if pos_redun is not None: infr.params['redun.pos'] = pos_redun if neg_redun is not None: infr.params['redun.neg'] = neg_redun infr.prioritize() return infr._generate_reviews(data=data)
def _generate_reviews(infr, data=False): if data: while True: yield from infr.pop() else: while True: edge, priority = infr.pop() yield edge