Skip to content

Commit

Permalink
fix general_fitting, add xp type dict
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Oct 23, 2024
1 parent a261d67 commit b213475
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
27 changes: 27 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,33 @@
DEFAULT_PRECISION = "float64"


def get_xp_precision(
xp: Any,
precision: str,
):
"""Get the precision from the API compatible namespace."""
if precision == "float16" or precision == "half":
return xp.float16
elif precision == "float32" or precision == "single":
return xp.float32
elif precision == "float64" or precision == "double":
return xp.float64
elif precision == "int32":
return xp.int32
elif precision == "int64":
return xp.int64
elif precision == "bool":
return bool
elif precision == "default":
return get_xp_precision(xp, RESERVED_PRECISON_DICT[PRECISION_DICT[precision]])
elif precision == "global":
return get_xp_precision(xp, RESERVED_PRECISON_DICT[GLOBAL_NP_FLOAT_PRECISION])
elif precision == "bfloat16":
return ml_dtypes.bfloat16
else:
raise ValueError(f"unsupported precision {precision} for {xp}")


class NativeOP(ABC):
"""The unit operation of a native model."""

Expand Down
7 changes: 5 additions & 2 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -417,7 +418,9 @@ def _call_common(

# calcualte the prediction
if not self.mixed_types:
outs = xp.zeros([nf, nloc, net_dim_out], dtype=self.prec)
outs = xp.zeros(
[nf, nloc, net_dim_out], dtype=get_xp_precision(xp, self.precision)
)
for type_i in range(self.ntypes):
mask = xp.tile(
xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, net_dim_out)
Expand All @@ -443,4 +446,4 @@ def _call_common(
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype)
return {self.var_name: outs.astype(GLOBAL_NP_FLOAT_PRECISION)}
return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))}
27 changes: 27 additions & 0 deletions source/tests/common/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import array_api_compat
import ml_dtypes
import numpy as np

from deepmd.dpmodel.common import (
get_xp_precision,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)


class TestGetXPPrecision(unittest.TestCase):
def test(self):
aa = np.zeros(3)
xp = array_api_compat.array_namespace(aa)
self.assertTrue(get_xp_precision(xp, "float16"), xp.float16)
self.assertTrue(get_xp_precision(xp, "float32"), xp.float32)
self.assertTrue(get_xp_precision(xp, "float64"), xp.float64)
self.assertTrue(get_xp_precision(xp, "single"), xp.float32)
self.assertTrue(get_xp_precision(xp, "double"), xp.float64)
self.assertTrue(get_xp_precision(xp, "global"), GLOBAL_NP_FLOAT_PRECISION)
self.assertTrue(get_xp_precision(xp, "default"), GLOBAL_NP_FLOAT_PRECISION)
self.assertTrue(get_xp_precision(xp, "bfloat16"), ml_dtypes.bfloat16)

0 comments on commit b213475

Please sign in to comment.