From 9c4447c9562bc11777ed15c692cb272581c36ebe Mon Sep 17 00:00:00 2001 From: Orion Cohen <27712051+orionarcher@users.noreply.github.com> Date: Thu, 31 Oct 2024 04:10:06 -0400 Subject: [PATCH] Allow custom mace model by specifying "model" in calculator kwargs" (#1017) * allow custom mace model by specifying "model" in calculator kwargs" * fix error in trying to turn None into path * Add support for ORB model * Specify more dependencies * remove orb implementation * add line * add line * remove device * set device * fix set device * fix test * fix linting * restore test * remove os and rely on pathlib only --------- Co-authored-by: J. George Co-authored-by: JaGeo --- src/atomate2/forcefields/utils.py | 19 ++++++++++++++++--- tests/forcefields/test_jobs.py | 4 ++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/atomate2/forcefields/utils.py b/src/atomate2/forcefields/utils.py index 83d4280177..d8125421b9 100644 --- a/src/atomate2/forcefields/utils.py +++ b/src/atomate2/forcefields/utils.py @@ -4,6 +4,7 @@ import json from contextlib import contextmanager +from pathlib import Path from typing import TYPE_CHECKING from monty.json import MontyDecoder @@ -59,9 +60,21 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N calculator = PESCalculator(potential, **kwargs) elif calculator_name == MLFF.MACE: - from mace.calculators import mace_mp - - calculator = mace_mp(**kwargs) + from mace.calculators import MACECalculator, mace_mp + + model = kwargs.get("model") + if isinstance(model, str | Path) and Path(model).exists(): + model_path = model + device = kwargs.get("device") or "cpu" + if "device" in kwargs: + del kwargs["device"] + calculator = MACECalculator( + model_paths=model_path, + device=device, + **kwargs, + ) + else: + calculator = mace_mp(**kwargs) elif calculator_name == MLFF.GAP: from quippy.potential import Potential diff --git a/tests/forcefields/test_jobs.py b/tests/forcefields/test_jobs.py index 779310aef2..0dbb765311 100644 --- a/tests/forcefields/test_jobs.py +++ b/tests/forcefields/test_jobs.py @@ -277,7 +277,7 @@ def test_mace_relax_maker( # NOTE the test model is not trained on Si, so the energy is not accurate job = ForceFieldRelaxMaker( force_field_name="MACE", - calculator_kwargs={"model": model}, + calculator_kwargs={"model": model, "default_dtype": "float32"}, steps=25, optimizer_kwargs={"optimizer": "BFGSLineSearch"}, relax_cell=relax_cell, @@ -308,7 +308,7 @@ def test_mace_relax_maker( if fix_symmetry: # if symmetry is fixed, the symmetry should be the same or higher assert is_subgroup(symmetry_ops_init, symmetry_ops_final) - else: # if symmetry is not fixed, it can both increase or decrease + else: # if symmetry is not fixed, it can both increase or decrease or stay the same assert not is_subgroup(symmetry_ops_init, symmetry_ops_final) if relax_cell: