Skip to content

Commit

Permalink
Merge pull request #34 from ECRL/dev
Browse files Browse the repository at this point in the history
Type checking, improved unit testing
  • Loading branch information
tjkessler authored Jun 3, 2019
2 parents 34f96a5 + 0fa367a commit 2083718
Show file tree
Hide file tree
Showing 73 changed files with 746 additions and 3,915 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '3.0.0'
release = '3.1.2'


# -- General configuration ---------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/usage/installation.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Installation

### Prerequisites
- Have Python 3.5/3.6 installed
- Have Python 3.7 installed
- Have the ability to install Python packages

### Install via pip
Expand All @@ -15,7 +15,7 @@ Alternatively, in a Windows or virtualenv environment:
pip install ecnet
```

Note: if multiple Python releases are installed on your system (e.g. 2.7 and 3.6), you may need to execute the correct version of pip. For Python 3.6, change **"pip install ecnet"** to **"pip3 install ecnet"**.
Note: if multiple Python releases are installed on your system (e.g. 2.7 and 3.7), you may need to execute the correct version of pip. For Python 3.7, change **"pip install ecnet"** to **"pip3 install ecnet"**.

To update your version of ECNet to the latest release version, use:
```
Expand Down
2 changes: 1 addition & 1 deletion ecnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ecnet.server import Server
__version__ = '3.1.1'
__version__ = '3.1.2'
20 changes: 12 additions & 8 deletions ecnet/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/models/mlp.py
# v.3.1.1
# v.3.1.2
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains the "MultilayerPerceptron" (feed-forward neural network) class
Expand All @@ -15,6 +15,7 @@

# 3rd party imports
from tensorflow import get_default_graph, logging
from numpy import array
stderr = sys.stderr
sys.stderr = open(devnull, 'w')
from keras.backend import clear_session, reset_uids
Expand All @@ -36,7 +37,7 @@

class MultilayerPerceptron:

def __init__(self, filename='model.h5'):
def __init__(self, filename: str='model.h5'):
'''MultilayerPerceptron object: fits neural network to supplied inputs
and targets
Expand All @@ -54,7 +55,8 @@ def __init__(self, filename='model.h5'):
clear_session()
self._model = Sequential(name=filename.lower().replace('.h5', ''))

def add_layer(self, num_neurons, activation, input_dim=None):
def add_layer(self, num_neurons: int, activation: str,
input_dim: int=None):
'''Adds a fully-connected layer to the model
Args:
Expand All @@ -71,8 +73,10 @@ def add_layer(self, num_neurons, activation, input_dim=None):
input_shape=(input_dim,)
))

def fit(self, l_x, l_y, v_x=None, v_y=None, epochs=1500, lr=0.001,
beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, v=0):
def fit(self, l_x: array, l_y: array, v_x: array=None, v_y: array=None,
epochs: int=1500, lr: float=0.001, beta_1: float=0.9,
beta_2: float=0.999, epsilon: float=0.0000001, decay: float=0.0,
v: int=0):
'''Fits neural network to supplied inputs and targets
Args:
Expand Down Expand Up @@ -133,7 +137,7 @@ def fit(self, l_x, l_y, v_x=None, v_y=None, epochs=1500, lr=0.001,
logger.log('debug', 'Training complete after {} epochs'.format(epochs),
call_loc='MLP')

def use(self, x):
def use(self, x: array) -> array:
'''Uses neural network to predict values for supplied data
Args:
Expand All @@ -146,7 +150,7 @@ def use(self, x):
with get_default_graph().as_default():
return self._model.predict(x)

def save(self, filename=None):
def save(self, filename: str=None):
'''Saves neural network to .h5 file
filename (str): if None, uses MultilayerPerceptron._filename;
Expand All @@ -165,7 +169,7 @@ def save(self, filename=None):
logger.log('debug', 'Model saved to {}'.format(filename),
call_loc='MLP')

def load(self, filename=None):
def load(self, filename: str=None):
'''Loads neural network from .h5 file
Args:
Expand Down
39 changes: 22 additions & 17 deletions ecnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/server.py
# v.3.1.1
# v.3.1.2
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains the "Server" class, which handles ECNet project creation, neural
Expand Down Expand Up @@ -32,8 +32,8 @@

class Server:

