Skip to content

Commit

Permalink
feat: adding test to compare pbc calculation from pymatgen and ase
Browse files Browse the repository at this point in the history
  • Loading branch information
melo-gonzo committed Oct 22, 2024
1 parent 4d5eb4c commit a0af875
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions matsciml/datasets/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,77 @@ def test_ase_periodic(backend):
batch = next(iter(loader))
# check if periodic properties transform was applied
assert "unit_offsets" in batch


def test_pbc_backend_equivalence_easy():
from ase.build import molecule
from pymatgen.io.ase import AseAtomsAdaptor

atoms = molecule(
"H2O", cell=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], pbc=(True, True, True)
)
structure = AseAtomsAdaptor.get_structure(atoms)

data = {}
coords = torch.from_numpy(structure.cart_coords).float()
data["pos"] = coords
atom_numbers = torch.LongTensor(structure.atomic_numbers)
data["atomic_numbers"] = atom_numbers
data["natoms"] = len(atom_numbers)
lattice_params = torch.FloatTensor(
structure.lattice.abc
+ tuple(a * (torch.pi / 180.0) for a in structure.lattice.angles),
)
lattice_features = {
"lattice_params": lattice_params,
}
data["lattice_features"] = lattice_features

ase_trans = transforms.PeriodicPropertiesTransform(
cutoff_radius=6.0, adaptive_cutoff=True, backend="ase"
)

pymatgen_trans = transforms.PeriodicPropertiesTransform(
cutoff_radius=6.0, adaptive_cutoff=True, backend="pymatgen"
)

ase_result = ase_trans(data)
pymatgen_result = pymatgen_trans(data)

ase_wiring = torch.vstack([ase_result["src_nodes"], ase_result["dst_nodes"]])
pymatgen_wiring = torch.vstack(
[pymatgen_result["src_nodes"], pymatgen_result["dst_nodes"]]
)
equivalence = ase_wiring == pymatgen_wiring
# basically checking if src -> dst node wiring is equivalent between the two approaches
assert torch.all(equivalence)


def test_pbc_backend_equivalence_hard():
ase_trans = transforms.PeriodicPropertiesTransform(
cutoff_radius=6.0, adaptive_cutoff=True, backend="ase"
)

pymatgen_trans = transforms.PeriodicPropertiesTransform(
cutoff_radius=6.0, adaptive_cutoff=True, backend="pymatgen"
)

dm = MatSciMLDataModule.from_devset(
"S2EFDataset",
batch_size=1,
)

dm.setup()
loader = dm.train_dataloader()
batch = next(iter(loader))
batch["atomic_numbers"] = batch["atomic_numbers"].squeeze(0)

ase_result = ase_trans(batch)
pymatgen_result = pymatgen_trans(batch)
ase_wiring = torch.vstack([ase_result["src_nodes"], ase_result["dst_nodes"]])
pymatgen_wiring = torch.vstack(
[pymatgen_result["src_nodes"], pymatgen_result["dst_nodes"]]
)
equivalence = ase_wiring == pymatgen_wiring
# basically checking if src -> dst node wiring is equivalent between the two approaches
assert torch.all(equivalence)

0 comments on commit a0af875

Please sign in to comment.