Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tf/pt): add/refact lammps support for spin model #4216

Draft
wants to merge 27 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f34cbe1
feat(pt/tf): support spin lammps plugin
iProzd Sep 21, 2024
d5b544b
update typo
iProzd Sep 21, 2024
dd331fd
update pt backend
iProzd Sep 22, 2024
31bafb1
rm extend from pair-deepmd
iProzd Sep 22, 2024
15150f6
fix tf interface for spin
hztttt Sep 23, 2024
bdfe205
fix interface for multi model
hztttt Sep 23, 2024
be59313
support spin_norm & virtual_len in model graph and fix bug
hztttt Sep 25, 2024
ec7c16b
fix pt
iProzd Sep 28, 2024
6524e5e
Update pair_deepmd.cpp
iProzd Sep 28, 2024
2c66443
fix tensorflow bug
iProzd Oct 14, 2024
4f3d9d4
fix mag force bug
iProzd Oct 14, 2024
d24d7e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
593bf81
Update c_api.h
iProzd Oct 15, 2024
3466e34
Update c_api.h
iProzd Oct 15, 2024
c3a4f3e
extend sendlist nlist and other tensors but still bugs
CaRoLZhangxy Oct 18, 2024
e2e1e55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
cf85275
revert `extend sendlist nlist`
iProzd Oct 21, 2024
1d6defe
fix spin communication in lammps
iProzd Oct 21, 2024
2a38025
Merge branch 'devel' into spin_lmp
iProzd Oct 21, 2024
e5c0ecf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2024
85c934b
Update spin_model.py
iProzd Oct 22, 2024
35fd1c6
Update spin.py
iProzd Oct 22, 2024
11aeb17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
6c5cb1d
add ut for spin c++
iProzd Oct 22, 2024
474a2b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
114898f
bump version
iProzd Oct 22, 2024
fef13f5
Merge branch 'devel' into spin_lmp
iProzd Oct 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)
Expand Down Expand Up @@ -422,6 +425,7 @@ def forward(
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
g1 = self.act(atype_embd)
ng1 = g1.shape[-1]
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
if not self.direct_dist:
g2, h2 = torch.split(dmatrix, [1, 3], dim=-1)
Expand All @@ -448,10 +452,27 @@ def forward(
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
else:
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
has_spin = "has_spin" in comm_dict
if not has_spin:
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
real_nloc = nloc
real_nall = nall
else:
# for spin
real_nloc = nloc // 2
real_nall = nall // 2
real_n_padding = real_nall - real_nloc
g1_real, g1_virtual = torch.split(g1, [real_nloc, real_nloc], dim=1)
# mix_g1: nb x real_nloc x (ng1 * 2)
mix_g1 = torch.cat([g1_real, g1_virtual], dim=2)
# nb x real_nall x (ng1 * 2)
g1 = torch.nn.functional.pad(
mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved

iProzd marked this conversation as resolved.
Show resolved Hide resolved
assert "send_list" in comm_dict
assert "send_proc" in comm_dict
assert "recv_proc" in comm_dict
Expand All @@ -466,10 +487,15 @@ def forward(
comm_dict["recv_num"],
g1,
comm_dict["communicator"],
torch.tensor(nloc), # pylint: disable=no-explicit-dtype,no-explicit-device
torch.tensor(nall - nloc), # pylint: disable=no-explicit-dtype,no-explicit-device
torch.tensor(real_nloc), # pylint: disable=no-explicit-dtype,no-explicit-device
torch.tensor(real_nall - real_nloc), # pylint: disable=no-explicit-dtype,no-explicit-device
)
g1_ext = ret[0].unsqueeze(0)
if has_spin:
g1_real_ext, g1_virtual_ext = torch.split(g1_ext, [ng1, ng1], dim=2)
g1_ext = concat_switch_virtual(
g1_real_ext, g1_virtual_ext, real_nloc
)
g1, g2, h2 = ll.forward(
g1_ext,
g2,
Expand Down
40 changes: 10 additions & 30 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)
Expand Down Expand Up @@ -79,15 +82,15 @@ def process_spin_input_lower(
self.virtual_scale_mask.to(extended_atype.device)
)[extended_atype].reshape([nframes, nall, 1])
virtual_extended_atype = extended_atype + self.ntypes_real
extended_coord_updated = self.concat_switch_virtual(
extended_coord_updated = concat_switch_virtual(
extended_coord, virtual_extended_coord, nloc
)
extended_atype_updated = self.concat_switch_virtual(
extended_atype_updated = concat_switch_virtual(
extended_atype, virtual_extended_atype, nloc
)
if mapping is not None:
virtual_mapping = mapping + nloc
mapping_updated = self.concat_switch_virtual(mapping, virtual_mapping, nloc)
mapping_updated = concat_switch_virtual(mapping, virtual_mapping, nloc)
else:
mapping_updated = None
# extend the nlist
Expand Down Expand Up @@ -203,33 +206,6 @@ def extend_nlist(extended_atype, nlist):
extended_nlist[second_part_index] -= nall - nloc
return extended_nlist

@staticmethod
def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int):
"""
Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms.
- [:, :nloc]: original nloc real atoms.
- [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms.
- [:, nloc + nloc: nloc + nall]: ghost real atoms.
- [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms.
"""
nframes, nall = extended_tensor.shape[:2]
out_shape = list(extended_tensor.shape)
out_shape[1] *= 2
extended_tensor_updated = torch.zeros(
out_shape,
dtype=extended_tensor.dtype,
device=extended_tensor.device,
)
extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc]
extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[
:, :nloc
]
extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[
:, nloc:
]
extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:]
return extended_tensor_updated.view(out_shape)

@staticmethod
def expand_aparam(aparam, nloc: int):
"""Expand the atom parameters for virtual atoms if necessary."""
Expand Down Expand Up @@ -469,6 +445,7 @@ def forward_common_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
extra_nlist_sort: bool = False,
):
nframes, nloc = nlist.shape[:2]
Expand All @@ -490,6 +467,7 @@ def forward_common_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=extra_nlist_sort,
)
model_output_type = self.backbone_model.model_output_type()
Expand Down Expand Up @@ -605,6 +583,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -615,6 +594,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=self.backbone_model.need_sorted_nlist_for_lower(),
)
model_predict = {}
Expand Down
30 changes: 30 additions & 0 deletions deepmd/pt/utils/spin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch


