-
Notifications
You must be signed in to change notification settings - Fork 0
/
test1
31 lines (29 loc) · 1.59 KB
/
test1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import tensorflow_addons as tfa
import tensorflow as tf
num_units = 5
decoder_length = 21
input = np.array([[2, 1, 2, 1, 0]])
batch_size = input.shape[0]
outputs_size = 7
embedding = tf.keras.layers.Embedding(3, num_units, mask_zero=True)
encoder_output = embedding(input)
print(encoder_output._keras_mask)
cell = tf.keras.layers.LSTMCell(num_units)
attention_mechanism = tfa.seq2seq.BahdanauAttention(units=num_units, memory=encoder_output)
attention_cell = tfa.seq2seq.AttentionWrapper(cell, attention_mechanism,
attention_layer_size=num_units) # attention_layer_size
# sampler起到argmax的作用
sampler = tfa.seq2seq.sampler.TrainingSampler()
output_layer = tf.keras.layers.Dense(outputs_size) # 逻辑上是分类问题的类别数
decoder = tfa.seq2seq.BasicDecoder(cell=attention_cell, sampler=sampler, output_layer=output_layer)
decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=float)
# cell_state[0]是用于第一个查询
decoder_initial_state = decoder_initial_state.clone(
cell_state=[tf.random.normal(shape=(batch_size, num_units)), tf.random.normal(shape=(batch_size, num_units))])
decoder_input_ = tf.constant(value=0, shape=(batch_size, decoder_length, num_units), dtype=float)
# 对BahdanauAttention而言decoder(input=)是无意义的,但还是要给,希望不会影响性能
# sequence_length 对文本生成无意义
outputs, _, _ = decoder(inputs=decoder_input_, initial_state=decoder_initial_state,
sequence_length=tf.constant([decoder_length]))
# print(outputs)