-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_evaluation.py
38 lines (33 loc) · 1.43 KB
/
run_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python
# coding=utf-8
from __future__ import division, print_function, unicode_literals
from dae import ex
for ds in ['bars', 'corners', 'shapes', 'multi_mnist', 'mnist_shape', 'simple_superpos']:
for k in [2, 3, 5, 12]:
ex.run_command('evaluate', config_updates={
'dataset.name': ds,
'net_filename': 'Networks/best_{}_dae.h5'.format(ds),
'em.k': k,
'em.nr_iters': 10,
'em.dump_results': 'Results/{}_10_{}.pickle'.format(ds, k),
'seed': 1337})
# Longer results for bars convergence plot
for k in [2, 3, 5, 12]:
ex.run_command('evaluate', config_updates={
'dataset.name': 'bars',
'net_filename': 'Networks/best_{}_dae.h5'.format('bars'),
'em.k': k,
'em.nr_iters': 20,
'em.dump_results': 'Results/{}_20_{}.pickle'.format('bars', k),
'seed': 42})
# Results for multi-object trained networks
for ds in ['bars', 'corners', 'shapes', 'multi_mnist', 'mnist_shape']:
for k in [2, 3, 5, 12]:
ex.run_command('evaluate', config_updates={
'dataset.name': ds,
'net_filename': 'Networks/best_{}_dae_train_multi.h5'.format(ds),
'em.k': k,
'em.nr_iters': 10,
'em.e_step': 'max',
'em.dump_results': 'Results/{}_10_{}_train_multi.pickle'.format(ds, k),
'seed': 23})