-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathnn.py
203 lines (154 loc) · 7.28 KB
/
nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# coding=utf-8
# Copyright 2018 The THUMT Authors
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import code
from tensorflow.python import debug as tfdbg
def linear(inputs, output_size, bias, concat=True, dtype=None, scope=None):
"""
Linear layer
:param inputs: A Tensor or a list of Tensors with shape [batch, input_size]
:param output_size: An integer specify the output size
:param bias: a boolean value indicate whether to use bias term
:param concat: a boolean value indicate whether to concatenate all inputs
:param dtype: an instance of tf.DType, the default value is ``tf.float32''
:param scope: the scope of this layer, the default value is ``linear''
:returns: a Tensor with shape [batch, output_size]
:raises RuntimeError: raises ``RuntimeError'' when input sizes do not
compatible with each other
"""
#[batch, input_size] => [batch, output_size]
with tf.variable_scope(scope, default_name="linear", values=[inputs]):
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
input_size = [item.get_shape()[-1].value for item in inputs]
if len(inputs) != len(input_size):
raise RuntimeError("inputs and input_size unmatched!")
output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]],
axis=0)
# Flatten to 2D
inputs = [tf.reshape(inp, [-1, inp.shape[-1].value]) for inp in inputs]
results = []
if concat:
input_size = sum(input_size)
inputs = tf.concat(inputs, 1)
shape = [input_size, output_size]
matrix = tf.get_variable("matrix", shape, dtype=dtype)
results.append(tf.matmul(inputs, matrix))
else:
for i in range(len(input_size)):
shape = [input_size[i], output_size]
name = "matrix_%d" % i
matrix = tf.get_variable(name, shape, dtype=dtype)
results.append(tf.matmul(inputs[i], matrix))
output = tf.add_n(results)
if bias:
shape = [output_size]
bias = tf.get_variable("bias", shape, dtype=dtype)
output = tf.nn.bias_add(output, bias)
output = tf.reshape(output, output_shape)
return output
def maxout(inputs, output_size, maxpart=2, use_bias=True, concat=True,
dtype=None, scope=None):
"""
Maxout layer
:param inputs: see the corresponding description of ``linear''
:param output_size: see the corresponding description of ``linear''
:param maxpart: an integer, the default value is 2
:param use_bias: a boolean value indicate whether to use bias term
:param concat: concat all tensors if inputs is a list of tensors
:param dtype: an optional instance of tf.Dtype
:param scope: the scope of this layer, the default value is ``maxout''
:returns: a Tensor with shape [batch, output_size]
:raises RuntimeError: see the corresponding description of ``linear''
"""
# candidate = linear(inputs, output_size * maxpart, use_bias, concat,
# dtype=dtype, scope=scope or "maxout")
candidate = linear(inputs, output_size * maxpart, use_bias, concat,
dtype=dtype, scope=scope or "maxout")
print("candidate",candidate)
return candidate
shape = tf.concat([tf.shape(candidate)[:-1], [output_size, maxpart]],
axis=0)
print("shape",shape)
value = tf.reshape(candidate, shape)
print("value",value)
output = tf.reduce_max(value, -1)
print("output",output)
return output
def layer_norm(inputs, epsilon=1e-6, dtype=None, scope=None):
"""
Layer Normalization
:param inputs: A Tensor of shape [..., channel_size]
:param epsilon: A floating number
:param dtype: An optional instance of tf.DType
:param scope: An optional string
:returns: A Tensor with the same shape as inputs
"""
with tf.variable_scope(scope, default_name="layer_norm", values=[inputs],
dtype=dtype):
channel_size = inputs.get_shape().as_list()[-1]
scale = tf.get_variable("scale", shape=[channel_size],
initializer=tf.ones_initializer())
offset = tf.get_variable("offset", shape=[channel_size],
initializer=tf.zeros_initializer())
mean = tf.reduce_mean(inputs, axis=-1, keep_dims=True)
variance = tf.reduce_mean(tf.square(inputs - mean), axis=-1,
keep_dims=True)
norm_inputs = (inputs - mean) * tf.rsqrt(variance + epsilon)
return norm_inputs * scale + offset
def smoothed_softmax_cross_entropy_with_logits(**kwargs):
logits = kwargs.get("logits")
labels = kwargs.get("labels")
smoothing = kwargs.get("smoothing") or 0.0
normalize = kwargs.get("normalize")
scope = kwargs.get("scope")
if logits is None or labels is None:
raise ValueError("Both logits and labels must be provided")
with tf.name_scope(scope or "smoothed_softmax_cross_entropy_with_logits",
values=[logits, labels]):
labels = tf.reshape(labels, [-1])
if not smoothing:
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits,
labels=labels
)
return ce
# label smoothing
vocab_size = tf.shape(logits)[1]
n = tf.to_float(vocab_size - 1)
p = 1.0 - smoothing
q = smoothing / n
soft_targets = tf.one_hot(tf.cast(labels, tf.int32), depth=vocab_size,
on_value=p, off_value=q)
xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
labels=soft_targets)
# with tf.Session() as sess:
dx = tf.Print(soft_targets,[soft_targets ],message="debug soft_targets:",summarize=134)
if normalize is False:
return xentropy
# Normalizing constant is the best cross-entropy value with soft
# targets. We subtract it just for readability, makes no difference on
# learning
normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20))
#code.interact(local=locals())
return xentropy - normalizing
def smoothed_sigmoid_cross_entropy_with_logits(**kwargs):
logits = kwargs.get("logits")
labels = kwargs.get("labels")
tes = kwargs.get("tes")
scope = kwargs.get("scope")
if logits is None or labels is None:
raise ValueError("Both logits and labels must be provided")
with tf.name_scope(scope or "smoothed_sigmoid_cross_entropy_with_logits",
values=[logits, labels]):
labels = tf.reshape(labels, [-1])
# label smoothing
vocab_size = tf.shape(logits)[1]
multi_one_hot = tf.map_fn(lambda x: tf.one_hot(tf.cast(x, tf.int32), depth=vocab_size), tes, dtype = tf.float32)
soft_targets = tf.reduce_max(multi_one_hot, axis = 1)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=soft_targets)
return xentropy