Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
HellevdM committed May 6, 2024
2 parents b1152c6 + 1a91559 commit 34687de
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 12 deletions.
7 changes: 3 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Change Log
From v3.4.5 to v3.4.6
From v3.4.6 to v3.4.7

## Fixes

None.
- Prevent the QSPRpred scorer from crashing when an empty list of molecules is supplied. The scorer now returns an empty list of scores and outputs a warning.

## Changes

- For the generator CLI, environment variables are now read from the generator meta file automatically. Now unused arguments are removed from the CLI.
- Compatibility updates to make the package work with the latest QSPRpred scorers in version 3.0.0 and higher. Older scorers will still work if an older version is installed alongside DrugEx. Only the unit tests will fail since the models used there assume QSPPRpred v3.0.0 or later.
- update depency rdkit-pypi to rdkit

## Removed Features

Expand Down
2 changes: 1 addition & 1 deletion drugex/about.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

VERSION = "3.4.6"
VERSION = "3.4.7"

if os.path.exists(os.path.join(os.path.dirname(__file__), '_version.py')):
from ._version import version
Expand Down
3 changes: 3 additions & 0 deletions drugex/training/scorers/qsprpred.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def __init__(self, model, invalids_score=0.0, modifier=None, **kwargs):
self.kwargs = kwargs

def getScores(self, mols, frags=None):
if len(mols) == 0:
logger.warning("No molecules to score. Returning empty list...")
return []
parsed_mols = []
if not isinstance(mols[0], str):
invalids = 0
Expand Down
29 changes: 23 additions & 6 deletions drugex/training/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,25 @@ class TestScorer(TestCase):

def test_getScores(self):
scorer = getPredictor()
# test with invalid
mols = ["CCO", "XXXX"]
scores = scorer.getScores(mols)
self.assertEqual(len(scores), len(mols))
# test with empty
mols = []
scores = scorer.getScores(mols)
self.assertEqual(len(scores), len(mols))
# test with valid
mols = ["CCO", "CC"]
scores = scorer.getScores(mols)
self.assertEqual(len(scores), len(mols))
self.assertTrue(all([isinstance(score, float) and score > 0 for score in scores]))

# test directly with RDKit mols
mols = [Chem.MolFromSmiles("CCO"), Chem.MolFromSmiles("CC")]
scores = scorer.getScores(mols)
self.assertEqual(len(scores), len(mols))
self.assertTrue(all([isinstance(score, float) and score > 0 for score in scores]))

mols = ["CCO", "XXXX"] # test with invalid
scores = scorer.getScores(mols)
self.assertEqual(len(scores), len(mols))

class TrainingTestCase(TestCase):

Expand Down Expand Up @@ -394,7 +400,8 @@ def test_sequence_rnn(self):
self.assertTrue(type(monitor.getModel()) == collections.OrderedDict)
self.assertTrue(monitor.allMethodsExecuted())

pretrained.generate(num_samples=10, evaluator=environment, drop_invalid=False)
pretrained.generate(num_samples=10, evaluator=environment, drop_invalid=False, raw_scores=True)
pretrained.generate(num_samples=10, evaluator=environment, drop_invalid=False, raw_scores=False)

def test_graph_transformer(self):
"""
Expand Down Expand Up @@ -457,7 +464,13 @@ def test_graph_transformer(self):
pretrained.generate([
"c1ccncc1CCC",
"CCO"
], num_samples=1, evaluator=environment, drop_invalid=False)
], num_samples=2, evaluator=environment, drop_invalid=False, raw_scores=True)
pretrained.generate([
"c1ccncc1CCC",
"CCO"
], num_samples=2, evaluator=environment, drop_invalid=False, raw_scores=False)
pretrained.generate(input_dataset=pr_data_set_test, num_samples=20, evaluator=environment, drop_invalid=False, raw_scores=True)
pretrained.generate(input_dataset=pr_data_set_test, num_samples=20, evaluator=environment, drop_invalid=False, raw_scores=False)

def test_graph_transformer_scaffold(self):
"""
Expand Down Expand Up @@ -501,6 +514,10 @@ def test_graph_transformer_scaffold(self):
self.assertTrue(type(monitor.getModel()) == collections.OrderedDict)
self.assertTrue(monitor.allMethodsExecuted())

# generate molecules
pretrained.generate(input_dataset=data_set, num_samples=5, evaluator=environment, drop_invalid=False, raw_scores=True)
pretrained.generate(frags, num_samples=5, evaluator=environment, drop_invalid=False, raw_scores=False)

def test_sequence_transformer(self):
"""
Test fragment-based sequence transformer model.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"torch >= 1.7.0",
"matplotlib >= 2.0",
"tqdm",
"rdkit-pypi",
"rdkit",
"joblib",
"gitpython",
"networkx",
Expand Down

0 comments on commit 34687de

Please sign in to comment.