-
Notifications
You must be signed in to change notification settings - Fork 37
/
util.py
225 lines (188 loc) · 10.2 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import argparse
import scipy.misc
import h5py
from general.util import string_or_gitresman_or_none
DEFAULT_ARCH_CHOICES = ['mnist']
def make_standard_parser(description='No decription provided', arch_choices=DEFAULT_ARCH_CHOICES,
skip_train=False, skip_val=False):
'''Make a standard parser, probably good for many experiments.
Arguments:
description: just used for help
arch_choices: list of strings that may be specified when
selecting architecture type. For example, ('mnist', 'cifar')
would allow selection of different networks for each
dataset. architecture may also be toggled via the --conv and
--xprop switches. Default architecture is the first in the
list.
skip_train: if True, skip adding a train_h5 arg
skip_val: if True, skip adding a val_h5 arg
'''
parser = argparse.ArgumentParser(description=description,
formatter_class=lambda prog: argparse.ArgumentDefaultsHelpFormatter(prog)
)
# Optimization
parser.add_argument('--opt', type=str, default='sgd', choices=('sgd', 'rmsprop', 'adam'), help='Which optimizer to use')
parser.add_argument('--lr', '-L', type=float, default=.001, help='learning rate')
parser.add_argument('--mom', '-M', type=float, default=.9, help='momentum (only has effect for sgd/rmsprop)')
parser.add_argument('--beta1', type=float, default=.9, help='beta1 for adam opt')
parser.add_argument('--beta2', type=float, default=.99, help='beta2 for adam opt')
parser.add_argument('--adameps', type=float, default=1e-8, help='epsilon for adam opt')
parser.add_argument('--epochs', '-E',type=int, default=5, help='number of epochs.')
# Model
parser.add_argument('--arch', type=str, default=arch_choices[0],
choices=arch_choices, help='Which architecture to use (choices: %s).' % arch_choices)
parser.add_argument('--conv', '-C', action='store_true', help='Use a conv model.')
parser.add_argument('--xprop', '-X', action='store_true', help='Use an xprop model')
parser.add_argument('--springprop', '-S', action='store_true', help='Use an springprop model')
parser.add_argument('--springt', '-t', type=float, default=0.5, help='T value to use for springs')
parser.add_argument('--learncoords', '--lc', action='store_true', help='Learn coordinates (update them during training) instead of keeping them fixed.')
parser.add_argument('--l2', type=float, default=0.0, help='L2 regularization to apply to direct parameters.')
parser.add_argument('--l2i', type=float, default=0.0, help='L2 regularization to apply to indirect parameters.')
# Experimental setup
parser.add_argument('--seed', type=int, default=0, help='random number seed for intial params and tf graph')
parser.add_argument('--minibatch', '-mb', type=int, default=256, help='minibatch size')
parser.add_argument('--test', action='store_true', help='Use test data instead of validation data (for final run).')
parser.add_argument('--shuffletrain', '--st', dest='shuffletrain', action='store_true', help='Shuffle training set each epoch.')
parser.add_argument('--noshuffletrain', '--nst', dest='shuffletrain', action='store_false', help='Do not shuffle training set each epoch. Ignore the following "default" value:')
parser.set_defaults(shuffletrain=True)
# Misc
parser.add_argument('--ipy', '-I', action='store_true', help='drop into embedded iPython for debugging.')
parser.add_argument('--nocolor', '--nc', action='store_true', help='Do not use color output (for scripts).')
parser.add_argument('--skipval', action='store_true', help='Skip validation set entirely.')
parser.add_argument('--verbose', '-V', action='store_true', help='Verbose mode (print some extra stuff)')
parser.add_argument('--cpu', action='store_true', help='Skip GPU assert (allows use of CPU but still uses GPU if possible)')
# Saving a loading
parser.add_argument('--snapshot-to', type=str, default='net', help='Where to snapshot to. --snapshot-to NAME produces NAME_iter.h5 and NAME.json')
parser.add_argument('--snapshot-every', type=int, default=-1, help='Snapshot every N minibatches. 0 to disable snapshots, -1 to snapshot only on last iteration.')
parser.add_argument('--load', type=str, default=None, help='Snapshot to load from: specify as H5_FILE:MISC_FILE.')
parser.add_argument('--output', '-O', type=string_or_gitresman_or_none, default='', help='directory output TF results to. If None, checks for GIT_RESULTS_MANAGER_DIR environment variable and uses that directory, if defined, unless output is set to "skip", in which case no output is written even if GIT_RESULTS_MANAGER_DIR is defined. If nothing else: skips output.')
# Dataset
if not skip_train:
parser.add_argument('train_h5', type=str, help='Training set hdf5 file.')
if not skip_val:
parser.add_argument('val_h5', type=str, help='Validation set hdf5 file.')
return parser
import numpy as np
def merge_dict_append(d1,d2):
assert list(d1.keys()) == list(d2.keys()), 'Two dictionaries to merge must have the same set of keys'
merged = {}
for kk in list(d1.keys()):
v1 = d1[kk] if isinstance(d1[kk], list) else [d1[kk]]
v2 = d2[kk] if isinstance(d2[kk], list) else [d2[kk]]
merged[kk] = v1 + v2
return merged
def average_dict_values(dd):
averaged = {}
for kk in list(dd.keys()):
assert len(dd[kk]) == len(dd['weights']), "lengths must match"
x1 = np.array(dd[kk])
x2 = np.array(dd['weights'])
averaged[kk] = np.sum(np.multiply(x1,x2)) / np.sum(x2)
return averaged
def interpolate(z0, z1, t):
# z0: noise vector 0
# z1: noise vector 1
# t: scale between [0,1]
r0 = np.linalg.norm(z0)
r1 = np.linalg.norm(z1)
rscale = r0 * t + r1 * (1-t)
zt = z0 * t + z1 * (1-t)
rt = np.linalg.norm(zt)
zt /= (rt/rscale)
return zt
def image_separator(consolidated_images, nh=10, nw=10):
# inverse function of exp.cgan.utils.merge
h = (consolidated_images.shape[0]-nh+1) / nh
w = (consolidated_images.shape[1]-nw+1) / nw
gray = True if len(consolidated_images.shape) == 2 else False
images = []
for ii in range(nh):
for jj in range(nw):
start_w = ii*w if ii == 0 else ii*(w+1)
start_h = jj*h if jj == 0 else jj*(h+1)
if gray:
_image = consolidated_images[start_h:start_h + h, start_w:start_w + w]
else:
_image = consolidated_images[start_h:start_h + h, start_w:start_w + w, :]
images.append(_image)
return images
def merge(images, size, black_divider=True):
h, w = images.shape[1], images.shape[2]
if (images.shape[3] in (3,4)):
c = images.shape[3]
if black_divider:
img = np.zeros((h * size[0] + size[0]-1, w * size[1] + size[1] - 1, c))
else:
img = np.ones((h * size[0] + size[0]-1, w * size[1] + size[1] - 1, c))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
start_w = i*w if i == 0 else i*(w+1)
start_h = j*h if j == 0 else j*(h+1)
img[start_h:start_h + h, start_w:start_w + w, :] = image
return img
elif images.shape[3]==1:
if black_divider:
img = np.zeros((h * size[0] + size[0] - 1, w * size[1] + size[1] - 1))
else:
img = np.ones((h * size[0] + size[0] - 1, w * size[1] + size[1] - 1))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
start_w = i*w if i == 0 else i*(w+1)
start_h = j*h if j == 0 else j*(h+1)
img[start_h:start_h + h, start_w:start_w + w] = image[:,:,0]
return img
else:
raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
def save_images(images, size, image_path, black_divider=True):
image = np.squeeze(merge(images, size, black_divider=black_divider))
return scipy.misc.imsave(image_path, image)
def save_average_image(images, image_path):
im = images.sum(0)
im -= im.min()
im *= 255.0 / im.max()
#out = Image.fromarray(im).convert('RGB')
#out.save(image_path)
#return
image = np.squeeze(im)
image = scipy.misc.imresize(image, [im.shape[0]*10,im.shape[1]*10,-1])
return scipy.misc.imsave(image_path, image)
def load_sort_of_clevr():
data_dir = os.path.join("./data", "sort_of_clevr")
filename = os.path.join(data_dir,
'sort_of_clevr_2objs_10rad_50000imgs_64x.h5')
if not os.path.isfile(filename):
print(('{} does not exist.. Try running ./data/sort_of_clevr_generator.py ?'.format(filename)))
return
ff = h5py.File(filename, 'r')
if 'test_x' in ff:
train_x = ff['train_x']
test_x = ff['test_x']
else:
cutoff = int(ff['train_x'].shape[0] * 0.9)
train_x = ff['train_x'][:cutoff]
test_x = ff['train_x'][cutoff:]
if train_x.size * 4 + test_x.size * 4 < 1e9:
train_x, test_x = np.array(train_x), np.array(test_x)
return (train_x, test_x)