Skip to content

Commit

Permalink
Merge pull request #29 from ECRL/dev
Browse files Browse the repository at this point in the history
Bug fixes, database creation improvements
  • Loading branch information
tjkessler authored Mar 27, 2019
2 parents 36848dc + c0eda14 commit f92b742
Show file tree
Hide file tree
Showing 15 changed files with 73 additions and 35 deletions.
2 changes: 1 addition & 1 deletion ecnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ecnet.server import Server
__version__ = '3.0.0'
__version__ = '3.0.1'
2 changes: 1 addition & 1 deletion ecnet/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/models/mlp.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains the "MultilayerPerceptron" (feed-forward neural network) class
Expand Down
9 changes: 6 additions & 3 deletions ecnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/server.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains the "Server" class, which handles ECNet project creation, neural
Expand All @@ -23,9 +23,9 @@
save_config, save_df, train_model, use_model

# Stdlib imports
from os import listdir, makedirs, path, walk
from os import listdir, makedirs, name, path, walk
from operator import itemgetter
from multiprocessing import Pool
from multiprocessing import Pool, set_start_method
from shutil import rmtree
from zipfile import ZipFile, ZIP_DEFLATED

Expand Down Expand Up @@ -53,6 +53,9 @@ def __init__(self, model_config='config.yml', prj_file=None,

self._num_processes = num_processes

if name != 'nt':
set_start_method('spawn', force=True)

if prj_file is not None:
self._open_project(prj_file)
return
Expand Down
3 changes: 2 additions & 1 deletion ecnet/tasks/limit_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/limit_inputs.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for selecting influential input parameters
Expand Down Expand Up @@ -37,6 +37,7 @@ def limit_rforest(df, limit_num, num_estimators=1000, num_processes=1):
if logger.file_level != 'disable':
ditto_logger.log_dir = logger.log_dir
ditto_logger.file_level = logger.file_level
ditto_logger.default_call_loc('LIMIT')
item_collection = ItemCollection(df._filename)
for inp_name in df.input_names:
item_collection.add_attribute(Attribute(inp_name))
Expand Down
3 changes: 2 additions & 1 deletion ecnet/tasks/remove_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/remove_outliers.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains function for removing outliers from ECNet DataFrame
Expand Down Expand Up @@ -38,6 +38,7 @@ def remove_outliers(df, leaf_size=40, num_processes=1):
if logger.file_level != 'disable':
ditto_logger.log_dir = logger.log_dir
ditto_logger.file_level = logger.file_level
ditto_logger.default_call_loc('OUTLIERS')
item_collection = ItemCollection(df._filename)
for inp_name in df.input_names:
item_collection.add_attribute(Attribute(inp_name))
Expand Down
5 changes: 3 additions & 2 deletions ecnet/tasks/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tasks/tuning.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions/fitness functions for tuning hyperparameters
Expand Down Expand Up @@ -73,6 +73,7 @@ def tune_hyperparameters(df, vars, num_employers, num_iterations,
if logger.file_level != 'disable':
abc._logger.log_dir = logger.log_dir
abc._logger.file_level = logger.file_level
abc._logger.default_call_loc('TUNE')
abc.create_employers()
for i in range(num_iterations):
logger.log('info', 'Iteration {}'.format(i + 1), call_loc='TUNE')
Expand All @@ -92,7 +93,7 @@ def tune_hyperparameters(df, vars, num_employers, num_iterations,
vars['beta_2'] = params[1]
vars['decay'] = params[2]
vars['epsilon'] = params[3]
vars['learning_date'] = params[4]
vars['learning_rate'] = params[4]
for l_idx in range(len(vars['hidden_layers'])):
vars['hidden_layers'][l_idx][0] = params[5 + l_idx]
return vars
Expand Down
51 changes: 34 additions & 17 deletions ecnet/tools/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tools/conversions.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for converting various chemical file formats
Expand Down Expand Up @@ -39,7 +39,7 @@ def get_smiles(name):

smiles = [m.isomeric_smiles for m in get_compounds(name, 'name')]
if len(smiles) == 0:
raise IndexError('PubChem entry not found for {}'.format(name))
return ''
else:
return smiles[0]

Expand Down Expand Up @@ -84,13 +84,15 @@ def smiles_to_mdl(smiles_file, mdl_file):
continue


def mdl_to_descriptors(mdl_file, descriptors_csv):
def mdl_to_descriptors(mdl_file, descriptors_csv, fingerprints=False):
'''Generates QSPR descriptors from supplied MDL file using
PaDEL-Descriptor
Args:
mdl_file (str): path to source MDL file
descriptors_csv (str): path to resulting CSV file w/ descriptors
fingerprints (bool): if True, generates molecular fingerprints instead
of QSPR descriptors
Returns:
list: list of dicts, where each dict is a molecule populated with
Expand All @@ -105,20 +107,35 @@ def mdl_to_descriptors(mdl_file, descriptors_csv):
dn = open(devnull, 'w')
for attempt in range(3):
try:
call([
'java',
'-jar',
_PADEL_PATH,
'-2d',
'-3d',
'-retainorder',
'-retain3d',
'-dir',
mdl_file,
'-file',
descriptors_csv
], stdout=dn, stderr=dn, timeout=3600)
break
if fingerprints:
call([
'java',
'-jar',
_PADEL_PATH,
'-fingerprints',
'-retainorder',
'-retain3d',
'-dir',
mdl_file,
'-file',
descriptors_csv
], stdout=dn, stderr=dn, timeout=600)
break
else:
call([
'java',
'-jar',
_PADEL_PATH,
'-2d',
'-3d',
'-retainorder',
'-retain3d',
'-dir',
mdl_file,
'-file',
descriptors_csv
], stdout=dn, stderr=dn, timeout=600)
break
except Exception as e:
if attempt == 2:
raise e
Expand Down
8 changes: 5 additions & 3 deletions ecnet/tools/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tools/database.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for creating ECNet-formatted databases
Expand All @@ -20,7 +20,7 @@

def create_db(input_txt, output_name, id_prefix='', targets=None, form='name',
smiles_file='mols.smi', mdl_file='mols.mdl',
desc_file='descriptors.csv', clean_up=True):
desc_file='descriptors.csv', clean_up=True, fingerprints=False):
'''Create an ECNet-formatted database from either molecule names or SMILES
strings
Expand All @@ -37,6 +37,8 @@ def create_db(input_txt, output_name, id_prefix='', targets=None, form='name',
desc_file (str): name of descriptors file generated by PaDEL-Descriptor
clean_up (bool): if True, cleans up all files generated during
processing except for the input text files and output database
fingerprints (bool): if True, generates molecular fingerprints instead
of QSPR descriptors
'''

input_data = _read_txt(input_txt)
Expand Down Expand Up @@ -67,7 +69,7 @@ def create_db(input_txt, output_name, id_prefix='', targets=None, form='name',
target_data = [0 for _ in range(len(input_data))]

smiles_to_mdl(smiles_file, mdl_file)
desc = mdl_to_descriptors(mdl_file, desc_file)
desc = mdl_to_descriptors(mdl_file, desc_file, fingerprints)
desc_keys = list(desc[0].keys())
try:
desc_keys.remove('Name')
Expand Down
2 changes: 1 addition & 1 deletion ecnet/tools/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/tools/project.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for predicting data using pre-existing .prj files
Expand Down
2 changes: 1 addition & 1 deletion ecnet/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/utils/data_utils.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions/classes for loading data, saving data, saving results
Expand Down
2 changes: 1 addition & 1 deletion ecnet/utils/error_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/utils/error_utils.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions for error calculations
Expand Down
2 changes: 1 addition & 1 deletion ecnet/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/utils/logging.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains logger used by ECNet
Expand Down
2 changes: 1 addition & 1 deletion ecnet/utils/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# ecnet/utils/server_utils.py
# v.3.0.0
# v.3.0.1
# Developed in 2019 by Travis Kessler <[email protected]>
#
# Contains functions used by ecnet.Server
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='ecnet',
version='3.0.0',
version='3.0.1',
description='UMass Lowell Energy and Combustion Research Laboratory Neural'
' Network Software',
url='http://github.com/tjkessler/ecnet',
Expand Down
13 changes: 13 additions & 0 deletions tests/test_create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,20 @@ def from_smiles():
)


def fingerprints():

print('Creating database with fingerprints...')
create_db(
'mols_smiles.txt',
'db_with_fingerprints.csv',
targets='mols_targets.txt',
form='smiles',
fingerprints=True
)


if __name__ == '__main__':

from_names()
from_smiles()
fingerprints()

0 comments on commit f92b742

Please sign in to comment.