-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
executable file
·172 lines (153 loc) · 4.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# NOTE: mostly from aditya-grover's UAE project
import numpy as np
from scipy.special import expit
import inspect
import importlib
import tensorflow as tf
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS
def load_dynamic(class_name, module_name):
"""
Load a class dynamically from the classname.
:param class_name: string
"""
return getattr(importlib.import_module(module_name), class_name)
def get_arglist(func):
"""
Get the argument list of a function.
:param func: function handle
:return arglist: list of argument names
"""
argspec = inspect.getfullargspec(func)
return argspec[0]
def get_args(arglist, config):
"""
Get the argument values for arguments in an argument list
from a configuration dictionary.
:param arglist: list of strings
:param conf: configuration dict
"""
args = {}
for argname in arglist:
# do our best and hope that failures mean default values are present
try:
args[argname] = config[argname]
except:
pass
return args
def sigmoid(x, gamma=1):
"""
Sigmoid function (numerically stable).
"""
u = gamma * x
return expit(u)
def provide_unlabelled_data(data, batch_size=10):
"""
Provide batches of data; data = X
"""
N = len(data)
# Create indices for data
X_indexed = list(zip(range(N), np.split(data, N)))
def data_iterator():
while True:
idxs = np.arange(0, N)
np.random.shuffle(idxs)
X_shuf = [X_indexed[idx] for idx in idxs]
for batch_idx in range(0, N, batch_size):
X_shuf_batch = X_shuf[batch_idx:batch_idx+batch_size]
indices, X_batch = zip(*X_shuf_batch)
X_batch = np.vstack(X_batch)
yield indices, X_batch
return data_iterator()
def provide_data(data, batch_size=10):
"""
Provide batches of data; data = (X, y).
"""
N = len(data[0])
X_mean = np.mean(data[0], axis=0)
X_std = np.std(data[0], axis=0)
y_mean = np.mean(data[1], axis=0)
y_std = np.std(data[1], axis=0)
# Create indices for data
X_indexed = list(zip(range(N), np.split(data[0], N)))
y_indexed = list(zip(range(N), np.split(data[1], N)))
def data_iterator():
while True:
idxs = np.arange(0, N)
np.random.shuffle(idxs)
X_shuf = [X_indexed[idx] for idx in idxs]
y_shuf = [y_indexed[idx] for idx in idxs]
for batch_idx in range(0, N, batch_size):
X_shuf_batch = X_shuf[batch_idx:batch_idx+batch_size]
y_shuf_batch = y_shuf[batch_idx:batch_idx+batch_size]
indices, X_batch = zip(*X_shuf_batch)
_, y_batch = zip(*y_shuf_batch)
X_batch = np.vstack(X_batch)
y_batch = np.vstack(y_batch)
yield indices, X_batch, y_batch
return data_iterator(), X_mean, X_std, y_mean, y_std
def plot(samples, m=4, n=None, px=28, title=None):
"""
Plots samples.
n: Number of rows and columns; n^2 samples
px: Pixels per side for each sample
"""
if n is None:
n = m
fig = plt.figure(figsize=(m, n))
gs = gridspec.GridSpec(n, m)
print('n: {}, m: {}'.format(n, m))
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
# TODO: this is specifically for celebA!!
if FLAGS.datasource == 'celebA':
px = 64
plt.imshow(sample.reshape(px, px, 3))
elif FLAGS.datasource == 'svhn' or FLAGS.datasource == 'cifar10':
px = 32
plt.imshow(sample.reshape(px, px, 3))
else:
plt.imshow(sample.reshape(px, px), cmap='Greys')
if title is None:
title = 'samples'
fig.savefig(os.path.join(FLAGS.outdir, title))
# fig.show()
plt.close()
return fig
def get_activation_fn(activation):
"""
Returns the specified tensorflow activation function.
"""
if activation == 'tanh':
return tf.tanh
elif activation == 'sigmoid':
return tf.sigmoid
elif activation == 'softplus':
return tf.nn.softplus
elif activation == 'leakyrelu':
return tf.nn.leaky_relu
else:
return tf.nn.relu # default
def get_optimizer_fn(optimizer):
"""
Returns the specified tensorflow optimizer.
"""
if optimizer == 'sgd':
return tf.train.GradientDescentOptimizer
elif optimizer == 'momentum':
return tf.train.MomentumOptimizer
elif optimizer == 'rmsprop':
return tf.train.RMSPropOptimizer
else:
return tf.train.AdamOptimizer # default