From eac87f650ffd0825e0794f8f7319d39285ad4060 Mon Sep 17 00:00:00 2001 From: Lily Wang <31115101+lilyminium@users.noreply.github.com> Date: Fri, 20 Sep 2024 13:01:50 +1000 Subject: [PATCH] Add multiple molecule support (#137) * add failing tests * fix parametrize spelling sigh * fix single-ion charge * add multi molecule support * mark tests as fail -- current models weren't trained with 0 bonds as feature * make test stronger * update changelog --- docs/CHANGELOG.md | 1 + openff/nagl/molecule/_dgl/utils.py | 5 +- openff/nagl/nn/_models.py | 91 +++++++++++++++++++++++++- openff/nagl/tests/nn/test_model.py | 76 +++++++++++++++++++++ openff/nagl/tests/utils/test_openff.py | 35 +++++++++- openff/nagl/toolkits/openff.py | 44 +++++++++++++ 6 files changed, 249 insertions(+), 3 deletions(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 45b63634..20ee08ae 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -21,6 +21,7 @@ The rules for this file: ### Added - General linear fit target and example (PR #131) +- Support for multiple molecules (PR #137, Issue #136) ### Changed - Removed unused, undocumented code paths, and updated docs (PR #132) diff --git a/openff/nagl/molecule/_dgl/utils.py b/openff/nagl/molecule/_dgl/utils.py index b4219b2d..d274da19 100644 --- a/openff/nagl/molecule/_dgl/utils.py +++ b/openff/nagl/molecule/_dgl/utils.py @@ -28,7 +28,10 @@ def openff_molecule_to_base_dgl_graph( from openff.nagl.toolkits.openff import get_openff_molecule_bond_indices bonds = get_openff_molecule_bond_indices(molecule) - indices_a, indices_b = map(list, zip(*bonds)) + if bonds: + indices_a, indices_b = map(list, zip(*bonds)) + else: + indices_a, indices_b = [], [] indices_a = torch.tensor(indices_a, dtype=torch.int32) indices_b = torch.tensor(indices_b, dtype=torch.int32) diff --git a/openff/nagl/nn/_models.py b/openff/nagl/nn/_models.py index 7a3a8138..6e6d1603 100644 --- a/openff/nagl/nn/_models.py +++ b/openff/nagl/nn/_models.py @@ -1,4 +1,4 @@ -import copy +import collections import logging import types from typing import TYPE_CHECKING, Tuple, Dict, Union, Callable, Literal, Optional @@ -165,6 +165,95 @@ def compute_properties( """ Compute the trained property for a molecule. + Parameters + ---------- + molecule: :class:`~openff.toolkit.topology.Molecule` + The molecule to compute the property for. + as_numpy: bool + Whether to return the result as a numpy array. + If ``False``, the result will be a ``torch.Tensor``. + check_domains: bool + Whether to check if the molecule is similar + to the training data. + error_if_unsupported: bool + Whether to raise an error if the molecule + is not represented in the training data. + This is only used if ``check_domains`` is ``True``. + If ``False``, a warning will be raised instead. + check_lookup_table: bool + Whether to check a lookup table for the property values. + If ``False`` or if the molecule is not in the lookup + table, the property will be computed using the model. + + Returns + ------- + result: Dict[str, torch.Tensor] or Dict[str, numpy.ndarray] + """ + import numpy as np + + # split up molecule in case it's fragments + from openff.nagl.toolkits.openff import split_up_molecule + + fragments, all_indices = split_up_molecule(molecule) + # TODO: this assumes atom-wise properties + # we should add support for bond-wise/more general properties + + results = [ + self._compute_properties( + fragment, + as_numpy=as_numpy, + check_domains=check_domains, + error_if_unsupported=error_if_unsupported, + check_lookup_table=check_lookup_table, + ) + for fragment in fragments + ] + + # combine the results + combined_results = {} + + if as_numpy: + tensor = np.empty + else: + tensor = torch.empty + for property_name, value in results[0].items(): + combined_results[property_name] = tensor( + molecule.n_atoms, + dtype=value.dtype + ) + + seen_indices = collections.defaultdict(set) + + for result, indices in zip(results, all_indices): + for property_name, value in result.items(): + combined_results[property_name][indices] = value + if seen_indices[property_name] & set(indices): + raise ValueError( + "Overlapping indices in the fragments" + ) + seen_indices[property_name].update(indices) + + expected_indices = list(range(molecule.n_atoms)) + for property_name, seen_indices in seen_indices.items(): + assert sorted(seen_indices) == expected_indices, ( + f"Missing indices for property {property_name}: " + f"{set(expected_indices) - seen_indices}" + ) + return combined_results + + + + def _compute_properties( + self, + molecule: "Molecule", + as_numpy: bool = True, + check_domains: bool = False, + error_if_unsupported: bool = True, + check_lookup_table: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + Compute the trained property for a molecule. + Parameters ---------- molecule: :class:`~openff.toolkit.topology.Molecule` diff --git a/openff/nagl/tests/nn/test_model.py b/openff/nagl/tests/nn/test_model.py index 3174f6ed..1e04c720 100644 --- a/openff/nagl/tests/nn/test_model.py +++ b/openff/nagl/tests/nn/test_model.py @@ -6,6 +6,7 @@ from openff.units import unit from openff.toolkit.topology import Molecule +from openff.toolkit.utils.toolkits import RDKIT_AVAILABLE from openff.nagl.nn.gcn._sage import SAGEConvStack from openff.nagl.nn._containers import ConvolutionModule, ReadoutModule @@ -424,3 +425,78 @@ def test_compute_property( assert charges.dtype == np.float32 assert_allclose(charges, expected_charges, atol=1e-5) + + @pytest.mark.xfail(reason="Model does not include 0 bonds as feature") + def test_assign_partial_charges_to_ion(self, model): + mol = Molecule.from_smiles("[Cl-]") + assert mol.n_atoms == 1 + + charges = model.compute_property(mol, as_numpy=True).flatten() + assert np.isclose(charges[-1], -1.) + + @pytest.mark.xfail(reason="Model does not include 0 bonds as feature") + def test_assign_partial_charges_to_hcl_salt(self, model): + mol = Molecule.from_mapped_smiles("[Cl-:1].[H+:2]") + assert mol.n_atoms == 2 + + charges = model.compute_property(mol, as_numpy=True).flatten() + assert np.isclose(charges[0], -1.) + assert np.isclose(charges[1], 1.) + + @pytest.mark.skipif(not RDKIT_AVAILABLE, reason="requires rdkit") + @pytest.mark.parametrize( + "smiles, expected_formal_charges", [ + ("CCCn1cc[n+](C)c1.C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F", [1, -1]), + ] + ) + def test_multimolecule_smiles(self, model, smiles, expected_formal_charges): + from rdkit import Chem + + mol = Molecule.from_smiles(smiles) + charges = model.compute_property(mol, as_numpy=True) + + # work out which charges belong to which molecule + rdmol = mol.to_rdkit() + # assume lowest atoms are in order of left to right + fragment_indices = [] + fragments = Chem.GetMolFrags( + rdmol, + asMols=True, + fragsMolAtomMapping=fragment_indices + ) + assert len(fragment_indices) == len(expected_formal_charges) + + # sort to get lowest atoms + min_indices = [min(indices) for indices in fragment_indices] + argsorted = np.argsort(min_indices) + fragment_indices = [fragment_indices[i] for i in argsorted] + fragments = [fragments[i] for i in argsorted] + + # test individually assigned charges *and* sums are correct + # and we didn't muddle indices somehow + individual_smiles = smiles.split(".") + for i, indices in enumerate(fragment_indices): + expected_charge = expected_formal_charges[i] + fragment_charges = charges[list(indices)] + assert np.allclose(sum(fragment_charges), expected_charge) + + individual_mol = Molecule.from_smiles( + individual_smiles[i], + allow_undefined_stereo=True + ) + individual_charges = model.compute_property(individual_mol, as_numpy=True) + mol_fragment = Molecule.from_rdkit(fragments[i]) + + # remap to fragment charges + is_iso, atom_mapping = Molecule.are_isomorphic( + individual_mol, mol_fragment, return_atom_map=True + ) + assert is_iso + # atom_mapping has k:v of mol2_index: mol_fragment_index + remapped_fragment_charges = [ + fragment_charges[v] + for _, v in sorted(atom_mapping.items()) + ] + assert np.allclose(individual_charges, remapped_fragment_charges) + + diff --git a/openff/nagl/tests/utils/test_openff.py b/openff/nagl/tests/utils/test_openff.py index 0a7e286e..7f9693ec 100644 --- a/openff/nagl/tests/utils/test_openff.py +++ b/openff/nagl/tests/utils/test_openff.py @@ -3,6 +3,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from openff.toolkit.topology import Molecule from openff.nagl.toolkits import NAGLRDKitToolkitWrapper from openff.toolkit import RDKitToolkitWrapper from openff.toolkit.utils.toolkit_registry import toolkit_registry_manager, ToolkitRegistry @@ -21,6 +22,7 @@ molecule_from_networkx, _molecule_from_dict, _molecule_to_dict, + split_up_molecule, ) from openff.nagl.utils._utils import transform_coordinates @@ -290,4 +292,35 @@ def test_molecule_to_dict(openff_methane_uncharged): def test_molecule_from_dict(openff_methane_uncharged): graph = _molecule_to_dict(openff_methane_uncharged) molecule = _molecule_from_dict(graph) - assert molecule.is_isomorphic_with(openff_methane_uncharged) \ No newline at end of file + assert molecule.is_isomorphic_with(openff_methane_uncharged) + +def test_split_up_molecule(): + # "N.c1ccccc1.C.CCN" + mapped_smiles = ( + "[H:17][c:4]1[c:3]([c:2]([c:7]([c:6]([c:5]1[H:18])[H:19])[H:20])[H:15])[H:16]" + ".[H:21][C:8]([H:22])([H:23])[H:24]" + ".[H:25][C:9]([H:26])([H:27])[C:10]([H:28])([H:29])[N:11]([H:30])[H:31]" + ".[H:12][N:1]([H:13])[H:14]" + ) + molecule = Molecule.from_mapped_smiles(mapped_smiles) + + fragments, indices = split_up_molecule(molecule, return_indices=True) + assert len(fragments) == 4 + + # check order + n = Molecule.from_smiles("N") + benzene = Molecule.from_smiles("c1ccccc1") + ethanamine = Molecule.from_smiles("CCN") + methane = Molecule.from_smiles("C") + + assert fragments[0].is_isomorphic_with(n) + assert fragments[1].is_isomorphic_with(benzene) + assert fragments[2].is_isomorphic_with(methane) + assert fragments[3].is_isomorphic_with(ethanamine) + + assert indices[0] == [0, 11, 12, 13] + assert indices[1] == [1, 2, 3, 4, 5, 6, 14, 15, 16, 17, 18, 19] + assert indices[2] == [7, 20, 21, 22, 23] + assert indices[3] == [8, 9, 10, 24, 25, 26, 27, 28, 29, 30] + + diff --git a/openff/nagl/toolkits/openff.py b/openff/nagl/toolkits/openff.py index 7e61ca88..cf47283c 100644 --- a/openff/nagl/toolkits/openff.py +++ b/openff/nagl/toolkits/openff.py @@ -319,6 +319,50 @@ def is_conformer_identical( return rmsd.m_as(unit.angstrom) < atol +def _split_up_molecule_into_indices( + molecule: "Molecule", +) -> list[list[int]]: + import networkx as nx + + graph = molecule.to_networkx() + return list(map(list, nx.connected_components(graph))) + +def split_up_molecule( + molecule: "Molecule", + return_indices: bool = True +) -> list["Molecule"]: + """ + Split up a molecule into its connected components. + + Parameters + ---------- + molecule: openff.toolkit.topology.Molecule + The molecule to split up. + return_indices: bool, default=True + If the indices of the atoms in each component should be returned. + + Returns + ------- + components: List[openff.toolkit.topology.Molecule] + The connected components of the molecule. + """ + import networkx as nx + + graph = molecule.to_networkx() + indices = list(map(list, nx.connected_components(graph))) + + fragments = [] + for ix in indices: + subgraph = nx.convert_node_labels_to_integers(graph.subgraph(ix)) + fragment = molecule_from_networkx(subgraph) + fragments.append(fragment) + + if return_indices: + return fragments, indices + return fragments + + + def normalize_molecule( molecule: "Molecule", max_iter: int = 200,