From c00dc84bdc9632eebbf7ec87512346468f6b6f25 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 10 Oct 2024 08:27:52 -0700 Subject: [PATCH 1/4] refactor: passing vesta=True to lattice creation --- matsciml/datasets/transforms/pbc.py | 2 +- matsciml/datasets/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index 09a8365a..14b827a2 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -60,7 +60,7 @@ def __call__(self, data: DataDict) -> DataDict: angles = torch.FloatTensor( tuple(angle * (180.0 / torch.pi) for angle in angles), ) - lattice = Lattice.from_parameters(*abc, *angles) + lattice = Lattice.from_parameters(*abc, *angles, vesta=True) structure = make_pymatgen_periodic_structure( data["atomic_numbers"], data["pos"], diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index b0410573..ba9986e3 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -644,7 +644,7 @@ def make_pymatgen_periodic_structure( "Unable to construct Lattice object without parameters:" f" Angles: {lat_angles}, ABC: {lat_abc}", ) - lattice = Lattice(*lat_abc, *lat_angles) + lattice = Lattice(*lat_abc, *lat_angles, vesta=True) structure = Structure( lattice, atomic_numbers, From a99f869a2a37a7b72f2d1dcee40d7a4893cd4ec8 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 10 Oct 2024 08:32:44 -0700 Subject: [PATCH 2/4] refactor: add case where serialized structure exists --- matsciml/datasets/transforms/pbc.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index 14b827a2..15a441db 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -1,7 +1,7 @@ from __future__ import annotations import torch -from pymatgen.core import Lattice +from pymatgen.core import Lattice, Structure from matsciml.common.types import DataDict from matsciml.datasets.transforms.base import AbstractDataTransform @@ -37,6 +37,12 @@ def __init__(self, cutoff_radius: float, adaptive_cutoff: bool = False) -> None: def __call__(self, data: DataDict) -> DataDict: for key in ["atomic_numbers", "pos"]: assert key in data, f"{key} missing from data sample!" + if "structure" in data: + structure = data["structure"] + if isinstance(structure, Structure): + graph_props = calculate_periodic_shifts( + structure, self.cutoff_radius, self.adaptive_cutoff + ) if "cell" in data: # squeeze is used to make sure we remove empty dims lattice = Lattice(data["cell"].squeeze()) From 0dc94fdc97352fe33a48abc305053b4e0e3a6a5c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 10 Oct 2024 08:34:09 -0700 Subject: [PATCH 3/4] refactor: directly use structure if available for periodic properties --- matsciml/datasets/transforms/pbc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index 15a441db..53ca0142 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -37,12 +37,16 @@ def __init__(self, cutoff_radius: float, adaptive_cutoff: bool = False) -> None: def __call__(self, data: DataDict) -> DataDict: for key in ["atomic_numbers", "pos"]: assert key in data, f"{key} missing from data sample!" + # if we have a pymatgen structure serialized already use it directly if "structure" in data: structure = data["structure"] if isinstance(structure, Structure): graph_props = calculate_periodic_shifts( structure, self.cutoff_radius, self.adaptive_cutoff ) + data.update(graph_props) + return data + # continue this branch if the structure doesn't qualify if "cell" in data: # squeeze is used to make sure we remove empty dims lattice = Lattice(data["cell"].squeeze()) From c37a84661e942d261a1bf11dbd10b81cfe0c681d Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 10 Oct 2024 08:58:29 -0700 Subject: [PATCH 4/4] refactor: changing missing lattice parameters to runtime error --- matsciml/datasets/transforms/pbc.py | 38 ++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index 53ca0142..b8f9de0b 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch +import numpy as np from pymatgen.core import Lattice, Structure from matsciml.common.types import DataDict @@ -35,6 +36,38 @@ def __init__(self, cutoff_radius: float, adaptive_cutoff: bool = False) -> None: self.adaptive_cutoff = adaptive_cutoff def __call__(self, data: DataDict) -> DataDict: + """ + Given a data sample, generate graph edges with periodic boundary conditions + as specified by ``cutoff_radius`` and ``adaptive_cutoff``. + + This function has several nested conditions, depending on the availability + of data pertaining to periodic structures. First and foremost, if there + is a serialized ``pymatgen.core.Structure`` object, we will take that + directly and use it to compute the periodic shifts as to minimize ambiguituies. + If there isn't one available, we then check for the presence of a ``cell`` + or lattice matrix, from which we use to create a ``Lattice`` object that + is *then* used to create a ``pymatgen.core.Structure``. If a ``cell`` isn't + available, the final check is to look for keys related to lattice parameters, + and use those instead. + + Parameters + ---------- + data : DataDict + Data sample retrieved from a dataset. + + Returns + ------- + DataDict : DataDict + Data sample, now with updated key/values based on periodic + properties. See ``calculate_periodic_shifts`` for the additional + keys. + + Raises + ------ + RuntimeError: + If the final check for lattice parameters fails, there is nothing + we can base the periodic boundary calculation off of. + """ for key in ["atomic_numbers", "pos"]: assert key in data, f"{key} missing from data sample!" # if we have a pymatgen structure serialized already use it directly @@ -48,6 +81,9 @@ def __call__(self, data: DataDict) -> DataDict: return data # continue this branch if the structure doesn't qualify if "cell" in data: + assert isinstance( + data["cell"], (torch.Tensor, np.ndarray) + ), "Lattice matrix is not array-like." # squeeze is used to make sure we remove empty dims lattice = Lattice(data["cell"].squeeze()) else: @@ -58,7 +94,7 @@ def __call__(self, data: DataDict) -> DataDict: elif "lattice_params" in data: lattice_key = "lattice_params" else: - raise KeyError( + raise RuntimeError( "Data sample is missing lattice parameters. " "Ensure `lattice_features` or `lattice_params` is available" " in the data.",