-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #31 from ECRL/dev
Updates to DataFrame, DataPoint structure
- Loading branch information
Showing
63 changed files
with
4,301 additions
and
210 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from ecnet.server import Server | ||
__version__ = '3.0.1' | ||
__version__ = '3.1.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/models/mlp.py | ||
# v.3.0.1 | ||
# v.3.1.0 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "MultilayerPerceptron" (feed-forward neural network) class | ||
|
@@ -14,7 +14,7 @@ | |
import sys | ||
|
||
# 3rd party imports | ||
from tensorflow import get_default_graph | ||
from tensorflow import get_default_graph, logging | ||
stderr = sys.stderr | ||
sys.stderr = open(devnull, 'w') | ||
from keras.backend import clear_session, reset_uids | ||
|
@@ -29,6 +29,7 @@ | |
from ecnet.utils.logging import logger | ||
|
||
environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
logging.set_verbosity(logging.ERROR) | ||
|
||
H5_EXT = compile(r'.*\.h5', flags=IGNORECASE) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,14 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/server.py | ||
# v.3.0.1 | ||
# v.3.1.0 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "Server" class, which handles ECNet project creation, neural | ||
# network model creation, data hand-off to models, prediction error | ||
# calculation, input parameter selection, hyperparameter tuning. | ||
# | ||
# For example scripts, refer to https://github.com/ecrl/ecnet/examples | ||
# For example scripts, refer to https://ecnet.readthedocs.io/en/latest/ | ||
# | ||
|
||
# ECNet imports | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/limit_inputs.py | ||
# v.3.0.1 | ||
# v.3.1.0 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for selecting influential input parameters | ||
|
@@ -39,27 +39,31 @@ def limit_rforest(df, limit_num, num_estimators=1000, num_processes=1): | |
ditto_logger.file_level = logger.file_level | ||
ditto_logger.default_call_loc('LIMIT') | ||
item_collection = ItemCollection(df._filename) | ||
for inp_name in df.input_names: | ||
for inp_name in df._input_names: | ||
item_collection.add_attribute(Attribute(inp_name)) | ||
for pt in df.data_points: | ||
item_collection.add_item(pt.id, deepcopy(pt.inputs)) | ||
for tar_name in df.target_names: | ||
item_collection.add_item( | ||
pt.id, | ||
deepcopy([getattr(pt, i) for i in df._input_names]) | ||
) | ||
for tar_name in df._target_names: | ||
item_collection.add_attribute(Attribute(tar_name, is_descriptor=False)) | ||
for pt in df.data_points: | ||
for idx, tar in enumerate(pt.targets): | ||
target_vals = [getattr(pt, t) for t in df._target_names] | ||
for idx, tar in enumerate(target_vals): | ||
item_collection.set_item_attribute( | ||
pt.id, tar, df.target_names[idx] | ||
pt.id, tar, df._target_names[idx] | ||
) | ||
item_collection.strip() | ||
params = [param[0] for param in random_forest_regressor( | ||
item_collection.dataframe, | ||
target_attribute=df.target_names[0], | ||
target_attribute=df._target_names[0], | ||
n_components=limit_num, | ||
n_estimators=num_estimators, | ||
n_jobs=num_processes | ||
)] | ||
for idx, param in enumerate(params): | ||
for tn in df.target_names: | ||
for tn in df._target_names: | ||
if tn == param: | ||
del params[idx] | ||
break | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/remove_outliers.py | ||
# v.3.0.1 | ||
# v.3.1.0 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains function for removing outliers from ECNet DataFrame | ||
|
@@ -40,10 +40,13 @@ def remove_outliers(df, leaf_size=40, num_processes=1): | |
ditto_logger.file_level = logger.file_level | ||
ditto_logger.default_call_loc('OUTLIERS') | ||
item_collection = ItemCollection(df._filename) | ||
for inp_name in df.input_names: | ||
for inp_name in df._input_names: | ||
item_collection.add_attribute(Attribute(inp_name)) | ||
for pt in df.data_points: | ||
item_collection.add_item(pt.id, deepcopy(pt.inputs)) | ||
item_collection.add_item( | ||
pt.id, | ||
deepcopy([getattr(pt, i) for i in df._input_names]) | ||
) | ||
item_collection.strip() | ||
outliers = local_outlier_factor( | ||
item_collection.dataframe, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/tuning.py | ||
# v.3.0.1 | ||
# v.3.1.0 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/fitness functions for tuning hyperparameters | ||
|
@@ -117,6 +117,7 @@ def tune_fitness_function(params, **kwargs): | |
vars['epsilon'] = params[3] | ||
vars['learning_rate'] = params[4] | ||
vars['hidden_layers'] = kwargs['hidden_layers'] | ||
vars['epochs'] = 2000 | ||
for l_idx in range(len(vars['hidden_layers'])): | ||
vars['hidden_layers'][l_idx][0] = params[5 + l_idx] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/conversions.py | ||
# v.3.0.1 | ||
# v.3.1.0 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for converting various chemical file formats | ||
|
@@ -30,18 +30,14 @@ def get_smiles(name): | |
'''Queries PubChemPy for SMILES string for supplied molecule | ||
Args: | ||
name (str): name of the molecule | ||
name (str): name of the molecule (IUPAC, CAS, etc.) | ||
Returns: | ||
str or None: if molecule found, returns first idenitifying SMILES, | ||
else None | ||
''' | ||
|
||
smiles = [m.isomeric_smiles for m in get_compounds(name, 'name')] | ||
if len(smiles) == 0: | ||
return '' | ||
else: | ||
return smiles[0] | ||
return [m.isomeric_smiles for m in get_compounds(name, 'name')] | ||
|
||
|
||
def smiles_to_mdl(smiles_file, mdl_file): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/database.py | ||
# v.3.0.1 | ||
# v.3.1.0 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for creating ECNet-formatted databases | ||
|
@@ -20,7 +20,8 @@ | |
|
||
def create_db(input_txt, output_name, id_prefix='', targets=None, form='name', | ||
smiles_file='mols.smi', mdl_file='mols.mdl', | ||
desc_file='descriptors.csv', clean_up=True, fingerprints=False): | ||
desc_file='descriptors.csv', clean_up=True, fingerprints=False, | ||
extra_strings={}): | ||
'''Create an ECNet-formatted database from either molecule names or SMILES | ||
strings | ||
|
@@ -39,18 +40,32 @@ def create_db(input_txt, output_name, id_prefix='', targets=None, form='name', | |
processing except for the input text files and output database | ||
fingerprints (bool): if True, generates molecular fingerprints instead | ||
of QSPR descriptors | ||
extra_strings (dict): if additional STRING headers are desired, | ||
specify them with {'Str Name': [str1, str2 ...]} | ||
''' | ||
|
||
input_data = _read_txt(input_txt) | ||
if form == 'name': | ||
input_names = deepcopy(input_data) | ||
for i, d in enumerate(input_data): | ||
input_data[i] = get_smiles(d) | ||
smiles = get_smiles(d) | ||
if len(smiles) == 0: | ||
raise IndexError('SMILES not found for {}'.format(d)) | ||
input_data[i] = smiles[0] | ||
with open(smiles_file, 'w') as smi_file: | ||
for d in input_data: | ||
smi_file.write(d + '\n') | ||
elif form == 'smiles': | ||
input_names = ['' for _ in range(len(input_data))] | ||
if 'Compound Name' not in extra_strings: | ||
input_names = ['' for _ in range(len(input_data))] | ||
else: | ||
input_names = extra_strings['Compound Name'] | ||
if len(input_names) != len(input_data): | ||
raise IndexError('Number of supplied names does not equal the ' | ||
'number of supplied molecules: {}, {}'.format( | ||
len(input_names), len(input_data) | ||
)) | ||
del extra_strings['Compound Name'] | ||
smiles_file = input_txt | ||
|
||
else: | ||
|
@@ -81,27 +96,36 @@ def create_db(input_txt, output_name, id_prefix='', targets=None, form='name', | |
is_valid = True | ||
for row in desc[1:]: | ||
if row[ds] == '' or row[ds] is None: | ||
row[ds] = 0 | ||
is_valid = False | ||
break | ||
if is_valid: | ||
valid_keys.append(ds) | ||
desc_keys = valid_keys | ||
|
||
rows = [] | ||
type_row = ['DATAID', 'ASSIGNMENT', 'STRING', 'STRING', 'TARGET'] | ||
|
||
type_row = ['DATAID', 'ASSIGNMENT', 'STRING', 'STRING'] | ||
type_row.extend(['STRING' for string in extra_strings]) | ||
type_row.append('TARGET') | ||
type_row.extend(['INPUT' for _ in range(len(desc_keys))]) | ||
title_row = ['DATAID', 'ASSIGNMENT', 'Compound Name', 'SMILES', 'Target'] | ||
title_row.extend(desc_keys) | ||
rows.append(type_row) | ||
|
||
title_row = ['DATAID', 'ASSIGNMENT', 'Compound Name', 'SMILES'] | ||
title_row.extend([string for string in extra_strings]) | ||
title_row.append('TARGET') | ||
title_row.extend(desc_keys) | ||
rows.append(title_row) | ||
|
||
for idx, name in enumerate(input_names): | ||
|
||
mol_row = [ | ||
'{}'.format(id_prefix) + '%04d' % (idx + 1), | ||
'L', | ||
name, | ||
input_data[idx], | ||
target_data[idx] | ||
input_data[idx] | ||
] | ||
mol_row.extend([extra_strings[s][idx] for s in extra_strings]) | ||
mol_row.append(target_data[idx]) | ||
mol_row.extend([desc[idx][k] for k in desc_keys]) | ||
rows.append(mol_row) | ||
|
||
|
Oops, something went wrong.