-
-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathmd_lstm.py
240 lines (191 loc) · 9.93 KB
/
md_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import tensorflow as tf
from tensorflow.contrib.rnn import RNNCell, LSTMStateTuple
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _linear
from tensorflow.python.ops.rnn import dynamic_rnn
def ln(tensor, scope=None, epsilon=1e-5):
""" Layer normalizes a 2D tensor along its second axis """
assert (len(tensor.get_shape()) == 2)
m, v = tf.nn.moments(tensor, [1], keep_dims=True)
if not isinstance(scope, str):
scope = ''
with tf.variable_scope(scope + 'layer_norm'):
scale = tf.get_variable('scale',
shape=[tensor.get_shape()[1]],
initializer=tf.constant_initializer(1))
shift = tf.get_variable('shift',
shape=[tensor.get_shape()[1]],
initializer=tf.constant_initializer(0))
ln_initial = (tensor - m) / tf.sqrt(v + epsilon)
return ln_initial * scale + shift
class MultiDimensionalLSTMCell(RNNCell):
"""
Adapted from TF's BasicLSTMCell to use Layer Normalization.
Note that state_is_tuple is always True.
"""
def __init__(self, num_units, forget_bias=0.0, activation=tf.nn.tanh):
self._num_units = num_units
self._forget_bias = forget_bias
self._activation = activation
@property
def state_size(self):
return LSTMStateTuple(self._num_units, self._num_units)
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM).
@param: inputs (batch,n)
@param state: the states and hidden unit of the two cells
"""
with tf.variable_scope(scope or type(self).__name__):
c1, c2, h1, h2 = state
# change bias argument to False since LN will add bias via shift
concat = _linear([inputs, h1, h2], 5 * self._num_units, False)
i, j, f1, f2, o = tf.split(value=concat, num_or_size_splits=5, axis=1)
# add layer normalization to each gate
i = ln(i, scope='i/')
j = ln(j, scope='j/')
f1 = ln(f1, scope='f1/')
f2 = ln(f2, scope='f2/')
o = ln(o, scope='o/')
new_c = (c1 * tf.nn.sigmoid(f1 + self._forget_bias) +
c2 * tf.nn.sigmoid(f2 + self._forget_bias) + tf.nn.sigmoid(i) *
self._activation(j))
# add layer_normalization in calculation of new hidden state
new_h = self._activation(ln(new_c, scope='new_h/')) * tf.nn.sigmoid(o)
new_state = LSTMStateTuple(new_c, new_h)
return new_h, new_state
def multi_dimensional_rnn_while_loop(rnn_size, input_data, sh, dims=None, scope_n="layer1"):
"""Implements naive multi dimension recurrent neural networks
@param rnn_size: the hidden units
@param input_data: the data to process of shape [batch,h,w,channels]
@param sh: [height,width] of the windows
@param dims: dimensions to reverse the input data,eg.
dims=[False,True,True,False] => true means reverse dimension
@param scope_n : the scope
returns [batch,h/sh[0],w/sh[1],rnn_size] the output of the lstm
"""
with tf.variable_scope("MultiDimensionalLSTMCell-" + scope_n):
# Create multidimensional cell with selected size
cell = MultiDimensionalLSTMCell(rnn_size)
# Get the shape of the input (batch_size, x, y, channels)
batch_size, X_dim, Y_dim, channels = input_data.shape.as_list()
# Window size
X_win, Y_win = sh
# Get the runtime batch size
batch_size_runtime = tf.shape(input_data)[0]
# If the input cannot be exactly sampled by the window, we patch it with zeros
if X_dim % X_win != 0:
# Get offset size
offset = tf.zeros([batch_size_runtime, X_win - (X_dim % X_win), Y_dim, channels])
# Concatenate X dimension
input_data = tf.concat(axis=1, values=[input_data, offset])
# Update shape value
X_dim = input_data.shape[1].value
# The same but for Y axis
if Y_dim % Y_win != 0:
# Get offset size
offset = tf.zeros([batch_size_runtime, X_dim, Y_win - (Y_dim % Y_win), channels])
# Concatenate Y dimension
input_data = tf.concat(axis=2, values=[input_data, offset])
# Update shape value
Y_dim = input_data.shape[2].value
# Get the steps to perform in X and Y axis
h, w = int(X_dim / X_win), int(Y_dim / Y_win)
# Get the number of features (total number of input values per step)
features = Y_win * X_win * channels
# Reshape input data to a tensor containing the step indexes and features inputs
# The batch size is inferred from the tensor size
x = tf.reshape(input_data, [batch_size_runtime, h, w, features])
# Reverse the selected dimensions
if dims is not None:
assert dims[0] is False and dims[3] is False
x = tf.reverse(x, dims)
# Reorder inputs to (h, w, batch_size, features)
x = tf.transpose(x, [1, 2, 0, 3])
# Reshape to a one dimensional tensor of (h*w*batch_size , features)
x = tf.reshape(x, [-1, features])
# Split tensor into h*w tensors of size (batch_size , features)
x = tf.split(axis=0, num_or_size_splits=h * w, value=x)
# Create an input tensor array (literally an array of tensors) to use inside the loop
inputs_ta = tf.TensorArray(dtype=tf.float32, size=h * w, name='input_ta')
# Unstack the input X in the tensor array
inputs_ta = inputs_ta.unstack(x)
# Create an input tensor array for the states
states_ta = tf.TensorArray(dtype=tf.float32, size=h * w + 1, name='state_ta', clear_after_read=False)
# And an other for the output
outputs_ta = tf.TensorArray(dtype=tf.float32, size=h * w, name='output_ta')
# initial cell hidden states
# Write to the last position of the array, the LSTMStateTuple filled with zeros
states_ta = states_ta.write(h * w, LSTMStateTuple(tf.zeros([batch_size_runtime, rnn_size], tf.float32),
tf.zeros([batch_size_runtime, rnn_size], tf.float32)))
# Function to get the sample skipping one row
def get_up(t_, w_):
return t_ - tf.constant(w_)
# Function to get the previous sample
def get_last(t_, w_):
return t_ - tf.constant(1)
# Controls the initial index
time = tf.constant(0)
zero = tf.constant(0)
# Body of the while loop operation that applies the MD LSTM
def body(time_, outputs_ta_, states_ta_):
# If the current position is less or equal than the width, we are in the first row
# and we need to read the zero state we added in row (h*w).
# If not, get the sample located at a width distance.
state_up = tf.cond(tf.less_equal(time_, tf.constant(w)),
lambda: states_ta_.read(h * w),
lambda: states_ta_.read(get_up(time_, w)))
# If it is the first step we read the zero state if not we read the immediate last
state_last = tf.cond(tf.less(zero, tf.mod(time_, tf.constant(w))),
lambda: states_ta_.read(get_last(time_, w)),
lambda: states_ta_.read(h * w))
# We build the input state in both dimensions
current_state = state_up[0], state_last[0], state_up[1], state_last[1]
# Now we calculate the output state and the cell output
out, state = cell(inputs_ta.read(time_), current_state)
# We write the output to the output tensor array
outputs_ta_ = outputs_ta_.write(time_, out)
# And save the output state to the state tensor array
states_ta_ = states_ta_.write(time_, state)
# Return outputs and incremented time step
return time_ + 1, outputs_ta_, states_ta_
# Loop output condition. The index, given by the time, should be less than the
# total number of steps defined within the image
def condition(time_, outputs_ta_, states_ta_):
return tf.less(time_, tf.constant(h * w))
# Run the looped operation
result, outputs_ta, states_ta = tf.while_loop(condition, body, [time, outputs_ta, states_ta],
parallel_iterations=1)
# Extract the output tensors from the processesed tensor array
outputs = outputs_ta.stack()
states = states_ta.stack()
# Reshape outputs to match the shape of the input
y = tf.reshape(outputs, [h, w, batch_size_runtime, rnn_size])
# Reorder te dimensions to match the input
y = tf.transpose(y, [2, 0, 1, 3])
# Reverse if selected
if dims is not None:
y = tf.reverse(y, dims)
# Return the output and the inner states
return y, states
def horizontal_standard_lstm(input_data, rnn_size):
# input is (b, h, w, c)
b, h, w, c = input_data.shape.as_list()
# transpose = swap h and w.
new_input_data = tf.reshape(input_data, (b * h, w, c)) # horizontal.
rnn_out, _ = dynamic_rnn(tf.contrib.rnn.LSTMCell(rnn_size),
inputs=new_input_data,
dtype=tf.float32)
rnn_out = tf.reshape(rnn_out, (b, h, w, rnn_size))
return rnn_out
def snake_standard_lstm(input_data, rnn_size):
# input is (b, h, w, c)
b, h, w, c = input_data.shape.as_list()
# transpose = swap h and w.
new_input_data = tf.reshape(input_data, (b, w * h, c)) # snake.
rnn_out, _ = dynamic_rnn(tf.contrib.rnn.LSTMCell(rnn_size),
inputs=new_input_data,
dtype=tf.float32)
rnn_out = tf.reshape(rnn_out, (b, h, w, rnn_size))
return rnn_out