# -*- coding: utf-8 -*-
"""
TODO: separate out the tests and make this file just generate the demo data
"""
import logging
import itertools as it
import numpy as np
import utool as ut
from wbia.algo.graph.state import POSTV, NEGTV, INCMP, UNREV
from wbia.algo.graph.state import SAME, DIFF, NULL # NOQA
print, rrr, profile = ut.inject2(__name__)
logger = logging.getLogger('wbia')
[docs]def make_dummy_infr(annots_per_name):
import wbia
nids = [val for val, num in enumerate(annots_per_name, start=1) for _ in range(num)]
aids = range(len(nids))
infr = wbia.AnnotInference(None, aids, nids=nids, autoinit=True, verbose=1)
return infr
[docs]def demodata_mtest_infr(state='empty'):
import wbia
ibs = wbia.opendb(db='PZ_MTEST')
annots = ibs.annots()
names = list(annots.group_items(annots.nids).values())
ut.shuffle(names, rng=321)
test_aids = ut.flatten(names[1::2])
infr = wbia.AnnotInference(ibs, test_aids, autoinit=True)
infr.reset(state=state)
return infr
[docs]def demodata_infr2(defaultdb='PZ_MTEST'):
defaultdb = 'PZ_MTEST'
import wbia
ibs = wbia.opendb(defaultdb=defaultdb)
annots = ibs.annots()
names = list(annots.group_items(annots.nids).values())[0:20]
def dummy_phi(c, n):
x = np.arange(n)
phi = c * x / (c * x + 1)
phi = phi / phi.sum()
phi = np.diff(phi)
return phi
phis = {c: dummy_phi(c, 30) for c in range(1, 4)}
aids = ut.flatten(names)
infr = wbia.AnnotInference(ibs, aids, autoinit=True)
infr.init_termination_criteria(phis)
infr.init_refresh_criteria()
# Partially review
n1, n2, n3, n4 = names[0:4]
for name in names[4:]:
for a, b in ut.itertwo(name.aids):
infr.add_feedback((a, b), POSTV)
for name1, name2 in it.combinations(names[4:], 2):
infr.add_feedback((name1.aids[0], name2.aids[0]), NEGTV)
return infr
[docs]def demo2():
"""
CommandLine:
python -m wbia.algo.graph.demo demo2 --viz
python -m wbia.algo.graph.demo demo2
Example:
>>> # DISABLE_DOCTEST
>>> from wbia.algo.graph.demo import * # NOQA
>>> result = demo2()
>>> print(result)
"""
import wbia.plottool as pt
from wbia.scripts.thesis import TMP_RC
import matplotlib as mpl
mpl.rcParams.update(TMP_RC)
# ---- Synthetic data params
params = {
'redun.pos': 2,
'redun.neg': 2,
}
# oracle_accuracy = .98
# oracle_accuracy = .90
# oracle_accuracy = (.8, 1.0)
oracle_accuracy = (0.85, 1.0)
# oracle_accuracy = 1.0
# --- draw params
VISUALIZE = ut.get_argflag('--viz')
# QUIT_OR_EMEBED = 'embed'
QUIT_OR_EMEBED = 'quit'
TARGET_REVIEW = ut.get_argval('--target', type_=int, default=None)
START = ut.get_argval('--start', type_=int, default=None)
END = ut.get_argval('--end', type_=int, default=None)
# ------------------
# rng = np.random.RandomState(42)
# infr = demodata_infr(num_pccs=4, size=3, size_std=1, p_incon=0)
# infr = demodata_infr(num_pccs=6, size=7, size_std=1, p_incon=0)
# infr = demodata_infr(num_pccs=3, size=5, size_std=.2, p_incon=0)
infr = demodata_infr(pcc_sizes=[5, 2, 4])
infr.verbose = 100
# apply_dummy_viewpoints(infr)
# infr.ensure_cliques()
infr.ensure_cliques()
infr.ensure_full()
# infr.apply_edge_truth()
# Dummy scoring
infr.init_simulation(oracle_accuracy=oracle_accuracy, name='demo2')
# infr_gt = infr.copy()
dpath = ut.ensuredir(ut.truepath('~/Desktop/demo'))
ut.remove_files_in_dir(dpath)
fig_counter = it.count(0)
def show_graph(infr, title, final=False, selected_edges=None):
if not VISUALIZE:
return
# TODO: rich colored text?
latest = '\n'.join(infr.latest_logs())
showkw = dict(
# fontsize=infr.graph.graph['fontsize'],
# fontname=infr.graph.graph['fontname'],
show_unreviewed_edges=True,
show_inferred_same=False,
show_inferred_diff=False,
outof=(len(infr.aids)),
# show_inferred_same=True,
# show_inferred_diff=True,
selected_edges=selected_edges,
show_labels=True,
simple_labels=True,
# show_recent_review=not final,
show_recent_review=False,
# splines=infr.graph.graph['splines'],
reposition=False,
# with_colorbar=True
)
verbose = infr.verbose
infr.verbose = 0
infr_ = infr.copy()
infr_ = infr
infr_.verbose = verbose
infr_.show(pickable=True, verbose=0, **showkw)
infr.verbose = verbose
# logger.info('status ' + ut.repr4(infr_.status()))
# infr.show(**showkw)
ax = pt.gca()
pt.set_title(title, fontsize=20)
fig = pt.gcf()
fontsize = 22
if True:
# postprocess xlabel
lines = []
for line in latest.split('\n'):
if False and line.startswith('ORACLE ERROR'):
lines += ['ORACLE ERROR']
else:
lines += [line]
latest = '\n'.join(lines)
if len(lines) > 10:
fontsize = 16
if len(lines) > 12:
fontsize = 14
if len(lines) > 14:
fontsize = 12
if len(lines) > 18:
fontsize = 10
if len(lines) > 23:
fontsize = 8
if True:
pt.adjust_subplots(top=0.95, left=0, right=1, bottom=0.45, fig=fig)
ax.set_xlabel('\n' + latest)
xlabel = ax.get_xaxis().get_label()
xlabel.set_horizontalalignment('left')
# xlabel.set_x(.025)
xlabel.set_x(-0.6)
# xlabel.set_fontname('CMU Typewriter Text')
xlabel.set_fontname('Inconsolata')
xlabel.set_fontsize(fontsize)
ax.set_aspect('equal')
# ax.xaxis.label.set_color('red')
from os.path import join
fpath = join(dpath, 'demo_{:04d}.png'.format(next(fig_counter)))
fig.savefig(
fpath,
dpi=300,
# transparent=True,
edgecolor='none',
)
# pt.save_figure(dpath=dpath, dpi=300)
infr.latest_logs()
if VISUALIZE:
infr.update_visual_attrs(groupby='name_label')
infr.set_node_attrs('pin', 'true')
node_dict = ut.nx_node_dict(infr.graph)
logger.info(ut.repr4(node_dict[1]))
if VISUALIZE:
infr.latest_logs()
# Pin Nodes into the target groundtruth position
show_graph(infr, 'target-gt')
logger.info(ut.repr4(infr.status()))
infr.clear_feedback()
infr.clear_name_labels()
infr.clear_edges()
logger.info(ut.repr4(infr.status()))
infr.latest_logs()
if VISUALIZE:
infr.update_visual_attrs()
infr.prioritize('prob_match')
if VISUALIZE or TARGET_REVIEW is None or TARGET_REVIEW == 0:
show_graph(infr, 'initial state')
def on_new_candidate_edges(infr, edges):
# hack updateing visual attrs as a callback
infr.update_visual_attrs()
infr.on_new_candidate_edges = on_new_candidate_edges
infr.params.update(**params)
infr.refresh_candidate_edges()
VIZ_ALL = VISUALIZE and TARGET_REVIEW is None and START is None
logger.info('VIZ_ALL = %r' % (VIZ_ALL,))
if VIZ_ALL or TARGET_REVIEW == 0:
show_graph(infr, 'find-candidates')
# _iter2 = enumerate(infr.generate_reviews(**params))
# _iter2 = list(_iter2)
# assert len(_iter2) > 0
# prog = ut.ProgIter(_iter2, label='demo2', bs=False, adjust=False,
# enabled=False)
count = 1
first = 1
for edge, priority in infr._generate_reviews(data=True):
msg = 'review #%d, priority=%.3f' % (count, priority)
logger.info('\n----------')
infr.print('pop edge {} with priority={:.3f}'.format(edge, priority))
# logger.info('remaining_reviews = %r' % (infr.remaining_reviews()),)
# Make the next review
if START is not None:
VIZ_ALL = count >= START
if END is not None and count >= END:
break
infr.print(msg)
if ut.allsame(infr.pos_graph.node_labels(*edge)) and first:
# Have oracle make a mistake early
feedback = infr.request_oracle_review(edge, accuracy=0)
first -= 1
else:
feedback = infr.request_oracle_review(edge)
AT_TARGET = TARGET_REVIEW is not None and count >= TARGET_REVIEW - 1
SHOW_CANDIATE_POP = True
if SHOW_CANDIATE_POP and (VIZ_ALL or AT_TARGET):
# import utool
# utool.embed()
infr.print(
ut.repr2(infr.task_probs['match_state'][edge], precision=4, si=True)
)
infr.print('len(queue) = %r' % (len(infr.queue)))
# Show edge selection
infr.print('Oracle will predict: ' + feedback['evidence_decision'])
show_graph(infr, 'pre' + msg, selected_edges=[edge])
if count == TARGET_REVIEW:
infr.EMBEDME = QUIT_OR_EMEBED == 'embed'
infr.add_feedback(edge, **feedback)
infr.print('len(queue) = %r' % (len(infr.queue)))
# infr.apply_nondynamic_update()
# Show the result
if VIZ_ALL or AT_TARGET:
show_graph(infr, msg)
# import sys
# sys.exit(1)
if count == TARGET_REVIEW:
break
count += 1
infr.print('status = ' + ut.repr4(infr.status(extended=False)))
show_graph(infr, 'post-review (#reviews={})'.format(count), final=True)
# ROUND 2 FIGHT
# if TARGET_REVIEW is None and round2_params is not None:
# # HACK TO GET NEW THINGS IN QUEUE
# infr.params = round2_params
# _iter2 = enumerate(infr.generate_reviews(**params))
# prog = ut.ProgIter(_iter2, label='round2', bs=False, adjust=False,
# enabled=False)
# for count, (aid1, aid2) in prog:
# msg = 'reviewII #%d' % (count)
# logger.info('\n----------')
# logger.info(msg)
# logger.info('remaining_reviews = %r' % (infr.remaining_reviews()),)
# # Make the next review evidence_decision
# feedback = infr.request_oracle_review(edge)
# if count == TARGET_REVIEW:
# infr.EMBEDME = QUIT_OR_EMEBED == 'embed'
# infr.add_feedback(edge, **feedback)
# # Show the result
# if PRESHOW or TARGET_REVIEW is None or count >= TARGET_REVIEW - 1:
# show_graph(infr, msg)
# if count == TARGET_REVIEW:
# break
# show_graph(infr, 'post-re-review', final=True)
if not getattr(infr, 'EMBEDME', False):
if ut.get_computer_name().lower() in ['hyrule', 'ooo']:
pt.all_figures_tile(monitor_num=0, percent_w=0.5)
else:
pt.all_figures_tile()
ut.show_if_requested()
valid_views = ['L', 'F', 'R', 'B']
adjacent_views = {
v: [valid_views[(count + i) % len(valid_views)] for i in [-1, 0, 1]]
for count, v in enumerate(valid_views)
}
[docs]def get_edge_truth(infr, n1, n2):
node_dict = ut.nx_node_dict(infr.graph)
nid1 = node_dict[n1]['orig_name_label']
nid2 = node_dict[n2]['orig_name_label']
try:
view1 = node_dict[n1]['viewpoint']
view2 = node_dict[n2]['viewpoint']
comparable = view1 in adjacent_views[view2]
except KeyError:
comparable = True
# raise
same = nid1 == nid2
if not comparable:
return 2
else:
return int(same)
[docs]def apply_dummy_viewpoints(infr):
transition_rate = 0.5
transition_rate = 0
valid_views = ['L', 'F', 'R', 'B']
rng = np.random.RandomState(42)
class MarkovView(object):
def __init__(self):
self.dir_ = +1
self.state = 0
def __call__(self):
return self.next_state()
def next_state(self):
if self.dir_ == -1 and self.state <= 0:
self.dir_ = +1
if self.dir_ == +1 and self.state >= len(valid_views) - 1:
self.dir_ = -1
if rng.rand() < transition_rate:
self.state += self.dir_
return valid_views[self.state]
mkv = MarkovView()
nid_to_aids = ut.group_pairs(
[(n, d['name_label']) for n, d in infr.graph.nodes(data=True)]
)
grouped_nodes = list(nid_to_aids.values())
node_to_view = {node: mkv() for nodes in grouped_nodes for node in nodes}
infr.set_node_attrs('viewpoint', node_to_view)
[docs]def make_demo_infr(ccs, edges=[], nodes=[], infer=True):
"""
Depricate in favor of demodata_infr
"""
import wbia
import networkx as nx
if nx.__version__.startswith('1'):
nx.add_path = nx.Graph.add_path
G = wbia.AnnotInference._graph_cls()
G.add_nodes_from(nodes)
for cc in ccs:
if len(cc) == 1:
G.add_nodes_from(cc)
nx.add_path(G, cc, evidence_decision=POSTV, meta_decision=NULL)
# for edge in edges:
# u, v, d = edge if len(edge) == 3 else tuple(edge) + ({},)
G.add_edges_from(edges)
infr = wbia.AnnotInference.from_netx(G, infer=infer)
infr.verbose = 3
infr.relabel_using_reviews(rectify=False)
infr.graph.graph['dark_background'] = False
infr.graph.graph['ignore_labels'] = True
infr.set_node_attrs('width', 40)
infr.set_node_attrs('height', 40)
# infr.set_node_attrs('fontsize', fontsize)
# infr.set_node_attrs('fontname', fontname)
infr.set_node_attrs('fixed_size', True)
return infr
[docs]@profile
def demodata_infr(**kwargs):
"""
kwargs = {}
CommandLine:
python -m wbia.algo.graph.demo demodata_infr --show
python -m wbia.algo.graph.demo demodata_infr --num_pccs=25
python -m wbia.algo.graph.demo demodata_infr --profile --num_pccs=100
Ignore:
>>> from wbia.algo.graph.demo import * # NOQA
>>> from wbia.algo.graph import demo
>>> import networkx as nx
>>> kwargs = dict(num_pccs=6, p_incon=.5, size_std=2)
>>> kwargs = ut.argparse_dict(kwargs)
>>> infr = demo.demodata_infr(**kwargs)
>>> pccs = list(infr.positive_components())
>>> assert len(pccs) == kwargs['num_pccs']
>>> nonfull_pccs = [cc for cc in pccs if len(cc) > 1 and nx.is_empty(nx.complement(infr.pos_graph.subgraph(cc)))]
>>> expected_n_incon = len(nonfull_pccs) * kwargs['p_incon']
>>> n_incon = len(list(infr.inconsistent_components()))
>>> # TODO can test that we our sample num incon agrees with pop mean
>>> #sample_mean = n_incon / len(nonfull_pccs)
>>> #pop_mean = kwargs['p_incon']
>>> print('status = ' + ut.repr4(infr.status(extended=True)))
>>> ut.quit_if_noshow()
>>> infr.show(pickable=True, groupby='name_label')
>>> ut.show_if_requested()
Ignore:
kwargs = {
'ccs': [[1, 2, 3], [4, 5]]
}
"""
import networkx as nx
import vtool as vt
from wbia.algo.graph import nx_utils
def kwalias(*args):
params = args[0:-1]
default = args[-1]
for key in params:
if key in kwargs:
return kwargs[key]
return default
num_pccs = kwalias('num_pccs', 16)
size_mean = kwalias('pcc_size_mean', 'pcc_size', 'size', 5)
size_std = kwalias('pcc_size_std', 'size_std', 0)
# p_pcc_incon = kwargs.get('p_incon', .1)
p_pcc_incon = kwargs.get('p_incon', 0)
p_pcc_incomp = kwargs.get('p_incomp', 0)
pcc_sizes = kwalias('pcc_sizes', None)
pos_redun = kwalias('pos_redun', [1, 2, 3])
pos_redun = ut.ensure_iterable(pos_redun)
# number of maximum inconsistent edges per pcc
max_n_incon = kwargs.get('n_incon', 3)
rng = np.random.RandomState(0)
counter = 1
if pcc_sizes is None:
pcc_sizes = [
int(randn(size_mean, size_std, rng=rng, a_min=1)) for _ in range(num_pccs)
]
else:
num_pccs = len(pcc_sizes)
if 'ccs' in kwargs:
# Overwrites other options
pcc_sizes = list(map(len, kwargs['ccs']))
num_pccs = len(pcc_sizes)
size_mean = None
size_std = 0
new_ccs = []
pcc_iter = list(enumerate(pcc_sizes))
pcc_iter = ut.ProgIter(pcc_iter, enabled=num_pccs > 20, label='make pos-demo')
for i, size in pcc_iter:
p = 0.1
want_connectivity = rng.choice(pos_redun)
want_connectivity = min(size - 1, want_connectivity)
# Create basic graph of positive edges with desired connectivity
g = nx_utils.random_k_edge_connected_graph(
size, k=want_connectivity, p=p, rng=rng
)
nx.set_edge_attributes(g, name='evidence_decision', values=POSTV)
nx.set_edge_attributes(g, name='truth', values=POSTV)
# nx.set_node_attributes(g, name='orig_name_label', values=i)
assert nx.is_connected(g)
# Relabel graph with non-conflicting names
if 'ccs' in kwargs:
g = nx.relabel_nodes(g, dict(enumerate(kwargs['ccs'][i])))
else:
# Make sure nodes do not conflict with others
g = nx.relabel_nodes(g, dict(enumerate(range(counter, len(g) + counter + 1))))
counter += len(g)
# The probability any edge is inconsistent is `p_incon`
# This is 1 - P(all edges consistent)
# which means p(edge is consistent) = (1 - p_incon) / N
complement_edges = ut.estarmap(nx_utils.e_, nx_utils.complement_edges(g))
if len(complement_edges) > 0:
# compute probability that any particular edge is inconsistent
# to achieve probability the PCC is inconsistent
p_edge_inconn = 1 - (1 - p_pcc_incon) ** (1 / len(complement_edges))
p_edge_unrev = 0.1
p_edge_notcomp = 1 - (1 - p_pcc_incomp) ** (1 / len(complement_edges))
probs = np.array([p_edge_inconn, p_edge_unrev, p_edge_notcomp])
# if the total probability is greater than 1 the parameters
# are invalid, so we renormalize to "fix" it.
# if probs.sum() > 1:
# warnings.warn('probabilities sum to more than 1')
# probs = probs / probs.sum()
pcumsum = probs.cumsum()
# Determine which mutually exclusive state each complement edge is in
# logger.info('pcumsum = %r' % (pcumsum,))
states = np.searchsorted(pcumsum, rng.rand(len(complement_edges)))
incon_idxs = np.where(states == 0)[0]
if len(incon_idxs) > max_n_incon:
logger.info('max_n_incon = %r' % (max_n_incon,))
chosen = rng.choice(incon_idxs, max_n_incon, replace=False)
states[np.setdiff1d(incon_idxs, chosen)] = len(probs)
grouped_edges = ut.group_items(complement_edges, states)
for state, edges in grouped_edges.items():
truth = POSTV
if state == 0:
# Add in inconsistent edges
evidence_decision = NEGTV
# TODO: truth could be INCMP or POSTV
# new_edges.append((u, v, {'evidence_decision': NEGTV}))
elif state == 1:
evidence_decision = UNREV
# TODO: truth could be INCMP or POSTV
# new_edges.append((u, v, {'evidence_decision': UNREV}))
elif state == 2:
evidence_decision = INCMP
truth = INCMP
else:
continue
# Add in candidate edges
attrs = {'evidence_decision': evidence_decision, 'truth': truth}
for (u, v) in edges:
g.add_edge(u, v, **attrs)
new_ccs.append(g)
# (list(g.nodes()), new_edges))
pos_g = nx.union_all(new_ccs)
assert len(new_ccs) == len(list(nx.connected_components(pos_g)))
assert num_pccs == len(new_ccs)
# Add edges between the PCCS
neg_edges = []
if not kwalias('ignore_pair', False):
logger.info('making pairs')
pair_attrs_lookup = {
0: {'evidence_decision': NEGTV, 'truth': NEGTV},
1: {'evidence_decision': INCMP, 'truth': INCMP},
2: {'evidence_decision': UNREV, 'truth': NEGTV}, # could be incomp or neg
}
# These are the probabilities that one edge has this state
p_pair_neg = kwalias('p_pair_neg', 0.4)
p_pair_incmp = kwalias('p_pair_incmp', 0.2)
p_pair_unrev = kwalias('p_pair_unrev', 0)
# p_pair_neg = 1
cc_combos = (
(list(g1.nodes()), list(g2.nodes()))
for (g1, g2) in it.combinations(new_ccs, 2)
)
valid_cc_combos = [(cc1, cc2) for cc1, cc2 in cc_combos if len(cc1) and len(cc2)]
for cc1, cc2 in ut.ProgIter(valid_cc_combos, label='make neg-demo'):
possible_edges = ut.estarmap(nx_utils.e_, it.product(cc1, cc2))
# probability that any edge between these PCCs is negative
n_edges = len(possible_edges)
p_edge_neg = 1 - (1 - p_pair_neg) ** (1 / n_edges)
p_edge_incmp = 1 - (1 - p_pair_incmp) ** (1 / n_edges)
p_edge_unrev = 1 - (1 - p_pair_unrev) ** (1 / n_edges)
# Create event space with sizes proportional to probabilities
pcumsum = np.cumsum([p_edge_neg, p_edge_incmp, p_edge_unrev])
# Roll dice for each of the edge to see which state it lands on
possible_pstate = rng.rand(len(possible_edges))
states = np.searchsorted(pcumsum, possible_pstate)
flags = states < len(pcumsum)
stateful_states = states.compress(flags)
stateful_edges = ut.compress(possible_edges, flags)
unique_states, groupxs_list = vt.group_indices(stateful_states)
for state, groupxs in zip(unique_states, groupxs_list):
# logger.info('state = %r' % (state,))
# Add in candidate edges
edges = ut.take(stateful_edges, groupxs)
attrs = pair_attrs_lookup[state]
for (u, v) in edges:
neg_edges.append((u, v, attrs))
logger.info('Made {} neg_edges between PCCS'.format(len(neg_edges)))
else:
logger.info('ignoring pairs')
import wbia
G = wbia.AnnotInference._graph_cls()
G.add_nodes_from(pos_g.nodes(data=True))
G.add_edges_from(pos_g.edges(data=True))
G.add_edges_from(neg_edges)
infr = wbia.AnnotInference.from_netx(G, infer=kwargs.get('infer', True))
infr.verbose = 3
infr.relabel_using_reviews(rectify=False)
# fontname = 'Ubuntu'
fontsize = 12
fontname = 'sans'
splines = 'spline'
# splines = 'ortho'
# splines = 'line'
infr.set_node_attrs('shape', 'circle')
infr.graph.graph['ignore_labels'] = True
infr.graph.graph['dark_background'] = False
infr.graph.graph['fontname'] = fontname
infr.graph.graph['fontsize'] = fontsize
infr.graph.graph['splines'] = splines
infr.set_node_attrs('width', 29)
infr.set_node_attrs('height', 29)
infr.set_node_attrs('fontsize', fontsize)
infr.set_node_attrs('fontname', fontname)
infr.set_node_attrs('fixed_size', True)
# Set synthetic ground-truth attributes for testing
# infr.apply_edge_truth()
infr.edge_truth = infr.get_edge_attrs('truth')
# Make synthetic verif
infr.dummy_verif = DummyVerif(infr)
infr.verifiers = {}
infr.verifiers['match_state'] = infr.dummy_verif
infr.demokw = kwargs
return infr
[docs]def randn(mean=0, std=1, shape=[], a_max=None, a_min=None, rng=None):
a = (rng.randn(*shape) * std) + mean
if a_max is not None or a_min is not None:
a = np.clip(a, a_min, a_max)
return a
[docs]class DummyVerif(object):
"""
generates dummy scores between edges (not necesarilly in the graph)
CommandLine:
python -m wbia.algo.graph.demo DummyVerif:1
Example:
>>> # ENABLE_DOCTEST
>>> from wbia.algo.graph.demo import * # NOQA
>>> from wbia.algo.graph import demo
>>> import networkx as nx
>>> kwargs = dict(num_pccs=6, p_incon=.5, size_std=2)
>>> infr = demo.demodata_infr(**kwargs)
>>> infr.dummy_verif.predict_edges([(1, 2)])
>>> infr.dummy_verif.predict_edges([(1, 21)])
>>> assert len(infr.dummy_verif.infr.task_probs['match_state']) == 2
"""
def __init__(verif, infr):
verif.rng = np.random.RandomState(4033913)
verif.dummy_params = {
NEGTV: {'mean': 0.2, 'std': 0.25},
POSTV: {'mean': 0.85, 'std': 0.2},
INCMP: {'mean': 0.15, 'std': 0.1},
}
verif.score_dist = randn
verif.infr = infr
verif.orig_nodes = set(infr.aids)
verif.orig_labels = infr.get_node_attrs('orig_name_label')
verif.orig_groups = ut.invert_dict(verif.orig_labels, False)
verif.orig_groups = ut.map_vals(set, verif.orig_groups)
[docs] def show_score_probs(verif):
"""
CommandLine:
python -m wbia.algo.graph.demo DummyVerif.show_score_probs --show
Example:
>>> # ENABLE_DOCTEST
>>> from wbia.algo.graph.demo import * # NOQA
>>> import wbia
>>> infr = wbia.AnnotInference(None)
>>> verif = DummyVerif(infr)
>>> verif.show_score_probs()
>>> ut.show_if_requested()
"""
import wbia.plottool as pt
dist = verif.score_dist
n = 100000
for key in verif.dummy_params.keys():
probs = dist(
shape=[n], rng=verif.rng, a_max=1, a_min=0, **verif.dummy_params[key]
)
color = verif.infr._get_truth_colors()[key]
pt.plt.hist(probs, bins=100, label=key, alpha=0.8, color=color)
pt.legend()
[docs] def dummy_ranker(verif, u, K=10):
"""
simulates the ranking algorithm. Order is defined using the dummy vsone
scores, but tests are only applied to randomly selected gt and gf
pairs. So, you usually will get a gt result, but you might not if all
the scores are bad.
"""
infr = verif.infr
nid = verif.orig_labels[u]
others = verif.orig_groups[nid]
others_gt = sorted(others - {u})
others_gf = sorted(verif.orig_nodes - others)
# rng = np.random.RandomState(u + 4110499444 + len(others))
rng = verif.rng
vs_list = []
k_gt = min(len(others_gt), max(1, K // 2))
k_gf = min(len(others_gf), max(1, K * 4))
if k_gt > 0:
gt = rng.choice(others_gt, k_gt, replace=False)
vs_list.append(gt)
if k_gf > 0:
gf = rng.choice(others_gf, k_gf, replace=False)
vs_list.append(gf)
u_edges = [infr.e_(u, v) for v in it.chain.from_iterable(vs_list)]
u_probs = np.array(infr.dummy_verif.predict_edges(u_edges))
# infr.set_edge_attrs('prob_match', ut.dzip(u_edges, u_probs))
# Need to determenistically sort here
# sortx = np.argsort(u_probs)[::-1][0:K]
sortx = np.argsort(u_probs)[::-1][0:K]
ranked_edges = ut.take(u_edges, sortx)
# assert len(ranked_edges) == K
return ranked_edges
[docs] def find_candidate_edges(verif, K=10):
"""
Example:
>>> # ENABLE_DOCTEST
>>> from wbia.algo.graph.demo import * # NOQA
>>> from wbia.algo.graph import demo
>>> import networkx as nx
>>> kwargs = dict(num_pccs=40, size=2)
>>> infr = demo.demodata_infr(**kwargs)
>>> edges = list(infr.dummy_verif.find_candidate_edges(K=100))
>>> scores = np.array(infr.dummy_verif.predict_edges(edges))
"""
new_edges = []
nodes = list(verif.infr.graph.nodes())
for u in nodes:
new_edges.extend(verif.dummy_ranker(u, K=K))
# logger.info('new_edges = %r' % (ut.hash_data(new_edges),))
new_edges = set(new_edges)
return new_edges
def _get_truth(verif, edge):
infr = verif.infr
if edge in infr.edge_truth:
return infr.edge_truth[edge]
node_dict = ut.nx_node_dict(infr.graph)
nid1 = node_dict[edge[0]]['orig_name_label']
nid2 = node_dict[edge[1]]['orig_name_label']
return POSTV if nid1 == nid2 else NEGTV
[docs] def predict_proba_df(verif, edges):
"""
CommandLine:
python -m wbia.algo.graph.demo DummyVerif.predict_edges
Example:
>>> # ENABLE_DOCTEST
>>> from wbia.algo.graph.demo import * # NOQA
>>> from wbia.algo.graph import demo
>>> import networkx as nx
>>> kwargs = dict(num_pccs=40, size=2)
>>> infr = demo.demodata_infr(**kwargs)
>>> verif = infr.dummy_verif
>>> edges = list(infr.graph.edges())
>>> probs = verif.predict_proba_df(edges)
>>> #print('scores = %r' % (scores,))
>>> #hashid = ut.hash_data(scores)
>>> #print('hashid = %r' % (hashid,))
>>> #assert hashid == 'cdlkytilfeqgmtsihvhqwffmhczqmpil'
"""
infr = verif.infr
edges = list(it.starmap(verif.infr.e_, edges))
prob_cache = infr.task_probs['match_state']
is_miss = np.array([e not in prob_cache for e in edges])
# is_hit = ~is_miss
if np.any(is_miss):
miss_edges = ut.compress(edges, is_miss)
miss_truths = [verif._get_truth(edge) for edge in miss_edges]
grouped_edges = ut.group_items(miss_edges, miss_truths, sorted_=False)
# Need to make this determenistic too
states = [POSTV, NEGTV, INCMP]
for key in sorted(grouped_edges.keys()):
group = grouped_edges[key]
probs0 = randn(
shape=[len(group)],
rng=verif.rng,
a_max=1,
a_min=0,
**verif.dummy_params[key],
)
# Just randomly assign other probs
probs1 = verif.rng.rand(len(group)) * (1 - probs0)
probs2 = 1 - (probs0 + probs1)
for edge, probs in zip(group, zip(probs0, probs1, probs2)):
prob_cache[edge] = ut.dzip(states, probs)
from wbia.algo.graph import nx_utils as nxu
import pandas as pd
probs = pd.DataFrame(
ut.take(prob_cache, edges),
index=nxu.ensure_multi_index(edges, ('aid1', 'aid2')),
)
return probs
[docs] def predict_edges(verif, edges):
pos_scores = verif.predict_proba_df(edges)[POSTV]
return pos_scores