Skip to content

Commit

Permalink
Add multiple molecule support (#137)
Browse files Browse the repository at this point in the history
* add failing tests

* fix parametrize spelling sigh

* fix single-ion charge

* add multi molecule support

* mark tests as fail -- current models weren't trained with 0 bonds as feature

* make test stronger

* update changelog
  • Loading branch information
lilyminium authored Sep 20, 2024
1 parent a4d7d35 commit eac87f6
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The rules for this file:

### Added
- General linear fit target and example (PR #131)
- Support for multiple molecules (PR #137, Issue #136)

### Changed
- Removed unused, undocumented code paths, and updated docs (PR #132)
Expand Down
5 changes: 4 additions & 1 deletion openff/nagl/molecule/_dgl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def openff_molecule_to_base_dgl_graph(
from openff.nagl.toolkits.openff import get_openff_molecule_bond_indices

bonds = get_openff_molecule_bond_indices(molecule)
indices_a, indices_b = map(list, zip(*bonds))
if bonds:
indices_a, indices_b = map(list, zip(*bonds))
else:
indices_a, indices_b = [], []
indices_a = torch.tensor(indices_a, dtype=torch.int32)
indices_b = torch.tensor(indices_b, dtype=torch.int32)

Expand Down
91 changes: 90 additions & 1 deletion openff/nagl/nn/_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import copy
import collections
import logging
import types
from typing import TYPE_CHECKING, Tuple, Dict, Union, Callable, Literal, Optional
Expand Down Expand Up @@ -165,6 +165,95 @@ def compute_properties(
"""
Compute the trained property for a molecule.
Parameters
----------
molecule: :class:`~openff.toolkit.topology.Molecule`
The molecule to compute the property for.
as_numpy: bool
Whether to return the result as a numpy array.
If ``False``, the result will be a ``torch.Tensor``.
check_domains: bool
Whether to check if the molecule is similar
to the training data.
error_if_unsupported: bool
Whether to raise an error if the molecule
is not represented in the training data.
This is only used if ``check_domains`` is ``True``.
If ``False``, a warning will be raised instead.
check_lookup_table: bool
Whether to check a lookup table for the property values.
If ``False`` or if the molecule is not in the lookup
table, the property will be computed using the model.
Returns
-------
result: Dict[str, torch.Tensor] or Dict[str, numpy.ndarray]
"""
import numpy as np

# split up molecule in case it's fragments
from openff.nagl.toolkits.openff import split_up_molecule

fragments, all_indices = split_up_molecule(molecule)
# TODO: this assumes atom-wise properties
# we should add support for bond-wise/more general properties

results = [
self._compute_properties(
fragment,
as_numpy=as_numpy,
check_domains=check_domains,
error_if_unsupported=error_if_unsupported,
check_lookup_table=check_lookup_table,
)
for fragment in fragments
]

# combine the results
combined_results = {}

if as_numpy:
tensor = np.empty
else:
tensor = torch.empty
for property_name, value in results[0].items():
combined_results[property_name] = tensor(
molecule.n_atoms,
dtype=value.dtype
)

seen_indices = collections.defaultdict(set)

for result, indices in zip(results, all_indices):
for property_name, value in result.items():
combined_results[property_name][indices] = value
if seen_indices[property_name] & set(indices):
raise ValueError(
"Overlapping indices in the fragments"
)
seen_indices[property_name].update(indices)

expected_indices = list(range(molecule.n_atoms))
for property_name, seen_indices in seen_indices.items():
assert sorted(seen_indices) == expected_indices, (
f"Missing indices for property {property_name}: "
f"{set(expected_indices) - seen_indices}"
)
return combined_results



def _compute_properties(
self,
molecule: "Molecule",
as_numpy: bool = True,
check_domains: bool = False,
error_if_unsupported: bool = True,
check_lookup_table: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Compute the trained property for a molecule.
Parameters
----------
molecule: :class:`~openff.toolkit.topology.Molecule`
Expand Down
76 changes: 76 additions & 0 deletions openff/nagl/tests/nn/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from openff.units import unit
from openff.toolkit.topology import Molecule
from openff.toolkit.utils.toolkits import RDKIT_AVAILABLE

from openff.nagl.nn.gcn._sage import SAGEConvStack
from openff.nagl.nn._containers import ConvolutionModule, ReadoutModule
Expand Down Expand Up @@ -424,3 +425,78 @@ def test_compute_property(
assert charges.dtype == np.float32

assert_allclose(charges, expected_charges, atol=1e-5)

@pytest.mark.xfail(reason="Model does not include 0 bonds as feature")
def test_assign_partial_charges_to_ion(self, model):
mol = Molecule.from_smiles("[Cl-]")
assert mol.n_atoms == 1

charges = model.compute_property(mol, as_numpy=True).flatten()
assert np.isclose(charges[-1], -1.)

@pytest.mark.xfail(reason="Model does not include 0 bonds as feature")
def test_assign_partial_charges_to_hcl_salt(self, model):
mol = Molecule.from_mapped_smiles("[Cl-:1].[H+:2]")
assert mol.n_atoms == 2

charges = model.compute_property(mol, as_numpy=True).flatten()
assert np.isclose(charges[0], -1.)
assert np.isclose(charges[1], 1.)

@pytest.mark.skipif(not RDKIT_AVAILABLE, reason="requires rdkit")
@pytest.mark.parametrize(
"smiles, expected_formal_charges", [
("CCCn1cc[n+](C)c1.C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F", [1, -1]),
]
)
def test_multimolecule_smiles(self, model, smiles, expected_formal_charges):
from rdkit import Chem

mol = Molecule.from_smiles(smiles)
charges = model.compute_property(mol, as_numpy=True)

# work out which charges belong to which molecule
rdmol = mol.to_rdkit()
# assume lowest atoms are in order of left to right
fragment_indices = []
fragments = Chem.GetMolFrags(
rdmol,
asMols=True,
fragsMolAtomMapping=fragment_indices
)
assert len(fragment_indices) == len(expected_formal_charges)

# sort to get lowest atoms
min_indices = [min(indices) for indices in fragment_indices]
argsorted = np.argsort(min_indices)
fragment_indices = [fragment_indices[i] for i in argsorted]
fragments = [fragments[i] for i in argsorted]

# test individually assigned charges *and* sums are correct
# and we didn't muddle indices somehow
individual_smiles = smiles.split(".")
for i, indices in enumerate(fragment_indices):
expected_charge = expected_formal_charges[i]
fragment_charges = charges[list(indices)]
assert np.allclose(sum(fragment_charges), expected_charge)

individual_mol = Molecule.from_smiles(
individual_smiles[i],
allow_undefined_stereo=True
)
individual_charges = model.compute_property(individual_mol, as_numpy=True)
mol_fragment = Molecule.from_rdkit(fragments[i])

# remap to fragment charges
is_iso, atom_mapping = Molecule.are_isomorphic(
individual_mol, mol_fragment, return_atom_map=True
)
assert is_iso
# atom_mapping has k:v of mol2_index: mol_fragment_index
remapped_fragment_charges = [
fragment_charges[v]
for _, v in sorted(atom_mapping.items())
]
assert np.allclose(individual_charges, remapped_fragment_charges)


35 changes: 34 additions & 1 deletion openff/nagl/tests/utils/test_openff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
from openff.toolkit.topology import Molecule
from openff.nagl.toolkits import NAGLRDKitToolkitWrapper
from openff.toolkit import RDKitToolkitWrapper
from openff.toolkit.utils.toolkit_registry import toolkit_registry_manager, ToolkitRegistry
Expand All @@ -21,6 +22,7 @@
molecule_from_networkx,
_molecule_from_dict,
_molecule_to_dict,
split_up_molecule,
)
from openff.nagl.utils._utils import transform_coordinates

Expand Down Expand Up @@ -290,4 +292,35 @@ def test_molecule_to_dict(openff_methane_uncharged):
def test_molecule_from_dict(openff_methane_uncharged):
graph = _molecule_to_dict(openff_methane_uncharged)
molecule = _molecule_from_dict(graph)
assert molecule.is_isomorphic_with(openff_methane_uncharged)
assert molecule.is_isomorphic_with(openff_methane_uncharged)

def test_split_up_molecule():
# "N.c1ccccc1.C.CCN"
mapped_smiles = (
"[H:17][c:4]1[c:3]([c:2]([c:7]([c:6]([c:5]1[H:18])[H:19])[H:20])[H:15])[H:16]"
".[H:21][C:8]([H:22])([H:23])[H:24]"
".[H:25][C:9]([H:26])([H:27])[C:10]([H:28])([H:29])[N:11]([H:30])[H:31]"
".[H:12][N:1]([H:13])[H:14]"
)
molecule = Molecule.from_mapped_smiles(mapped_smiles)

fragments, indices = split_up_molecule(molecule, return_indices=True)
assert len(fragments) == 4

# check order
n = Molecule.from_smiles("N")
benzene = Molecule.from_smiles("c1ccccc1")
ethanamine = Molecule.from_smiles("CCN")
methane = Molecule.from_smiles("C")

assert fragments[0].is_isomorphic_with(n)
assert fragments[1].is_isomorphic_with(benzene)
assert fragments[2].is_isomorphic_with(methane)
assert fragments[3].is_isomorphic_with(ethanamine)

assert indices[0] == [0, 11, 12, 13]
assert indices[1] == [1, 2, 3, 4, 5, 6, 14, 15, 16, 17, 18, 19]
assert indices[2] == [7, 20, 21, 22, 23]
assert indices[3] == [8, 9, 10, 24, 25, 26, 27, 28, 29, 30]


44 changes: 44 additions & 0 deletions openff/nagl/toolkits/openff.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,50 @@ def is_conformer_identical(
return rmsd.m_as(unit.angstrom) < atol


def _split_up_molecule_into_indices(
molecule: "Molecule",
) -> list[list[int]]:
import networkx as nx

graph = molecule.to_networkx()
return list(map(list, nx.connected_components(graph)))

def split_up_molecule(
molecule: "Molecule",
return_indices: bool = True
) -> list["Molecule"]:
"""
Split up a molecule into its connected components.
Parameters
----------
molecule: openff.toolkit.topology.Molecule
The molecule to split up.
return_indices: bool, default=True
If the indices of the atoms in each component should be returned.
Returns
-------
components: List[openff.toolkit.topology.Molecule]
The connected components of the molecule.
"""
import networkx as nx

graph = molecule.to_networkx()
indices = list(map(list, nx.connected_components(graph)))

fragments = []
for ix in indices:
subgraph = nx.convert_node_labels_to_integers(graph.subgraph(ix))
fragment = molecule_from_networkx(subgraph)
fragments.append(fragment)

if return_indices:
return fragments, indices
return fragments



def normalize_molecule(
molecule: "Molecule",
max_iter: int = 200,
Expand Down

0 comments on commit eac87f6

Please sign in to comment.