Skip to content

Commit

Permalink
Merge pull request #38 from ECRL/dev
Browse files Browse the repository at this point in the history
Better MLP validation, moved multiprocessing checks
  • Loading branch information
tjkessler authored Jun 22, 2019
2 parents 0f56fa9 + 86305f3 commit 807f592
Show file tree
Hide file tree
Showing 16 changed files with 46 additions and 48 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '3.2.0'
release = '3.2.1'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion ecnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ecnet.server import Server
__version__ = '3.2.0'
__version__ = '3.2.1'
36 changes: 17 additions & 19 deletions ecnet/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/models/mlp.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains the "MultilayerPerceptron" (feed-forward neural network) class
Expand All @@ -19,6 +19,7 @@
stderr = sys.stderr
sys.stderr = open(devnull, 'w')
from keras.backend import clear_session, reset_uids
from keras.callbacks import EarlyStopping
from keras.layers import Dense
from keras.losses import mean_squared_error
from keras.metrics import mae
Expand Down Expand Up @@ -109,31 +110,28 @@ def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None,
)

if v_x is not None and v_y is not None:
valid_mae_lowest = self._model.evaluate(v_x, v_y, verbose=v)[1]
steps = int(epochs / 250)
for e in range(steps):
h = self._model.fit(
l_x,
l_y,
validation_data=(v_x, v_y),
epochs=250,
verbose=v
)
valid_mae = h.history['val_mean_absolute_error'][-1]
if valid_mae < valid_mae_lowest:
valid_mae_lowest = valid_mae
elif valid_mae > (valid_mae_lowest + 0.05 * valid_mae_lowest):
logger.log('debug', 'Validation cutoff after {} epochs'
.format(e * 250), call_loc='MLP')
return

self._model.fit(
l_x,
l_y,
validation_data=(v_x, v_y),
callbacks=[EarlyStopping(
monitor='val_loss',
patience=250,
verbose=v,
mode='min',
restore_best_weights=True
)],
epochs=epochs,
verbose=v
)
else:
self._model.fit(
l_x,
l_y,
epochs=epochs,
verbose=v
)

logger.log('debug', 'Training complete after {} epochs'.format(epochs),
call_loc='MLP')

Expand Down
9 changes: 1 addition & 8 deletions ecnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/server.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains the "Server" class, which handles ECNet project creation, neural
Expand All @@ -24,10 +24,6 @@
resave_df, resave_model, save_config, save_df, save_project, train_model,\
use_model, use_project

# Stdlib imports
from multiprocessing import set_start_method
from os import name


class Server:

Expand All @@ -52,9 +48,6 @@ def __init__(self, model_config: str='config.yml', prj_file: str=None,

self._num_processes = num_processes

if name != 'nt':
set_start_method('spawn', force=True)

if prj_file is not None:
self._prj_name, self._num_pools, self._num_candidates, self._df,\
self._cf_file, self._vars = open_project(prj_file)
Expand Down
2 changes: 1 addition & 1 deletion ecnet/tasks/limit_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/limit_inputs.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for selecting influential input parameters
Expand Down
2 changes: 1 addition & 1 deletion ecnet/tasks/remove_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/remove_outliers.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains function for removing outliers from ECNet DataFrame
Expand Down
16 changes: 8 additions & 8 deletions ecnet/tasks/training.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# ecnet/server.py
# v.3.2.0
# ecnet/tasks/training.py
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains the "Server" class, which handles ECNet project creation, neural
# network model creation, data hand-off to models, prediction error
# calculation, input parameter selection, hyperparameter tuning.
#
# For example scripts, refer to https://ecnet.readthedocs.io/en/latest/
# Contains function for project training (multiprocessed training)
#

# stdlib. imports
from multiprocessing import Pool
from multiprocessing import Pool, set_start_method
from operator import itemgetter
from os import name

# ECNet imports
from ecnet.utils.logging import logger
Expand Down Expand Up @@ -48,6 +45,9 @@ def train_project(prj_name: str, num_pools: int, num_candidates: int,
num_processes (int): number of concurrent processes used to train
'''

if name != 'nt':
set_start_method('spawn', force=True)

logger.log('info', 'Training {}x{} models'.format(
num_pools, num_candidates
), call_loc='TRAIN')
Expand Down
9 changes: 8 additions & 1 deletion ecnet/tasks/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/tuning.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions/fitness functions for tuning hyperparameters
#

# stdlib. imports
from multiprocessing import set_start_method
from os import name

# 3rd party imports
from ecabc.abc import ABC

Expand Down Expand Up @@ -42,6 +46,9 @@ def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int,
dict: tuned hyperparameters
'''

if name != 'nt':
set_start_method('spawn', force=True)

logger.log('info', 'Tuning architecture/learning hyperparameters',
call_loc='TUNE')
logger.log('debug', 'Arguments:\n\t| num_employers:\t{}\n\t| '
Expand Down
2 changes: 1 addition & 1 deletion ecnet/tools/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tools/database.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for creating ECNet-formatted databases
Expand Down
2 changes: 1 addition & 1 deletion ecnet/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tools/plotting.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions/classes for creating various plots
Expand Down
2 changes: 1 addition & 1 deletion ecnet/tools/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tools/project.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for predicting data using pre-existing .prj files
Expand Down
2 changes: 1 addition & 1 deletion ecnet/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/utils/data_utils.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions/classes for loading data, saving data, saving results
Expand Down
2 changes: 1 addition & 1 deletion ecnet/utils/error_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/utils/error_utils.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for error calculations
Expand Down
2 changes: 1 addition & 1 deletion ecnet/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/utils/logging.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains logger used by ECNet
Expand Down
2 changes: 1 addition & 1 deletion ecnet/utils/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/utils/server_utils.py
# v.3.2.0
# v.3.2.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions used by ecnet.Server
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='ecnet',
version='3.2.0',
version='3.2.1',
description='UMass Lowell Energy and Combustion Research Laboratory Neural'
' Network Software',
url='http://github.com/tjkessler/ecnet',
Expand Down

0 comments on commit 807f592

Please sign in to comment.