diff --git a/ecnet/models/mlp.py b/ecnet/models/mlp.py index c33b287..ee44557 100644 --- a/ecnet/models/mlp.py +++ b/ecnet/models/mlp.py @@ -19,6 +19,7 @@ stderr = sys.stderr sys.stderr = open(devnull, 'w') from keras.backend import clear_session, reset_uids +from keras.callbacks import EarlyStopping from keras.layers import Dense from keras.losses import mean_squared_error from keras.metrics import mae @@ -109,24 +110,20 @@ def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None, ) if v_x is not None and v_y is not None: - valid_mae_lowest = self._model.evaluate(v_x, v_y, verbose=v)[1] - steps = int(epochs / 250) - for e in range(steps): - h = self._model.fit( - l_x, - l_y, - validation_data=(v_x, v_y), - epochs=250, - verbose=v - ) - valid_mae = h.history['val_mean_absolute_error'][-1] - if valid_mae < valid_mae_lowest: - valid_mae_lowest = valid_mae - elif valid_mae > (valid_mae_lowest + 0.05 * valid_mae_lowest): - logger.log('debug', 'Validation cutoff after {} epochs' - .format(e * 250), call_loc='MLP') - return - + self._model.fit( + l_x, + l_y, + validation_data=(v_x, v_y), + callbacks=[EarlyStopping( + monitor='val_loss', + patience=250, + verbose=v, + mode='min', + restore_best_weights=True + )], + epochs=epochs, + verbose=v + ) else: self._model.fit( l_x, @@ -134,6 +131,7 @@ def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None, epochs=epochs, verbose=v ) + logger.log('debug', 'Training complete after {} epochs'.format(epochs), call_loc='MLP')