-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatagen.py
47 lines (36 loc) · 1.34 KB
/
datagen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from utils import setUpLogger
logger = setUpLogger()
from pad_autoencoder import DataPipeline, autoencoder, getIds, inputModel, simpleInputs, getIds
import numpy as np
from tensorflow import keras
from scraperTools import getDataFromFile
import random
##### Initalise Model #############
modelFileName = 'model-v'
Input = keras.layers.Input
Model = keras.Model
Tokenizer = keras.preprocessing.text.Tokenizer
ids = getIds()
batchSize = 32
validation_split = 0.1
folderPath = "./model data/pads/"
cutoffIndex = int(len(ids)*validation_split)
opt = keras.optimizers.Adam(learning_rate=0.001, clipvalue=0.5)
inputs = simpleInputs()
pad_out = autoencoder(inputs)
model = Model(inputs, pad_out)
dataPipeline = DataPipeline(ids, folderPath, batchSize=batchSize, validation_split=validation_split)
trainData, valData = dataPipeline.dataGenerators()
dataPipeline.save('dataPipeline_metaData.data')
# keras.utils.plot_model(model, "model.png", show_shapes=True)
lower = 2 #set to -1 to start new model
upper = 10
if lower == -1:
model.compile(optimizer=opt, loss='mse')
else:
model = keras.models.load_model(modelFileName + str(lower))
for i in range(lower + 1, upper):
logger.info('Beginning iteration %d', i)
model.fit(x=trainData, validation_data = valData, epochs = 10)
model.save(modelFileName + str(i))
logger.info('Finished iteration %d', i)