-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathvizu_knn.py
executable file
·130 lines (118 loc) · 6.2 KB
/
vizu_knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
import matplotlib; matplotlib.use('Agg')
import numpy as np
import heapq
import os
import tempfile
from jinja2 import Environment, FileSystemLoader
import tsh.obsolete as tsh; logger = tsh.create_logger(__name__)
from utils import read_listfile, read_truthfile, read_weightsfile
def get_samples_data(listname, dissimname, predname, propname, truthname, only_errors, k=5):
meta, data = read_listfile(listname)
dissim_meta, dissim_ids, dissim = read_weightsfile(dissimname)
assert (data['id'] == dissim_ids).all()
if 'truth' in meta:
truth_name = meta['truth']
labels = meta[truth_name + '_labels']
if predname != None:
pred_meta, pred = read_listfile(predname)
assert (data['id'] == pred['id']).all()
if propname != None:
prop_meta, prop = read_listfile(propname)
assert (data['id'] == prop['id']).all()
if truthname != None:
truth_meta, truth_ids, truth = read_truthfile(truthname)
truth_name = truth_meta['truth']
labels = truth_meta[truth_name + '_labels']
samples = []
for id in truth_ids: #j in range(len(data)):
j = np.nonzero(data['id'] == id)[0][0]
d = data[j]
sample = {
'id': d['id'],
'image': os.path.join('image', os.path.relpath(
os.path.join(meta['image_prefix'], d['image']),
'/home/imp/kazmar/vt_project/Segmentation/Fine/MetaSys/')),
'mask': os.path.join('image', os.path.relpath(
os.path.join(meta['mask_prefix'], d['mask']),
'/home/imp/kazmar/vt_project/Segmentation/Fine/MetaSys/')),
'expr': os.path.join('expr', 'expr%d.png' % d['id']) }
neighbor_ids = heapq.nsmallest(k, range(len(data)), key=lambda ind: dissim[j, ind])
neighbors = []
for n in neighbor_ids:
neighbors += [{'id': data['id'][n],
'dissim': dissim[j, n],
'expr': os.path.join('expr', 'expr%d.png' % data['id'][n]),
'prediction': pred[pred['id'] == data['id'][n]][0]['pred'] if predname != None else 'N/A',
'propagated': prop[prop['id'] == data['id'][n]][0]['pred'] if propname != None else 'N/A',
'truth': truth[np.array(truth_ids) == data['id'][n]][0] if truthname != None and data['id'][n] in truth_ids else 'N/A'
}]
sample['neighbors'] = neighbors
if predname != None:
sample['prediction'] = pred[j]['pred']
else:
sample['prediction'] = 'N/A'
if propname != None:
sample['propagated'] = prop[j]['pred']
else:
sample['propagated'] = 'N/A'
if truthname != None and d['id'] in truth_ids:
sample['truth'] = truth[np.array(truth_ids) == d['id']][0]
if only_errors:
if (predname == None or sample['truth'] == sample['prediction']) and \
(propname == None or sample['truth'] == sample['propagated']):
continue
else:
sample['truth'] = 'N/A'
if only_errors:
continue
samples += [sample]
return samples
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Creates HTML with k nearest neighbors.')
parser.add_argument('-c', '--config', dest='config', required=False, action='store', default=None, help='Path to the config file')
parser.add_argument('-l', '--list', dest='list', nargs='*', required=True, action='store', default=None, help='List file.')
parser.add_argument('-d', '--dissimilarities', dest='dissim', nargs='*', required=True, action='store', default=None, help='Dissimilarities file.')
parser.add_argument('-e', '--only-errors', dest='errors', required=False, action='store_true', default=False, help='Keep only samples incorrectly predicted/propagated.')
parser.add_argument('-p', '--predictions', dest='pred', nargs='*', required=False, action='store', default=None, help='Predictions file.')
parser.add_argument('-q', '--propagated', dest='prop', nargs='*', required=False, action='store', default=None, help='Predictions file.')
parser.add_argument('-t', '--truth', dest='truth', nargs='*', required=False, action='store', default=None, help='Truth file(s).')
parser.add_argument('-o', '--output', dest='output', required=False, action='store', default=None, help='Output directory.')
opts = parser.parse_args()
config = tsh.read_config(opts, __file__)
k = 10
if opts.output == None:
outdir = tempfile.mkdtemp(dir=os.curdir, prefix='out')
logger.info('Output directory %s', outdir)
else:
outdir = opts.output
if not os.path.exists(outdir):
tsh.makedirs(outdir)
if len(opts.list) == 0:
logger.error('Nothing to do')
else:
samples = []
for n in range(len(opts.list)):
logger.info('Processing %s', opts.list[n])
if opts.truth != None:
truthname = opts.truth[n]
if opts.pred != None:
predname = opts.pred[n]
if opts.prop != None:
propname = opts.prop[n]
samples += get_samples_data(opts.list[n], opts.dissim[n], predname, propname, truthname, opts.errors, k=k)
template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')
env = Environment(loader=FileSystemLoader(template_dir))
if len(opts.list) == 1:
inputname = os.path.basename(os.path.splitext(opts.list[0])[0])
if opts.list.endswith('.gz'):
inputname = os.path.splitext(inputname)[0]
suffix = 'errors' if opts.errors else 'neighbors'
open(os.path.join(outdir, inputname + '-' + suffix + '.html'), 'w').write(env.get_template('neighbors.html').render(
title='Nearest neighbors for ' + inputname, k=k, samples=samples))
else:
inputname = 'all'
suffix = 'errors' if opts.errors else 'neighbors'
open(os.path.join(outdir, inputname + '-' + suffix + '.html'), 'w').write(env.get_template('neighbors.html').render(
title='Nearest neighbors for ' + inputname, k=k, samples=samples))