def concat_switch_virtual(
extended_tensor,
extended_tensor_virtual,
nloc: int,
):
"""
Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms.
- [:, :nloc]: original nloc real atoms.
- [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms.
- [:, nloc + nloc: nloc + nall]: ghost real atoms.
- [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms.
"""
nframes, nall = extended_tensor.shape[:2]
out_shape = list(extended_tensor.shape)
out_shape[1] *= 2
extended_tensor_updated = torch.zeros(
out_shape,
dtype=extended_tensor.dtype,
device=extended_tensor.device,
)
extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc]
extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[:, :nloc]
extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[:, nloc:]
extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:]
return extended_tensor_updated.view(out_shape)
4 changes: 4 additions & 0 deletions deepmd/tf/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def _make_node_names(
"o_atom_energy",
"o_atom_virial",
"spin_attr/ntypes_spin",
"spin_attr/virtual_len",
"spin_attr/spin_norm",
"fitting_attr/dfparam",
"fitting_attr/daparam",
"fitting_attr/aparam_nall",
Expand Down Expand Up @@ -258,6 +260,8 @@ def freeze_graph(
"train_attr/min_nbor_dist",
"fitting_attr/aparam_nall",
"spin_attr/ntypes_spin",
"spin_attr/virtual_len",
"spin_attr/spin_norm",
]
different_set = set(output_node) - set(input_node)
if different_set:
Expand Down
Loading
Loading