def __init__(self, model_config='config.yml', prj_file=None,
num_processes=1):
def __init__(self, model_config: str='config.yml', prj_file: str=None,
num_processes: int=1):
'''Server object: handles data loading, model creation, data-to-model
hand-off, data input parameter selection, hyperparameter tuning
Expand Down Expand Up @@ -72,7 +72,8 @@ def __init__(self, model_config='config.yml', prj_file=None,
self._vars = default_config()
save_config(self._vars, self._cf_file)

def load_data(self, filename, random=False, split=None, normalize=False):
def load_data(self, filename: str, random: bool=False, split: list=None,
normalize: bool=False):
'''Loads data from an ECNet-formatted CSV database
Args:
Expand All @@ -92,7 +93,8 @@ def load_data(self, filename, random=False, split=None, normalize=False):
self._df.create_sets(random, split)
self._sets = self._df.package_sets()

def create_project(self, project_name, num_pools=1, num_candidates=1):
def create_project(self, project_name: str, num_pools: int=1,
num_candidates: int=1):
'''Creates folder hierarchy for a new project
Args:
Expand All @@ -116,7 +118,7 @@ def create_project(self, project_name, num_pools=1, num_candidates=1):
logger.log('debug', 'Number of candidates/pool: {}'.format(
num_candidates), call_loc='PROJECT')

def remove_outliers(self, leaf_size=30, output_filename=None):
def remove_outliers(self, leaf_size: int=30, output_filename: str=None):
'''Removes any outliers from the currently-loaded data using
unsupervised outlier detection using local outlier factor
Expand All @@ -136,8 +138,8 @@ def remove_outliers(self, leaf_size=30, output_filename=None):
logger.log('info', 'Resulting database saved to {}'.format(
output_filename), call_loc='OUTLIERS')

def limit_inputs(self, limit_num, num_estimators=1000,
output_filename=None):
def limit_inputs(self, limit_num: int, num_estimators: int=1000,
output_filename: str=None):
'''Selects `limit_num` influential input parameters using random
forest regression
Expand All @@ -163,9 +165,10 @@ def limit_inputs(self, limit_num, num_estimators=1000,
logger.log('info', 'Resulting database saved to {}'.format(
output_filename), call_loc='LIMIT')

def tune_hyperparameters(self, num_employers, num_iterations,
shuffle=None, split=None, validate=True,
eval_set=None, eval_fn='rmse'):
def tune_hyperparameters(self, num_employers: int, num_iterations: int,
shuffle: bool=None, split: list=None,
validate: bool=True, eval_set: str=None,
eval_fn: str='rmse'):
'''Tunes neural network learning hyperparameters using an artificial
bee colony algorithm; tuned hyperparameters are saved to Server's
model configuration file
Expand Down Expand Up @@ -206,8 +209,9 @@ def tune_hyperparameters(self, num_employers, num_iterations,
)
save_config(self._vars, self._cf_file)

def train(self, shuffle=None, split=None, retrain=False,
validate=False, selection_set=None, selection_fn='rmse'):
def train(self, shuffle: str=None, split: list=None, retrain: bool=False,
validate: bool=False, selection_set: str=None,
selection_fn: str='rmse'):
'''Trains neural network(s) using currently-loaded data; single NN if
no project is created, all candidates if created
Expand Down Expand Up @@ -307,7 +311,7 @@ def train(self, shuffle=None, split=None, retrain=False,
pool_fp.replace('model.h5', 'data.d')
)

def use(self, dset=None, output_filename=None):
def use(self, dset: str=None, output_filename: str=None):
'''Uses trained neural network(s) to predict for specified set; single
NN if no project created, best pool candidates if created
Expand Down Expand Up @@ -341,7 +345,7 @@ def use(self, dset=None, output_filename=None):
call_loc='USE')
return results

def errors(self, *args, dset=None):
def errors(self, *args, dset: str=None):
'''Obtains various errors for specified set
Args:
Expand All @@ -365,7 +369,8 @@ def errors(self, *args, dset=None):
logger.log('debug', 'Errors: {}'.format(errors), call_loc='ERRORS')
return errors

def save_project(self, filename=None, clean_up=True, del_candidates=False):
def save_project(self, filename: str=None, clean_up: bool=True,
del_candidates: bool=False):
'''Saves current state of project to a .prj file
Args:
Expand Down Expand Up @@ -400,7 +405,7 @@ def save_project(self, filename=None, clean_up=True, del_candidates=False):
logger.log('info', 'Project saved to {}'.format(save_path),
call_loc='PROJECT')

