-
Notifications
You must be signed in to change notification settings - Fork 93
/
Copy pathex9_4_1_my_layer.py
45 lines (33 loc) · 1.12 KB
/
ex9_4_1_my_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import keras
from keras import backend as K
from keras.engine.topology import Layer
import numpy as np
from keras import initializers
igu = initializers.get('glorot_uniform')
iz = initializers.get('zeros')
class SFC(Layer):
# FC: Simplified fully connected layer
def __init__(self, No, **kwargs):
self.No = No
super().__init__(**kwargs)
def build(self, inshape):
self.w = self.add_weight("w", (inshape[1], self.No),
initializer=igu)
self.b = self.add_weight("b", (self.No,),
initializer=iz)
super().build(inshape)
def call(self, x):
return K.dot(x, self.w) + self.b
def compute_output_shape(self, inshape):
return (inshape[0], self.No)
def main():
x = np.array([0, 1, 2, 3, 4])
y = x * 2 + 1
model = keras.models.Sequential()
model.add(SFC(1, input_shape=(1,)))
model.compile('SGD', 'mse')
model.fit(x[:2], y[:2], epochs=1000, verbose=0)
print('Targets:',y[2:])
print('Predictions:', model.predict(x[2:]).flatten())
if __name__ == '__main__':
main()