From c4929fcf31ff185e731b604f077220b1edf0401f Mon Sep 17 00:00:00 2001 From: martin-sicho Date: Tue, 19 Mar 2024 09:27:34 +0100 Subject: [PATCH 1/5] prevent qsprpred scorer error on empty list --- CHANGELOG.md | 7 +++---- drugex/training/scorers/qsprpred.py | 3 +++ drugex/training/tests.py | 14 ++++++++++---- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5ce23e..b3b9235 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,14 +1,13 @@ # Change Log -From v3.4.5 to v3.4.6 +From v3.4.6 to v3.4.7 ## Fixes -None. +- Prevent the QSPRpred scorer from crashing when an empty list of molecules is supplied. The scorer now returns an empty list of scores and outputs a warning. ## Changes -- For the generator CLI, environment variables are now read from the generator meta file automatically. Now unused arguments are removed from the CLI. -- Compatibility updates to make the package work with the latest QSPRpred scorers in version 3.0.0 and higher. Older scorers will still work if an older version is installed alongside DrugEx. Only the unit tests will fail since the models used there assume QSPPRpred v3.0.0 or later. +None. ## Removed Features diff --git a/drugex/training/scorers/qsprpred.py b/drugex/training/scorers/qsprpred.py index 4a7b226..bc0f14f 100644 --- a/drugex/training/scorers/qsprpred.py +++ b/drugex/training/scorers/qsprpred.py @@ -14,6 +14,9 @@ def __init__(self, model, invalids_score=0.0, modifier=None, **kwargs): self.kwargs = kwargs def getScores(self, mols, frags=None): + if len(mols) == 0: + logger.warning("No molecules to score. Returning empty list...") + return [] parsed_mols = [] if not isinstance(mols[0], str): invalids = 0 diff --git a/drugex/training/tests.py b/drugex/training/tests.py index 06da3bc..7ed4595 100644 --- a/drugex/training/tests.py +++ b/drugex/training/tests.py @@ -159,19 +159,25 @@ class TestScorer(TestCase): def test_getScores(self): scorer = getPredictor() + # test with invalid + mols = ["CCO", "XXXX"] + scores = scorer.getScores(mols) + self.assertEqual(len(scores), len(mols)) + # test with empty + mols = [] + scores = scorer.getScores(mols) + self.assertEqual(len(scores), len(mols)) + # test with valid mols = ["CCO", "CC"] scores = scorer.getScores(mols) self.assertEqual(len(scores), len(mols)) self.assertTrue(all([isinstance(score, float) and score > 0 for score in scores])) - + # test directly with RDKit mols mols = [Chem.MolFromSmiles("CCO"), Chem.MolFromSmiles("CC")] scores = scorer.getScores(mols) self.assertEqual(len(scores), len(mols)) self.assertTrue(all([isinstance(score, float) and score > 0 for score in scores])) - mols = ["CCO", "XXXX"] # test with invalid - scores = scorer.getScores(mols) - self.assertEqual(len(scores), len(mols)) class TrainingTestCase(TestCase): From cdebf7877d4f4f50cb3d6b82a632da70b3d5951f Mon Sep 17 00:00:00 2001 From: martin-sicho Date: Tue, 19 Mar 2024 09:31:29 +0100 Subject: [PATCH 2/5] add scorer tests --- drugex/training/tests.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/drugex/training/tests.py b/drugex/training/tests.py index 7ed4595..d5b16cc 100644 --- a/drugex/training/tests.py +++ b/drugex/training/tests.py @@ -400,7 +400,8 @@ def test_sequence_rnn(self): self.assertTrue(type(monitor.getModel()) == collections.OrderedDict) self.assertTrue(monitor.allMethodsExecuted()) - pretrained.generate(num_samples=10, evaluator=environment, drop_invalid=False) + pretrained.generate(num_samples=10, evaluator=environment, drop_invalid=False, raw_scores=True) + pretrained.generate(num_samples=10, evaluator=environment, drop_invalid=False, raw_scores=False) def test_graph_transformer(self): """ @@ -463,7 +464,13 @@ def test_graph_transformer(self): pretrained.generate([ "c1ccncc1CCC", "CCO" - ], num_samples=1, evaluator=environment, drop_invalid=False) + ], num_samples=2, evaluator=environment, drop_invalid=False, raw_scores=True) + pretrained.generate([ + "c1ccncc1CCC", + "CCO" + ], num_samples=2, evaluator=environment, drop_invalid=False, raw_scores=False) + pretrained.generate(input_dataset=pr_data_set_test, num_samples=20, evaluator=environment, drop_invalid=False, raw_scores=True) + pretrained.generate(input_dataset=pr_data_set_test, num_samples=20, evaluator=environment, drop_invalid=False, raw_scores=False) def test_graph_transformer_scaffold(self): """ @@ -507,6 +514,10 @@ def test_graph_transformer_scaffold(self): self.assertTrue(type(monitor.getModel()) == collections.OrderedDict) self.assertTrue(monitor.allMethodsExecuted()) + # generate molecules + pretrained.generate(input_dataset=data_set, num_samples=5, evaluator=environment, drop_invalid=False, raw_scores=True) + pretrained.generate(frags, num_samples=5, evaluator=environment, drop_invalid=False, raw_scores=False) + def test_sequence_transformer(self): """ Test fragment-based sequence transformer model. From 46ac8144bd1aac94fa29441ea99febcc73702607 Mon Sep 17 00:00:00 2001 From: Helle van den Maagdenberg Date: Fri, 22 Mar 2024 11:16:49 +0100 Subject: [PATCH 3/5] update to rdkit from rdkit-pypi --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 52efbe7..cb532b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "torch >= 1.7.0", "matplotlib >= 2.0", "tqdm", - "rdkit-pypi", + "rdkit", "joblib", "gitpython", "networkx", From 775cc3a684fbdc63f86602f1e8e9bf90ced71555 Mon Sep 17 00:00:00 2001 From: Helle van den Maagdenberg Date: Fri, 22 Mar 2024 12:35:53 +0100 Subject: [PATCH 4/5] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b3b9235..c44134e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ From v3.4.6 to v3.4.7 ## Changes -None. +- update depency rdkit-pypi to rdkit ## Removed Features From 1a91559507f8a77b923d4cfd0b019f768aee3441 Mon Sep 17 00:00:00 2001 From: Helle van den Maagdenberg Date: Mon, 6 May 2024 14:35:17 +0200 Subject: [PATCH 5/5] bump up version for the next release --- drugex/about.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/drugex/about.py b/drugex/about.py index c365960..b1de983 100644 --- a/drugex/about.py +++ b/drugex/about.py @@ -1,6 +1,6 @@ import os -VERSION = "3.4.6" +VERSION = "3.4.7" if os.path.exists(os.path.join(os.path.dirname(__file__), '_version.py')): from ._version import version