Skip to content

Commit

Permalink
Merge branch 'devel' into torchfix
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored Oct 23, 2024
2 parents b096290 + 911f41b commit b7f86c9
Show file tree
Hide file tree
Showing 33 changed files with 323 additions and 50 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
exclude: ^source/3rdparty
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.9
rev: v0.7.0
hooks:
- id: ruff
args: ["--fix"]
Expand Down Expand Up @@ -60,7 +60,7 @@ repos:
- id: blacken-docs
# C++
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.1
rev: v19.1.2
hooks:
- id: clang-format
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$)
Expand All @@ -74,7 +74,7 @@ repos:
exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
# Shell
- repo: https://github.com/scop/pre-commit-shfmt
rev: v3.9.0-1
rev: v3.10.0-1
hooks:
- id: shfmt
# CMake
Expand Down
20 changes: 12 additions & 8 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import math
from typing import (
Optional,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
to_numpy_array,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
Expand Down Expand Up @@ -172,17 +174,18 @@ def forward_common_atomic(
ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.
"""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = np.where(pair_mask == 1, nlist, -1)
nlist = xp.where(pair_mask == 1, nlist, -1)

ext_atom_mask = self.make_atom_mask(extended_atype)
ret_dict = self.forward_atomic(
extended_coord,
np.where(ext_atom_mask, extended_atype, 0),
xp.where(ext_atom_mask, extended_atype, 0),
nlist,
mapping=mapping,
fparam=fparam,
Expand All @@ -191,13 +194,13 @@ def forward_common_atomic(
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(np.int32)
atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32)
if self.atom_excl is not None:
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = np.prod(out_shape[2:])
out_shape2 = math.prod(out_shape[2:])
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
Expand Down Expand Up @@ -232,14 +235,15 @@ def serialize(self) -> dict:
"rcond": self.rcond,
"preset_out_bias": self.preset_out_bias,
"@variables": {
"out_bias": self.out_bias,
"out_std": self.out_std,
"out_bias": to_numpy_array(self.out_bias),
"out_std": to_numpy_array(self.out_std),
},
}

@classmethod
def deserialize(cls, data: dict) -> "BaseAtomicModel":
data = copy.deepcopy(data)
# do not deep copy Descriptor and Fitting class
data = data.copy()
variables = data.pop("@variables")
obj = cls(**data)
for kk in variables.keys():
Expand Down
10 changes: 8 additions & 2 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,20 @@ def serialize(self) -> dict:
)
return dd

# for subclass overriden
base_descriptor_cls = BaseDescriptor
"""The base descriptor class."""
base_fitting_cls = BaseFitting
"""The base fitting class."""

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 2)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
descriptor_obj = cls.base_descriptor_cls.deserialize(data.pop("descriptor"))
fitting_obj = cls.base_fitting_cls.deserialize(data.pop("fitting"))
data["descriptor"] = descriptor_obj
data["fitting"] = fitting_obj
obj = super().deserialize(data)
Expand Down
35 changes: 16 additions & 19 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Optional,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.atomic_model.base_atomic_model import (
Expand Down Expand Up @@ -75,7 +76,8 @@ def __init__(
else:
self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs)
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
# not supported by flax
# self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION
self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION

Expand Down Expand Up @@ -253,9 +255,7 @@ def input_type_cast(
str,
]:
"""Cast the input data to global float type."""
input_prec = self.reverse_precision_dict[
self.precision_dict[coord.dtype.name]
]
input_prec = RESERVED_PRECISON_DICT[self.precision_dict[coord.dtype.name]]
###
### type checking would not pass jit, convert to coord prec anyway
###
Expand All @@ -264,10 +264,7 @@ def input_type_cast(
for vv in [box, fparam, aparam]
]
box, fparam, aparam = _lst
if (
input_prec
== self.reverse_precision_dict[self.global_np_float_precision]
):
if input_prec == RESERVED_PRECISON_DICT[self.global_np_float_precision]:
return coord, box, fparam, aparam, input_prec
else:
pp = self.global_np_float_precision
Expand All @@ -286,8 +283,7 @@ def output_type_cast(
) -> dict[str, np.ndarray]:
"""Convert the model output to the input prec."""
do_cast = (
input_prec
!= self.reverse_precision_dict[self.global_np_float_precision]
input_prec != RESERVED_PRECISON_DICT[self.global_np_float_precision]
)
pp = self.precision_dict[input_prec]
odef = self.model_output_def()
Expand Down Expand Up @@ -366,17 +362,18 @@ def _format_nlist(
nnei: int,
extra_nlist_sort: bool = False,
):
xp = array_api_compat.array_namespace(extended_coord, nlist)
n_nf, n_nloc, n_nnei = nlist.shape
extended_coord = extended_coord.reshape([n_nf, -1, 3])
nall = extended_coord.shape[1]
rcut = self.get_rcut()

if n_nnei < nnei:
# make a copy before revise
ret = np.concatenate(
ret = xp.concat(
[
nlist,
-1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype),
-1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype),
],
axis=-1,
)
Expand All @@ -385,16 +382,16 @@ def _format_nlist(
n_nf, n_nloc, n_nnei = nlist.shape
# make a copy before revise
m_real_nei = nlist >= 0
ret = np.where(m_real_nei, nlist, 0)
ret = xp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2)
coord1 = np.take_along_axis(extended_coord, index, axis=1)
coord1 = xp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3)
rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
rr = np.where(m_real_nei, rr, float("inf"))
rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1)
ret = np.take_along_axis(ret, ret_mapping, axis=2)
ret = np.where(rr > rcut, -1, ret)
rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
rr = xp.where(m_real_nei, rr, float("inf"))
rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1)
ret = xp.take_along_axis(ret, ret_mapping, axis=2)
ret = xp.where(rr > rcut, -1, ret)
ret = ret[..., :nnei]
# not extra_nlist_sort and n_nnei <= nnei:
elif n_nnei == nnei:
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
Expand All @@ -23,6 +24,7 @@ def fit_output_to_model_output(
the model output.
"""
xp = array_api_compat.get_namespace(coord_ext)
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
vdef = fit_output_def[kk]
Expand All @@ -31,7 +33,7 @@ def fit_output_to_model_output(
if vdef.reducible:
kk_redu = get_reduce_name(kk)
# cast to energy prec brefore reduction
model_ret[kk_redu] = np.sum(
model_ret[kk_redu] = xp.sum(
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
)
if vdef.r_differentiable:
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/atomic_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
18 changes: 18 additions & 0 deletions deepmd/jax/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.common import (
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
AtomExcludeMask,
PairExcludeMask,
)


def base_atomic_model_set_attr(name, value):
if name in {"out_bias", "out_std"}:
value = to_jax_array(value)
elif name == "pair_excl" and value is not None:
value = PairExcludeMask(value.ntypes, value.exclude_types)
elif name == "atom_excl" and value is not None:
value = AtomExcludeMask(value.ntypes, value.exclude_types)
return value
30 changes: 30 additions & 0 deletions deepmd/jax/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP
from deepmd.jax.atomic_model.base_atomic_model import (
base_atomic_model_set_attr,
)
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.fitting.base_fitting import (
BaseFitting,
)


@flax_module
class DPAtomicModel(DPAtomicModelDP):
base_descriptor_cls = BaseDescriptor
"""The base descriptor class."""
base_fitting_cls = BaseFitting
"""The base fitting class."""

def __setattr__(self, name: str, value: Any) -> None:
value = base_atomic_model_set_attr(name, value)
return super().__setattr__(name, value)
11 changes: 11 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.descriptor.dpa1 import (
DescrptDPA1,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)

__all__ = [
"DescrptSeA",
"DescrptDPA1",
]
9 changes: 9 additions & 0 deletions deepmd/jax/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.make_base_descriptor import (
make_base_descriptor,
)
from deepmd.jax.env import (
jnp,
)

BaseDescriptor = make_base_descriptor(jnp.ndarray)
5 changes: 5 additions & 0 deletions deepmd/jax/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
Expand Down Expand Up @@ -76,6 +79,8 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
@flax_module
class DescrptDPA1(DescrptDPA1DP):
def __setattr__(self, name: str, value: Any) -> None:
Expand Down
5 changes: 5 additions & 0 deletions deepmd/jax/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
Expand All @@ -16,6 +19,8 @@
)


@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
@flax_module
class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
Expand Down
9 changes: 9 additions & 0 deletions deepmd/jax/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.fitting.fitting import (
DOSFittingNet,
EnergyFittingNet,
)

__all__ = [
"EnergyFittingNet",
"DOSFittingNet",
]
9 changes: 9 additions & 0 deletions deepmd/jax/fitting/base_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.fitting.make_base_fitting import (
make_base_fitting,
)
from deepmd.jax.env import (
jnp,
)

BaseFitting = make_base_fitting(jnp.ndarray)
Loading

0 comments on commit b7f86c9

Please sign in to comment.