-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
303 lines (261 loc) · 12.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import tensorflow as tf
from keras.engine import Layer, InputSpec
from keras import backend as K
from keras.layers import Conv2D, Dense, Flatten
from keras.layers.pooling import _GlobalPooling2D
class AdaInstanceNormalization(Layer):
""" Input b and g should be 1x1xC """
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-5,
center=True,
scale=True,
**kwargs):
super(AdaInstanceNormalization, self).__init__(**kwargs)
self.axis = axis
self.momentum = momentum
self.epsilon = epsilon
self.center = center
self.scale = scale
def build(self, input_shape):
super(AdaInstanceNormalization, self).build(input_shape)
def call(self, inputs, training=None):
content, style_mean, style_std = inputs[0], inputs[1], inputs[2]
input_shape = K.int_shape(content)
reduction_axes = [1, 2]
content_mean, content_var = tf.nn.moments(content, reduction_axes, keep_dims=True)
stylized_content = tf.nn.batch_normalization(content,
content_mean,
content_var,
style_mean,
style_std,
self.epsilon)
return stylized_content
def get_config(self):
config = {
'axis': self.axis,
'momentum': self.momentum,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale
}
base_config = super(AdaInstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
content_shape = input_shape[0]
return content_shape
class DenseSN(Dense):
""" Borrowed from https://github.com/IShengFang/SpectralNormalizationKeras """
def __init__(self, units, **kwargs):
super(DenseSN, self).__init__(units, **kwargs)
def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]
self.kernel = self.add_weight(shape=(input_dim, self.units),
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(shape=(self.units,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
initializer=initializers.RandomNormal(0, 1),
name='sn',
trainable=False)
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True
def call(self, inputs, training=None):
def _l2normalize(v, eps=1e-5):
return v / (K.sum(v ** 2) ** 0.5 + eps)
def power_iteration(W, u):
_u = u
_v = _l2normalize(K.dot(_u, K.transpose(W)))
_u = _l2normalize(K.dot(_v, W))
return _u, _v
W_shape = self.kernel.shape.as_list()
#Flatten the Tensor
W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
_u, _v = power_iteration(W_reshaped, self.u)
#Calculate Sigma
sigma=K.dot(_v, W_reshaped)
sigma=K.dot(sigma, K.transpose(_u))
#normalize it
W_bar = W_reshaped / sigma
#reshape weight tensor
if training in {0, False}:
W_bar = K.reshape(W_bar, W_shape)
else:
with tf.control_dependencies([self.u.assign(_u)]):
W_bar = K.reshape(W_bar, W_shape)
output = K.dot(inputs, W_bar)
if self.use_bias:
output = K.bias_add(output, self.bias, data_format='channels_last')
if self.activation is not None:
output = self.activation(output)
return output
class ConvSN2D(Conv2D):
""" Borrowed from https://github.com/IShengFang/SpectralNormalizationKeras """
def __init__(self, filters, kernel_size, **kwargs):
super(ConvSN2D, self).__init__(filters, kernel_size, **kwargs)
def build(self, input_shape):
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.kernel = self.add_weight(shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True)
if self.use_bias:
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True)
else:
self.bias = None
self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
initializer=initializers.RandomNormal(0, 1),
name='sn',
trainable=False)
# Set input spec.
self.input_spec = InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
self.built = True
def call(self, inputs, training=None):
def _l2normalize(v, eps=1e-5):
return v / K.sqrt(K.sum(K.square(v)) + eps)
def power_iteration(W, u):
#Accroding the paper, we only need to do power iteration one time.
_u = u
_v = _l2normalize(K.dot(_u, K.transpose(W)))
_u = _l2normalize(K.dot(_v, W))
return _u, _v
#Spectral Normalization
W_shape = self.kernel.shape.as_list()
#Flatten the Tensor
W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
_u, _v = power_iteration(W_reshaped, self.u)
#Calculate Sigma
sigma = K.dot(_v, W_reshaped)
sigma = K.dot(sigma, K.transpose(_u))
#normalize it
W_bar = W_reshaped / sigma
#reshape weight tensor
if training in {0, False}:
W_bar = K.reshape(W_bar, W_shape)
else:
with tf.control_dependencies([self.u.assign(_u)]):
W_bar = K.reshape(W_bar, W_shape)
outputs = K.conv2d(
inputs,
W_bar,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self.use_bias:
outputs = K.bias_add(
outputs,
self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
class GlobalSumPooling2D(_GlobalPooling2D):
"""Global sum pooling operation for spatial data.
# Arguments
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
# Input shape
- If `data_format='channels_last'`:
4D tensor with shape:
`(batch_size, rows, cols, channels)`
- If `data_format='channels_first'`:
4D tensor with shape:
`(batch_size, channels, rows, cols)`
# Output shape
2D tensor with shape:
`(batch_size, channels)`
"""
def call(self, inputs):
if self.data_format == 'channels_last':
return K.sum(inputs, axis=[1, 2])
else:
return K.sum(inputs, axis=[2, 3])
from keras.engine.topology import Layer
from keras import initializers
class Bias(Layer):
"""
Custom keras layer that simply adds a scalar bias to each location in the input
"""
def __init__(self, initializer='zeros', **kwargs):
super(Bias, self).__init__(**kwargs)
self.initializer = initializer
def build(self, input_shape):
self.bias = self.add_weight(name='{}_W'.format(self.name), shape=(input_shape[-1],), initializer=self.initializer)
super(Bias, self).build(input_shape)
def call(self, x):
return K.bias_add(x, self.bias, data_format='channels_last')
class SelfAttention(Layer):
def __init__(self, channels, kernel_initializer='he_normal', **kwargs):
super(SelfAttention, self).__init__(**kwargs)
self.channels = channels
self.filters_f_g = self.channels // 8
self.filters_h = self.channels
self.kernel_initializer = kernel_initializer
def build(self, input_shape):
kernel_shape_f_g = (1, 1) + (self.channels, self.filters_f_g)
kernel_shape_h = (1, 1) + (self.channels, self.filters_h)
# Create a trainable weight variable for this layer:
self.gamma = self.add_weight(name='gamma', shape=[1], initializer='zeros', trainable=True)
self.kernel_f = self.add_weight(shape=kernel_shape_f_g, initializer=self.kernel_initializer, name='kernel_f')
self.kernel_g = self.add_weight(shape=kernel_shape_f_g, initializer=self.kernel_initializer, name='kernel_g')
self.kernel_h = self.add_weight(shape=kernel_shape_h, initializer=self.kernel_initializer, name='kernel_h')
self.bias_f = self.add_weight(shape=(self.filters_f_g,), initializer='zeros', name='bias_f')
self.bias_g = self.add_weight(shape=(self.filters_f_g,), initializer='zeros', name='bias_g')
self.bias_h = self.add_weight(shape=(self.filters_h,), initializer='zeros', name='bias_h')
# Set input spec.
self.input_spec = InputSpec(ndim=4, axes={3: input_shape[-1]})
super(SelfAttention, self).build(input_shape)
def call(self, x):
def hw_flatten(x):
x_shape = K.shape(x)
return K.reshape(x, shape=[x_shape[0], x_shape[1] * x_shape[2], x_shape[-1]])
f = K.conv2d(x, kernel=self.kernel_f, strides=(1, 1), padding='same') # [bs, h, w, c']
f = K.bias_add(f, self.bias_f)
g = K.conv2d(x, kernel=self.kernel_g, strides=(1, 1), padding='same') # [bs, h, w, c']
g = K.bias_add(g, self.bias_g)
h = K.conv2d(x, kernel=self.kernel_h, strides=(1, 1), padding='same') # [bs, h, w, c]
h = K.bias_add(h, self.bias_h)
s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
beta = K.softmax(s, axis=-1) # attention map
o = K.batch_dot(beta, hw_flatten(h)) # [bs, N, C]
o = K.reshape(o, shape=K.shape(x)) # [bs, h, w, C]
x += self.gamma * o
return x
def compute_output_shape(self, input_shape):
return input_shape