-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnernet.py
105 lines (67 loc) · 2.61 KB
/
nernet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
import numpy as np
enc_model = Model([enc_inp], enc_states)
# decoder Model
decoder_state_input_h = Input(shape=(400,))
decoder_state_input_c = Input(shape=(400,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = dec_lstm(dec_embed ,
initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
dec_model = Model([dec_inp]+ decoder_states_inputs,
[decoder_outputs]+ decoder_states)
from keras.preprocessing.sequence import pad_sequences
print("##########################################")
print("# start chatting ver. 1.0 #")
print("##########################################")
prepro1 = ""
while prepro1 != 'q':
prepro1 = input("you : ")
## prepro1 = "Hello"
prepro1 = clean_text(prepro1)
## prepro1 = "hello"
prepro = [prepro1]
## prepro1 = ["hello"]
txt = []
for x in prepro:
# x = "hello"
lst = []
for y in x.split():
## y = "hello"
try:
lst.append(vocab[y])
## vocab['hello'] = 454
except:
lst.append(vocab['<OUT>'])
txt.append(lst)
## txt = [[454]]
txt = pad_sequences(txt, 13, padding='post')
## txt = [[454,0,0,0,.........13]]
stat = enc_model.predict( txt )
empty_target_seq = np.zeros( ( 1 , 1) )
## empty_target_seq = [0]
empty_target_seq[0, 0] = vocab['<SOS>']
## empty_target_seq = [255]
stop_condition = False
decoded_translation = ''
while not stop_condition :
dec_outputs , h, c= dec_model.predict([ empty_target_seq] + stat )
decoder_concat_input = dense(dec_outputs)
## decoder_concat_input = [0.1, 0.2, .4, .0, ...............]
sampled_word_index = np.argmax( decoder_concat_input[0, -1, :] )
## sampled_word_index = [2]
sampled_word = inv_vocab[sampled_word_index] + ' '
## inv_vocab[2] = 'hi'
## sampled_word = 'hi '
if sampled_word != '<EOS> ':
decoded_translation += sampled_word
if sampled_word == '<EOS> ' or len(decoded_translation.split()) > 13:
stop_condition = True
empty_target_seq = np.zeros( ( 1 , 1 ) )
empty_target_seq[ 0 , 0 ] = sampled_word_index
## <SOS> - > hi
## hi --> <EOS>
stat = [h, c]
print("chatbot attention : ", decoded_translation )
print("==============================================")