diff --git a/molfeat/calc/_serializable_classes.py b/molfeat/calc/_serializable_classes.py new file mode 100644 index 0000000..2a43b22 --- /dev/null +++ b/molfeat/calc/_serializable_classes.py @@ -0,0 +1,72 @@ +from typing import Optional +from typing import Dict +from typing import Any + +from rdkit.Chem import rdFingerprintGenerator + +SERIALIZABLE_CLASSES = {} + + +def register_custom_serializable_class(cls: type): + SERIALIZABLE_CLASSES[cls.__name__] = cls + return cls + + +@register_custom_serializable_class +class SerializableMorganFeatureAtomInvGen: + """A serializable wrapper class for `rdFingerprintGenerator.GetMorganFeatureAtomInvGen()`""" + + def __init__(self): + self._generator = rdFingerprintGenerator.GetMorganFeatureAtomInvGen() + + def __getstate__(self): + return None + + def __setstate__(self, state: Optional[None]): + self._generator = rdFingerprintGenerator.GetMorganFeatureAtomInvGen() + + def __deepcopy__(self, memo: Dict[int, Any]): + new_instance = SerializableMorganFeatureAtomInvGen() + memo[id(self)] = new_instance + return new_instance + + def __getattr__(self, name: str): + try: + generator = object.__getattribute__(self, "_generator") + except AttributeError: + raise AttributeError("'_generator' is not initialized") + + try: + return getattr(generator, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + +@register_custom_serializable_class +class SerializableMorganFeatureBondInvGen: + """A serializable wrapper class for `rdFingerprintGenerator.GetMorganFeatureBondInvGen()`""" + + def __init__(self): + self._generator = rdFingerprintGenerator.GetMorganFeatureBondInvGen() + + def __getstate__(self): + return None + + def __setstate__(self, state: Optional[None]): + self._generator = rdFingerprintGenerator.GetMorganFeatureBondInvGen() + + def __deepcopy__(self, memo: Dict[int, Any]): + new_instance = SerializableMorganFeatureBondInvGen() + memo[id(self)] = new_instance + return new_instance + + def __getattr__(self, name: str): + try: + generator = object.__getattribute__(self, "_generator") + except AttributeError: + raise AttributeError("'_generator' is not initialized") + + try: + return getattr(generator, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/molfeat/calc/fingerprints.py b/molfeat/calc/fingerprints.py index 7144436..3130b50 100644 --- a/molfeat/calc/fingerprints.py +++ b/molfeat/calc/fingerprints.py @@ -1,11 +1,10 @@ from typing import Union from typing import Optional -from functools import partial - import copy import datamol as dm from rdkit.Avalon import pyAvalonTools +from rdkit.Chem import rdFingerprintGenerator from rdkit.Chem import rdMolDescriptors from rdkit.Chem import rdReducedGraphs from rdkit.Chem import rdmolops @@ -14,19 +13,31 @@ from loguru import logger from molfeat.calc._mhfp import SECFP from molfeat.calc._map4 import MAP4 +from molfeat.calc._serializable_classes import ( + SerializableMorganFeatureAtomInvGen, + SERIALIZABLE_CLASSES, +) from molfeat.calc.base import SerializableCalculator from molfeat.utils.datatype import to_numpy, to_fp from molfeat.utils.commons import fold_count_fp +FP_GENERATORS = { + "ecfp": rdFingerprintGenerator.GetMorganGenerator, + "fcfp": rdFingerprintGenerator.GetMorganGenerator, + "topological": rdFingerprintGenerator.GetTopologicalTorsionGenerator, + "atompair": rdFingerprintGenerator.GetAtomPairGenerator, + "rdkit": rdFingerprintGenerator.GetRDKitFPGenerator, + "ecfp-count": rdFingerprintGenerator.GetMorganGenerator, + "fcfp-count": rdFingerprintGenerator.GetMorganGenerator, + "topological-count": rdFingerprintGenerator.GetTopologicalTorsionGenerator, + "atompair-count": rdFingerprintGenerator.GetAtomPairGenerator, + "rdkit-count": rdFingerprintGenerator.GetRDKitFPGenerator, +} + FP_FUNCS = { "maccs": rdMolDescriptors.GetMACCSKeysFingerprint, "avalon": pyAvalonTools.GetAvalonFP, - "ecfp": rdMolDescriptors.GetMorganFingerprintAsBitVect, - "fcfp": partial(rdMolDescriptors.GetMorganFingerprintAsBitVect, useFeatures=True), - "topological": rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect, - "atompair": rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect, - "rdkit": rdmolops.RDKFingerprint, "pattern": rdmolops.PatternFingerprint, "layered": rdmolops.LayeredFingerprint, "map4": MAP4, @@ -34,11 +45,7 @@ "erg": rdReducedGraphs.GetErGFingerprint, "estate": lambda x, **params: EStateFingerprinter.FingerprintMol(x)[0], "avalon-count": pyAvalonTools.GetAvalonCountFP, - "rdkit-count": rdmolops.UnfoldedRDKFingerprintCountBased, - "ecfp-count": rdMolDescriptors.GetHashedMorganFingerprint, - "fcfp-count": rdMolDescriptors.GetHashedMorganFingerprint, - "topological-count": rdMolDescriptors.GetHashedTopologicalTorsionFingerprint, - "atompair-count": rdMolDescriptors.GetHashedAtomPairFingerprint, + **FP_GENERATORS, } @@ -52,59 +59,60 @@ }, "ecfp": { "radius": 2, # ECFP4 - "nBits": 2048, - "invariants": [], - "fromAtoms": [], - "useChirality": False, + "fpSize": 2048, + "includeChirality": False, "useBondTypes": True, - "useFeatures": False, + "countSimulation": False, + "countBounds": None, + "atomInvariantsGenerator": None, + "bondInvariantsGenerator": None, }, "fcfp": { - "radius": 2, # FCFP4 - "nBits": 2048, - "invariants": [], # you may want to provide features invariance - "fromAtoms": [], - "useChirality": False, + "radius": 2, + "fpSize": 2048, + "includeChirality": False, "useBondTypes": True, - "useFeatures": True, + "countSimulation": False, + "countBounds": None, + "atomInvariantsGenerator": SerializableMorganFeatureAtomInvGen(), + "bondInvariantsGenerator": None, }, "topological": { - "nBits": 2048, - "targetSize": 4, - "fromAtoms": 0, - "ignoreAtoms": 0, - "atomInvariants": 0, - "nBitsPerEntry": 4, "includeChirality": False, + "torsionAtomCount": 4, + "countSimulation": True, + "countBounds": None, + "fpSize": 2048, + "atomInvariantsGenerator": None, }, "atompair": { - "nBits": 2048, - "minLength": 1, - "maxLength": 30, - "fromAtoms": 0, - "ignoreAtoms": 0, - "atomInvariants": 0, - "nBitsPerEntry": 4, + "minDistance": 1, + "maxDistance": 30, "includeChirality": False, "use2D": True, - "confId": -1, + "countSimulation": True, + "countBounds": None, + "fpSize": 2048, + "atomInvariantsGenerator": None, }, "rdkit": { "minPath": 1, "maxPath": 7, - "fpSize": 2048, - "nBitsPerHash": 2, "useHs": True, - "tgtDensity": 0.0, - "minSize": 128, "branchedPaths": True, "useBondOrder": True, - "atomInvariants": 0, - "fromAtoms": 0, - "atomBits": None, - "bitInfo": None, + "countSimulation": False, + "countBounds": None, + "fpSize": 2048, + "numBitsPerFeature": 2, + "atomInvariantsGenerator": None, + }, + "pattern": { + "fpSize": 2048, + "atomCounts": [], + "setOnlyBits": None, + "tautomerFingerprints": False, }, - "pattern": {"fpSize": 2048, "atomCounts": [], "setOnlyBits": None}, "layered": { "fpSize": 2048, "minPath": 1, @@ -139,36 +147,40 @@ # COUNTING FP "ecfp-count": { "radius": 2, # ECFP4 - "nBits": 2048, - "invariants": [], - "fromAtoms": [], - "useChirality": False, + "fpSize": 2048, + "includeChirality": False, "useBondTypes": True, - "useFeatures": False, "includeRedundantEnvironments": False, + "countBounds": None, + "atomInvariantsGenerator": None, + "bondInvariantsGenerator": None, }, "fcfp-count": { - "radius": 2, # FCFP4 - "nBits": 2048, - "invariants": [], # you may want to provide features invariance - "fromAtoms": [], - "useChirality": False, + "radius": 2, + "fpSize": 2048, + "includeChirality": False, "useBondTypes": True, - "useFeatures": True, "includeRedundantEnvironments": False, + "atomInvariantsGenerator": SerializableMorganFeatureAtomInvGen(), + "bondInvariantsGenerator": None, }, "topological-count": { - "nBits": 2048, - "targetSize": 4, - "fromAtoms": 0, - "ignoreAtoms": 0, - "atomInvariants": 0, "includeChirality": False, + "torsionAtomCount": 4, + "countSimulation": True, + "countBounds": None, + "fpSize": 2048, + "atomInvariantsGenerator": None, }, - "avalon-count": { - "nBits": 512, - "isQuery": False, - "bitFlags": pyAvalonTools.avalonSimilarityBits, + "atompair-count": { + "minDistance": 1, + "maxDistance": 30, + "includeChirality": False, + "use2D": True, + "countSimulation": True, + "countBounds": None, + "fpSize": 2048, + "atomInvariantsGenerator": None, }, "rdkit-count": { "minPath": 1, @@ -176,21 +188,11 @@ "useHs": True, "branchedPaths": True, "useBondOrder": True, - "atomInvariants": 0, - "fromAtoms": 0, - "atomBits": None, - "bitInfo": None, - }, - "atompair-count": { - "nBits": 2048, - "minLength": 1, - "maxLength": 30, - "fromAtoms": 0, - "ignoreAtoms": 0, - "atomInvariants": 0, - "includeChirality": False, - "use2D": True, - "confId": -1, + "countSimulation": False, + "countBounds": None, + "fpSize": 2048, + "numBitsPerFeature": 1, + "atomInvariantsGenerator": None, }, } @@ -303,7 +305,16 @@ def __call__(self, mol: Union[dm.Mol, str], raw: bool = False): props (np.ndarray): list of computed rdkit molecular descriptors """ mol = dm.to_mol(mol) - fp_val = FP_FUNCS[self.method](mol, **self.params) + + fp_func = FP_FUNCS[self.method] + if self.method in FP_GENERATORS: + fp_func = fp_func(**self.params) + if self.counting: + fp_val = fp_func.GetCountFingerprint(mol) + else: + fp_val = fp_func.GetFingerprint(mol) + else: + fp_val = fp_func(mol, **self.params) if self.counting: fp_val = fold_count_fp(fp_val, self._length) if not raw: @@ -321,12 +332,19 @@ def __getstate__(self): state["input_length"] = self.input_length state["method"] = self.method state["counting"] = self.counting - state["params"] = self.params + state["params"] = { + k: (v if v.__class__.__name__ not in SERIALIZABLE_CLASSES else v.__class__.__name__) + for k, v in self.params.items() + } return state def __setstate__(self, state: dict): """Set the state of the featurizer""" self.__dict__.update(state) + self.params = { + k: (v if v not in SERIALIZABLE_CLASSES else SERIALIZABLE_CLASSES[v]()) + for k, v in self.params.items() + } self._length = self._set_length(self.input_length) def to_state_dict(self): @@ -334,9 +352,14 @@ def to_state_dict(self): state_dict = super().to_state_dict() cur_params = self.params default_params = copy.deepcopy(FP_DEF_PARAMS[state_dict["args"]["method"]]) + state_dict["args"].update( { - k: cur_params[k] + k: ( + cur_params[k] + if cur_params[k].__class__.__name__ not in SERIALIZABLE_CLASSES + else cur_params[k].__class__.__name__ + ) for k in cur_params if (cur_params[k] != default_params[k] and cur_params[k] is not None) } diff --git a/molfeat/trans/fp.py b/molfeat/trans/fp.py index f1e6a17..256b714 100644 --- a/molfeat/trans/fp.py +++ b/molfeat/trans/fp.py @@ -1,14 +1,11 @@ -from typing import Callable -from typing import List -from typing import Optional -from typing import Union - -import re import copy -import numpy as np +import re +from typing import Callable, List, Optional, Union + import datamol as dm +import numpy as np -from molfeat.calc import get_calculator, FP_FUNCS +from molfeat.calc import FP_FUNCS, get_calculator from molfeat.trans.base import MoleculeTransformer from molfeat.utils import datatype from molfeat.utils.commons import _parse_to_evaluable_str diff --git a/tests/test_fp.py b/tests/test_fp.py index c77393f..243941f 100644 --- a/tests/test_fp.py +++ b/tests/test_fp.py @@ -52,6 +52,7 @@ class TestMolTransformer(ut.TestCase): "avalon", "rdkit", "ecfp", + "ecfp-count", "pharm2D", "desc2D", ] diff --git a/tests/test_state.py b/tests/test_state.py index 95a2458..f6199b4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -259,7 +259,7 @@ def test_fp_state(): { "name": "FPCalculator", "module": "molfeat.calc.fingerprints", - "args": {"length": 512, "method": "ecfp", "counting": False, "nBits": 512}, + "args": {"length": 512, "method": "ecfp", "counting": False, "fpSize": 512}, "_molfeat_version": MOLFEAT_VERSION, }, { @@ -269,7 +269,8 @@ def test_fp_state(): "length": 241, "method": "fcfp-count", "counting": True, - "nBits": 241, + "fpSize": 241, + "atomInvariantsGenerator": "SerializableMorganFeatureAtomInvGen", }, "_molfeat_version": MOLFEAT_VERSION, },