-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from ECRL/dev
Better MLP validation, moved multiprocessing checks
- Loading branch information
Showing
16 changed files
with
46 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from ecnet.server import Server | ||
__version__ = '3.2.0' | ||
__version__ = '3.2.1' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/models/mlp.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "MultilayerPerceptron" (feed-forward neural network) class | ||
|
@@ -19,6 +19,7 @@ | |
stderr = sys.stderr | ||
sys.stderr = open(devnull, 'w') | ||
from keras.backend import clear_session, reset_uids | ||
from keras.callbacks import EarlyStopping | ||
from keras.layers import Dense | ||
from keras.losses import mean_squared_error | ||
from keras.metrics import mae | ||
|
@@ -109,31 +110,28 @@ def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None, | |
) | ||
|
||
if v_x is not None and v_y is not None: | ||
valid_mae_lowest = self._model.evaluate(v_x, v_y, verbose=v)[1] | ||
steps = int(epochs / 250) | ||
for e in range(steps): | ||
h = self._model.fit( | ||
l_x, | ||
l_y, | ||
validation_data=(v_x, v_y), | ||
epochs=250, | ||
verbose=v | ||
) | ||
valid_mae = h.history['val_mean_absolute_error'][-1] | ||
if valid_mae < valid_mae_lowest: | ||
valid_mae_lowest = valid_mae | ||
elif valid_mae > (valid_mae_lowest + 0.05 * valid_mae_lowest): | ||
logger.log('debug', 'Validation cutoff after {} epochs' | ||
.format(e * 250), call_loc='MLP') | ||
return | ||
|
||
self._model.fit( | ||
l_x, | ||
l_y, | ||
validation_data=(v_x, v_y), | ||
callbacks=[EarlyStopping( | ||
monitor='val_loss', | ||
patience=250, | ||
verbose=v, | ||
mode='min', | ||
restore_best_weights=True | ||
)], | ||
epochs=epochs, | ||
verbose=v | ||
) | ||
else: | ||
self._model.fit( | ||
l_x, | ||
l_y, | ||
epochs=epochs, | ||
verbose=v | ||
) | ||
|
||
logger.log('debug', 'Training complete after {} epochs'.format(epochs), | ||
call_loc='MLP') | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/server.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "Server" class, which handles ECNet project creation, neural | ||
|
@@ -24,10 +24,6 @@ | |
resave_df, resave_model, save_config, save_df, save_project, train_model,\ | ||
use_model, use_project | ||
|
||
# Stdlib imports | ||
from multiprocessing import set_start_method | ||
from os import name | ||
|
||
|
||
class Server: | ||
|
||
|
@@ -52,9 +48,6 @@ def __init__(self, model_config: str='config.yml', prj_file: str=None, | |
|
||
self._num_processes = num_processes | ||
|
||
if name != 'nt': | ||
set_start_method('spawn', force=True) | ||
|
||
if prj_file is not None: | ||
self._prj_name, self._num_pools, self._num_candidates, self._df,\ | ||
self._cf_file, self._vars = open_project(prj_file) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/limit_inputs.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for selecting influential input parameters | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/remove_outliers.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains function for removing outliers from ECNet DataFrame | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,17 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/server.py | ||
# v.3.2.0 | ||
# ecnet/tasks/training.py | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "Server" class, which handles ECNet project creation, neural | ||
# network model creation, data hand-off to models, prediction error | ||
# calculation, input parameter selection, hyperparameter tuning. | ||
# | ||
# For example scripts, refer to https://ecnet.readthedocs.io/en/latest/ | ||
# Contains function for project training (multiprocessed training) | ||
# | ||
|
||
# stdlib. imports | ||
from multiprocessing import Pool | ||
from multiprocessing import Pool, set_start_method | ||
from operator import itemgetter | ||
from os import name | ||
|
||
# ECNet imports | ||
from ecnet.utils.logging import logger | ||
|
@@ -48,6 +45,9 @@ def train_project(prj_name: str, num_pools: int, num_candidates: int, | |
num_processes (int): number of concurrent processes used to train | ||
''' | ||
|
||
if name != 'nt': | ||
set_start_method('spawn', force=True) | ||
|
||
logger.log('info', 'Training {}x{} models'.format( | ||
num_pools, num_candidates | ||
), call_loc='TRAIN') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,16 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/tuning.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/fitness functions for tuning hyperparameters | ||
# | ||
|
||
# stdlib. imports | ||
from multiprocessing import set_start_method | ||
from os import name | ||
|
||
# 3rd party imports | ||
from ecabc.abc import ABC | ||
|
||
|
@@ -42,6 +46,9 @@ def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int, | |
dict: tuned hyperparameters | ||
''' | ||
|
||
if name != 'nt': | ||
set_start_method('spawn', force=True) | ||
|
||
logger.log('info', 'Tuning architecture/learning hyperparameters', | ||
call_loc='TUNE') | ||
logger.log('debug', 'Arguments:\n\t| num_employers:\t{}\n\t| ' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/database.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for creating ECNet-formatted databases | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/plotting.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/classes for creating various plots | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/project.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for predicting data using pre-existing .prj files | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/data_utils.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/classes for loading data, saving data, saving results | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/error_utils.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for error calculations | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/logging.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains logger used by ECNet | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/server_utils.py | ||
# v.3.2.0 | ||
# v.3.2.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions used by ecnet.Server | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters