-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #40 from ECRL/dev
Bug fixes, enhancements
- Loading branch information
Showing
36 changed files
with
2,432 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from ecnet.server import Server | ||
__version__ = '3.2.1' | ||
__version__ = '3.2.2' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/models/mlp.py | ||
# v.3.2.1 | ||
# v.3.2.2 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "MultilayerPerceptron" (feed-forward neural network) class | ||
|
@@ -77,7 +77,7 @@ def add_layer(self, num_neurons: int, activation: str, | |
def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None, | ||
epochs: int=1500, lr: float=0.001, beta_1: float=0.9, | ||
beta_2: float=0.999, epsilon: float=0.0000001, decay: float=0.0, | ||
v: int=0): | ||
v: int=0, batch_size: int=32): | ||
'''Fits neural network to supplied inputs and targets | ||
Args: | ||
|
@@ -95,6 +95,7 @@ def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None, | |
epsilon (float): epsilon value for Adam optimizer | ||
decay (float): learning rate decay for Adam optimizer | ||
v (int): verbose training, `0` for no printing, `1` for printing | ||
batch_size (int): number of learning samples per batch | ||
''' | ||
|
||
self._model.compile( | ||
|
@@ -110,7 +111,7 @@ def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None, | |
) | ||
|
||
if v_x is not None and v_y is not None: | ||
self._model.fit( | ||
history = self._model.fit( | ||
l_x, | ||
l_y, | ||
validation_data=(v_x, v_y), | ||
|
@@ -122,14 +123,17 @@ def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None, | |
restore_best_weights=True | ||
)], | ||
epochs=epochs, | ||
verbose=v | ||
verbose=v, | ||
batch_size=batch_size | ||
) | ||
epochs = len(history.history['loss']) | ||
else: | ||
self._model.fit( | ||
l_x, | ||
l_y, | ||
epochs=epochs, | ||
verbose=v | ||
verbose=v, | ||
batch_size=batch_size | ||
) | ||
|
||
logger.log('debug', 'Training complete after {} epochs'.format(epochs), | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/server.py | ||
# v.3.2.1 | ||
# v.3.2.2 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "Server" class, which handles ECNet project creation, neural | ||
|
@@ -14,15 +14,14 @@ | |
|
||
# ECNet imports | ||
from ecnet.tasks.limit_inputs import limit_rforest | ||
from ecnet.tasks.remove_outliers import remove_outliers | ||
from ecnet.tasks.training import train_project | ||
from ecnet.tasks.tuning import tune_hyperparameters | ||
from ecnet.utils.data_utils import DataFrame, save_results | ||
from ecnet.utils.logging import logger | ||
from ecnet.utils.server_utils import create_project, default_config,\ | ||
get_candidate_path, get_error, get_y, open_config, open_df, open_project,\ | ||
resave_df, resave_model, save_config, save_df, save_project, train_model,\ | ||
use_model, use_project | ||
from ecnet.utils.server_utils import check_config, create_project,\ | ||
default_config, get_candidate_path, get_error, get_y, open_config,\ | ||
open_df, open_project, resave_df, resave_model, save_config, save_df,\ | ||
save_project, train_model, use_model, use_project | ||
|
||
|
||
class Server: | ||
|
@@ -51,6 +50,8 @@ def __init__(self, model_config: str='config.yml', prj_file: str=None, | |
if prj_file is not None: | ||
self._prj_name, self._num_pools, self._num_candidates, self._df,\ | ||
self._cf_file, self._vars = open_project(prj_file) | ||
check_config(self._vars) | ||
self._sets = self._df.package_sets() | ||
logger.log('info', 'Opened project {}'.format(prj_file), | ||
call_loc='INIT') | ||
return | ||
|
@@ -61,6 +62,7 @@ def __init__(self, model_config: str='config.yml', prj_file: str=None, | |
self._vars = {} | ||
try: | ||
self._vars.update(open_config(self._cf_file)) | ||
check_config(self._vars) | ||
except FileNotFoundError: | ||
logger.log('warn', '{} not found, generating default config' | ||
.format(model_config), call_loc='INIT') | ||
|
@@ -109,42 +111,31 @@ def create_project(self, project_name: str, num_pools: int=1, | |
logger.log('debug', 'Number of candidates/pool: {}'.format( | ||
num_candidates), call_loc='PROJECT') | ||
|
||
def remove_outliers(self, leaf_size: int=30, output_filename: str=None): | ||
'''Removes any outliers from the currently-loaded data using | ||
unsupervised outlier detection using local outlier factor | ||
Args: | ||
leaf_size (int): used by nearest-neighbor algorithm as the number | ||
of points at which to switch to brute force | ||
output_filename (str): if not None, database w/o outliers is saved | ||
here | ||
''' | ||
|
||
self._df = remove_outliers(self._df, leaf_size, self._num_processes) | ||
self._sets = self._df.package_sets() | ||
if output_filename is not None: | ||
self._df.save(output_filename) | ||
logger.log('info', 'Resulting database saved to {}'.format( | ||
output_filename), call_loc='OUTLIERS') | ||
|
||
def limit_inputs(self, limit_num: int, num_estimators: int=1000, | ||
output_filename: str=None): | ||
def limit_inputs(self, limit_num: int, num_estimators: int=None, | ||
output_filename: str=None, **kwargs) -> list: | ||
'''Selects `limit_num` influential input parameters using random | ||
forest regression | ||
Args: | ||
limit_num (int): desired number of inputs | ||
num_estimators (int): number of trees in the RFR algorithm | ||
num_estimators (int): number of trees in the RFR algorithm; | ||
defaults to the total number of inputs | ||
output_filename (str): if not None, new limited database is saved | ||
here | ||
**kwargs: any argument accepted by | ||
sklearn.ensemble.RandomForestRegressor | ||
Returns: | ||
list: [(feature, importance), ..., (feature, importance)] | ||
''' | ||
|
||
self._df = limit_rforest( | ||
result = limit_rforest( | ||
self._df, | ||
limit_num, | ||
num_estimators, | ||
self._num_processes | ||
) | ||
self._df.set_inputs([r[0] for r in result]) | ||
self._sets = self._df.package_sets() | ||
if output_filename is not None: | ||
self._df.save(output_filename) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
import ecnet.tasks.limit_inputs | ||
import ecnet.tasks.remove_outliers | ||
import ecnet.tasks.training | ||
import ecnet.tasks.tuning |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,80 +2,65 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/limit_inputs.py | ||
# v.3.2.1 | ||
# v.3.2.2 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for selecting influential input parameters | ||
# | ||
|
||
# Stdlib. imports | ||
from copy import deepcopy | ||
|
||
# 3rd party imports | ||
from ditto_lib.itemcollection import ItemCollection | ||
from ditto_lib.tasks.random_forest import random_forest_regressor | ||
from ditto_lib.utils.logging import logger as ditto_logger | ||
from ditto_lib.utils.dataframe import Attribute | ||
from sklearn.ensemble import RandomForestRegressor | ||
from numpy import concatenate, ravel | ||
|
||
# ECNet imports | ||
from ecnet.utils.data_utils import DataFrame | ||
from ecnet.utils.logging import logger | ||
|
||
|
||
def limit_rforest(df: DataFrame, limit_num: int, num_estimators: int=1000, | ||
num_processes: int=1) -> DataFrame: | ||
def limit_rforest(df: DataFrame, limit_num: int, num_estimators: int=None, | ||
num_processes: int=1, **kwargs) -> list: | ||
'''Uses random forest regression to select input parameters | ||
Args: | ||
df (ecnet.utils.data_utils.DataFrame): loaded data | ||
limit_num (int): desired number of input parameters | ||
num_estimators (int): number of trees used by RFR algorithm | ||
num_processes (int): number of parallel jobs for RFR algorithm | ||
**kwargs: any argument accepted by | ||
sklearn.ensemble.RandomForestRegressor | ||
Returns: | ||
ecnet.utils.data_utils.DataFrame: limited data | ||
list: [(feature, importance), ..., (feature, importance)] | ||
''' | ||
|
||
logger.log('info', 'Finding {} most influential input parameters' | ||
.format(limit_num), call_loc='LIMIT') | ||
|
||
pd = df.package_sets() | ||
X = concatenate((pd.learn_x, pd.valid_x, pd.test_x)) | ||
y = ravel(concatenate((pd.learn_y, pd.valid_y, pd.test_y))) | ||
|
||
if num_estimators is None: | ||
num_estimators = len(X[0]) | ||
|
||
logger.log('debug', 'Number of estimators: {}'.format(num_estimators), | ||
call_loc='LIMIT') | ||
|
||
ditto_logger.stream_level = logger.stream_level | ||
if logger.file_level != 'disable': | ||
ditto_logger.log_dir = logger.log_dir | ||
ditto_logger.file_level = logger.file_level | ||
ditto_logger.default_call_loc('LIMIT') | ||
item_collection = ItemCollection(df._filename) | ||
for inp_name in df._input_names: | ||
item_collection.add_attribute(Attribute(inp_name)) | ||
for pt in df.data_points: | ||
item_collection.add_item( | ||
pt.id, | ||
deepcopy([getattr(pt, i) for i in df._input_names]) | ||
) | ||
for tar_name in df._target_names: | ||
item_collection.add_attribute(Attribute(tar_name, is_descriptor=False)) | ||
for pt in df.data_points: | ||
target_vals = [getattr(pt, t) for t in df._target_names] | ||
for idx, tar in enumerate(target_vals): | ||
item_collection.set_item_attribute( | ||
pt.id, tar, df._target_names[idx] | ||
) | ||
item_collection.strip() | ||
params = [param[0] for param in random_forest_regressor( | ||
item_collection.dataframe, | ||
target_attribute=df._target_names[0], | ||
n_components=limit_num, | ||
regr = RandomForestRegressor( | ||
n_jobs=num_processes, | ||
n_estimators=num_estimators, | ||
n_jobs=num_processes | ||
)] | ||
for idx, param in enumerate(params): | ||
for tn in df._target_names: | ||
if tn == param: | ||
del params[idx] | ||
break | ||
|
||
logger.log('debug', 'Selected parameters: {}'.format(params), | ||
call_loc='LIMIT') | ||
df.set_inputs(params) | ||
return df | ||
**kwargs | ||
) | ||
regr.fit(X, y) | ||
importances = regr.feature_importances_ | ||
result = [] | ||
for idx, name in enumerate(df._input_names): | ||
result.append((name, importances[idx])) | ||
result = sorted(result, key=lambda t: t[1], reverse=True)[:limit_num] | ||
logger.log('debug', 'Selected parameters: {}'.format( | ||
[r[0] for r in result] | ||
), call_loc='LIMIT') | ||
return result |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/training.py | ||
# v.3.2.1 | ||
# v.3.2.2 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains function for project training (multiprocessed training) | ||
|
Oops, something went wrong.