Skip to content

Commit

Permalink
Allow custom mace model by specifying "model" in calculator kwargs" (#…
Browse files Browse the repository at this point in the history
…1017)

* allow custom mace model by specifying "model" in calculator kwargs"

* fix error in trying to turn None into path

* Add support for ORB model

* Specify more dependencies

* remove orb implementation

* add line

* add line

* remove device

* set device

* fix set device

* fix test

* fix linting

* restore test

* remove os and rely on pathlib only

---------

Co-authored-by: J. George <[email protected]>
Co-authored-by: JaGeo <[email protected]>
  • Loading branch information
3 people authored Oct 31, 2024
1 parent 42bc7b8 commit 9c4447c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
19 changes: 16 additions & 3 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING

from monty.json import MontyDecoder
Expand Down Expand Up @@ -59,9 +60,21 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
calculator = PESCalculator(potential, **kwargs)

elif calculator_name == MLFF.MACE:
from mace.calculators import mace_mp

calculator = mace_mp(**kwargs)
from mace.calculators import MACECalculator, mace_mp

model = kwargs.get("model")
if isinstance(model, str | Path) and Path(model).exists():
model_path = model
device = kwargs.get("device") or "cpu"
if "device" in kwargs:
del kwargs["device"]
calculator = MACECalculator(
model_paths=model_path,
device=device,
**kwargs,
)
else:
calculator = mace_mp(**kwargs)

elif calculator_name == MLFF.GAP:
from quippy.potential import Potential
Expand Down
4 changes: 2 additions & 2 deletions tests/forcefields/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_mace_relax_maker(
# NOTE the test model is not trained on Si, so the energy is not accurate
job = ForceFieldRelaxMaker(
force_field_name="MACE",
calculator_kwargs={"model": model},
calculator_kwargs={"model": model, "default_dtype": "float32"},
steps=25,
optimizer_kwargs={"optimizer": "BFGSLineSearch"},
relax_cell=relax_cell,
Expand Down Expand Up @@ -308,7 +308,7 @@ def test_mace_relax_maker(

if fix_symmetry: # if symmetry is fixed, the symmetry should be the same or higher
assert is_subgroup(symmetry_ops_init, symmetry_ops_final)
else: # if symmetry is not fixed, it can both increase or decrease
else: # if symmetry is not fixed, it can both increase or decrease or stay the same
assert not is_subgroup(symmetry_ops_init, symmetry_ops_final)

if relax_cell:
Expand Down

0 comments on commit 9c4447c

Please sign in to comment.