def _open_project(self, prj_file):
def _open_project(self, prj_file: str):
'''Private method: if project file specified on Server.__init__, loads
the project
Expand Down
6 changes: 4 additions & 2 deletions ecnet/tasks/limit_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/limit_inputs.py
# v.3.1.1
# v.3.1.2
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for selecting influential input parameters
Expand All @@ -17,10 +17,12 @@
from ditto_lib.utils.dataframe import Attribute

# ECNet imports
from ecnet.utils.data_utils import DataFrame
from ecnet.utils.logging import logger


def limit_rforest(df, limit_num, num_estimators=1000, num_processes=1):
def limit_rforest(df: DataFrame, limit_num: int, num_estimators: int=1000,
num_processes: int=1) -> DataFrame:
'''Uses random forest regression to select input parameters
Args:
Expand Down
6 changes: 4 additions & 2 deletions ecnet/tasks/remove_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/remove_outliers.py
# v.3.1.1
# v.3.1.2
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains function for removing outliers from ECNet DataFrame
Expand All @@ -18,10 +18,12 @@
from ditto_lib.utils.dataframe import Attribute

# ECNet imports
from ecnet.utils.data_utils import DataFrame
from ecnet.utils.logging import logger


def remove_outliers(df, leaf_size=40, num_processes=1):
def remove_outliers(df: DataFrame, leaf_size: int=40,
num_processes: int=1) -> DataFrame:
'''Unsupervised outlier detection using local outlier factor
Args:
Expand Down
15 changes: 9 additions & 6 deletions ecnet/tasks/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/tuning.py
# v.3.1.1
# v.3.1.2
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions/fitness functions for tuning hyperparameters
Expand All @@ -12,13 +12,16 @@
from ecabc.abc import ABC

# ECNet imports
from ecnet.utils.data_utils import DataFrame
from ecnet.utils.logging import logger
from ecnet.utils.server_utils import default_config, train_model


def tune_hyperparameters(df, vars, num_employers, num_iterations,
num_processes=1, shuffle=None, split=None,
validate=True, eval_set=None, eval_fn='rmse'):
def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int,
num_iterations: int, num_processes: int=1,
shuffle: str=None, split: list=None,
validate: bool=True, eval_set: str=None,
eval_fn: str='rmse') -> dict:
'''Tunes neural network learning/architecture hyperparameters
Args:
Expand All @@ -27,7 +30,7 @@ def tune_hyperparameters(df, vars, num_employers, num_iterations,
num_employers (int): number of employer bees
num_iterations (int): number of search cycles for the colony
num_processes (int): number of parallel processes to utilize
shuffle (bool): if True, shuffles L/V/T data for all evals
shuffle (str): shuffles `train` or `all` sets if not None
split (list): if shuffle is True, [learn%, valid%, test%]
validate (bool): if True, uses periodic validation; otherwise, no
eval_set (str): set used to evaluate bee performance; `learn`, `valid`,
Expand Down Expand Up @@ -99,7 +102,7 @@ def tune_hyperparameters(df, vars, num_employers, num_iterations,
return vars


def tune_fitness_function(params, **kwargs):
def tune_fitness_function(params: dict, **kwargs):
'''Fitness function used by ABC
Args:
Expand Down
15 changes: 8 additions & 7 deletions ecnet/tools/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tools/conversions.py
# v.3.1.0
# v.3.1.2
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for converting various chemical file formats
Expand All @@ -26,21 +26,21 @@
)


def get_smiles(name):
def get_smiles(name: str) -> list:
'''Queries PubChemPy for SMILES string for supplied molecule
Args:
name (str): name of the molecule (IUPAC, CAS, etc.)
Returns:
str or None: if molecule found, returns first idenitifying SMILES,
else None
list: list of all SMILES strings representing the given molecule
'''

return [m.isomeric_smiles for m in get_compounds(name, 'name')]


def smiles_to_descriptors(smiles_file, descriptors_csv, fingerprints=False):
def smiles_to_descriptors(smiles_file: str, descriptors_csv: str,
fingerprints: bool=False) -> list:
'''Generates QSPR descriptors from supplied SMILES file using
PaDEL-Descriptor
Expand Down Expand Up @@ -111,7 +111,7 @@ def smiles_to_descriptors(smiles_file, descriptors_csv, fingerprints=False):
return [row for row in reader]


def smiles_to_mdl(smiles_file, mdl_file):
def smiles_to_mdl(smiles_file: str, mdl_file: str):
'''Invoke Open Babel to generate an MDL file containing all supplied
molecules; requires Open Babel to be installed externally
Expand Down Expand Up @@ -151,7 +151,8 @@ def smiles_to_mdl(smiles_file, mdl_file):
continue


def mdl_to_descriptors(mdl_file, descriptors_csv, fingerprints=False):
def mdl_to_descriptors(mdl_file: str, descriptors_csv: str,
fingerprints: bool=False) -> list:
'''Generates QSPR descriptors from supplied MDL file using
PaDEL-Descriptor
Expand Down
Loading

0 comments on commit 2083718

Please sign in to comment.