diff --git a/train.py b/train.py index d8f609c..5a0089f 100644 --- a/train.py +++ b/train.py @@ -42,8 +42,8 @@ # Hyper-parameters. NUM_EPOCHS = 200 NUM_HIDDEN = 50 -NUM_LAYERS = 1 -BATCH_SIZE = 1 +NUM_LAYERS = 2 +BATCH_SIZE = 4 # Optimizer parameters. INITIAL_LEARNING_RATE = 1e-2 @@ -101,11 +101,12 @@ def main(argv): sequence_length_placeholder = tf.placeholder(tf.int32, [None]) # Defining the cell. - cell = tf.contrib.rnn.LSTMCell(NUM_HIDDEN, state_is_tuple=True) + def lstm_cell(): + return tf.contrib.rnn.LSTMCell(NUM_HIDDEN, state_is_tuple=True) # Stacking rnn cells. - stack = tf.contrib.rnn.MultiRNNCell([cell] * NUM_LAYERS, - state_is_tuple=True) + stack = tf.contrib.rnn.MultiRNNCell( + [lstm_cell() for _ in range(NUM_LAYERS)], state_is_tuple=True) # Creates a recurrent neural network. outputs, _ = tf.nn.dynamic_rnn(stack, inputs_placeholder, sequence_length_placeholder, dtype=tf.float32)