diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index d9d57d2d6c..5c75229e49 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -50,6 +50,33 @@ DEFAULT_PRECISION = "float64" +def get_xp_precision( + xp: Any, + precision: str, +): + """Get the precision from the API compatible namespace.""" + if precision == "float16" or precision == "half": + return xp.float16 + elif precision == "float32" or precision == "single": + return xp.float32 + elif precision == "float64" or precision == "double": + return xp.float64 + elif precision == "int32": + return xp.int32 + elif precision == "int64": + return xp.int64 + elif precision == "bool": + return bool + elif precision == "default": + return get_xp_precision(xp, RESERVED_PRECISON_DICT[PRECISION_DICT[precision]]) + elif precision == "global": + return get_xp_precision(xp, RESERVED_PRECISON_DICT[GLOBAL_NP_FLOAT_PRECISION]) + elif precision == "bfloat16": + return ml_dtypes.bfloat16 + else: + raise ValueError(f"unsupported precision {precision} for {xp}") + + class NativeOP(ABC): """The unit operation of a native model.""" diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 8e9e228787..016f947206 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -18,6 +18,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + get_xp_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -417,7 +418,9 @@ def _call_common( # calcualte the prediction if not self.mixed_types: - outs = xp.zeros([nf, nloc, net_dim_out], dtype=self.prec) + outs = xp.zeros( + [nf, nloc, net_dim_out], dtype=get_xp_precision(xp, self.precision) + ) for type_i in range(self.ntypes): mask = xp.tile( xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, net_dim_out) @@ -443,4 +446,4 @@ def _call_common( exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype) - return {self.var_name: outs.astype(GLOBAL_NP_FLOAT_PRECISION)} + return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))} diff --git a/source/tests/common/test_common.py b/source/tests/common/test_common.py new file mode 100644 index 0000000000..478c512fed --- /dev/null +++ b/source/tests/common/test_common.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import array_api_compat +import ml_dtypes +import numpy as np + +from deepmd.dpmodel.common import ( + get_xp_precision, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + + +class TestGetXPPrecision(unittest.TestCase): + def test(self): + aa = np.zeros(3) + xp = array_api_compat.array_namespace(aa) + self.assertTrue(get_xp_precision(xp, "float16"), xp.float16) + self.assertTrue(get_xp_precision(xp, "float32"), xp.float32) + self.assertTrue(get_xp_precision(xp, "float64"), xp.float64) + self.assertTrue(get_xp_precision(xp, "single"), xp.float32) + self.assertTrue(get_xp_precision(xp, "double"), xp.float64) + self.assertTrue(get_xp_precision(xp, "global"), GLOBAL_NP_FLOAT_PRECISION) + self.assertTrue(get_xp_precision(xp, "default"), GLOBAL_NP_FLOAT_PRECISION) + self.assertTrue(get_xp_precision(xp, "bfloat16"), ml_dtypes.bfloat16)