Skip to content

Commit

Permalink
Early stopping callback added for termination based on validation per…
Browse files Browse the repository at this point in the history
…formance
  • Loading branch information
Kessler authored and Kessler committed Jun 19, 2019
1 parent 0f56fa9 commit 370abcd
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions ecnet/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
stderr = sys.stderr
sys.stderr = open(devnull, 'w')
from keras.backend import clear_session, reset_uids
from keras.callbacks import EarlyStopping
from keras.layers import Dense
from keras.losses import mean_squared_error
from keras.metrics import mae
Expand Down Expand Up @@ -109,31 +110,28 @@ def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None,
)

if v_x is not None and v_y is not None:
valid_mae_lowest = self._model.evaluate(v_x, v_y, verbose=v)[1]
steps = int(epochs / 250)
for e in range(steps):
h = self._model.fit(
l_x,
l_y,
validation_data=(v_x, v_y),
epochs=250,
verbose=v
)
valid_mae = h.history['val_mean_absolute_error'][-1]
if valid_mae < valid_mae_lowest:
valid_mae_lowest = valid_mae
elif valid_mae > (valid_mae_lowest + 0.05 * valid_mae_lowest):
logger.log('debug', 'Validation cutoff after {} epochs'
.format(e * 250), call_loc='MLP')
return

self._model.fit(
l_x,
l_y,
validation_data=(v_x, v_y),
callbacks=[EarlyStopping(
monitor='val_loss',
patience=250,
verbose=v,
mode='min',
restore_best_weights=True
)],
epochs=epochs,
verbose=v
)
else:
self._model.fit(
l_x,
l_y,
epochs=epochs,
verbose=v
)

logger.log('debug', 'Training complete after {} epochs'.format(epochs),
call_loc='MLP')

Expand Down

0 comments on commit 370abcd

Please sign in to comment.