Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tf/pt): add/refact lammps support for spin model #4216

Draft
wants to merge 27 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f34cbe1
feat(pt/tf): support spin lammps plugin
iProzd Sep 21, 2024
d5b544b
update typo
iProzd Sep 21, 2024
dd331fd
update pt backend
iProzd Sep 22, 2024
31bafb1
rm extend from pair-deepmd
iProzd Sep 22, 2024
15150f6
fix tf interface for spin
hztttt Sep 23, 2024
bdfe205
fix interface for multi model
hztttt Sep 23, 2024
be59313
support spin_norm & virtual_len in model graph and fix bug
hztttt Sep 25, 2024
ec7c16b
fix pt
iProzd Sep 28, 2024
6524e5e
Update pair_deepmd.cpp
iProzd Sep 28, 2024
2c66443
fix tensorflow bug
iProzd Oct 14, 2024
4f3d9d4
fix mag force bug
iProzd Oct 14, 2024
d24d7e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
593bf81
Update c_api.h
iProzd Oct 15, 2024
3466e34
Update c_api.h
iProzd Oct 15, 2024
c3a4f3e
extend sendlist nlist and other tensors but still bugs
CaRoLZhangxy Oct 18, 2024
e2e1e55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
cf85275
revert `extend sendlist nlist`
iProzd Oct 21, 2024
1d6defe
fix spin communication in lammps
iProzd Oct 21, 2024
2a38025
Merge branch 'devel' into spin_lmp
iProzd Oct 21, 2024
e5c0ecf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2024
85c934b
Update spin_model.py
iProzd Oct 22, 2024
35fd1c6
Update spin.py
iProzd Oct 22, 2024
11aeb17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
6c5cb1d
add ut for spin c++
iProzd Oct 22, 2024
474a2b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
114898f
bump version
iProzd Oct 22, 2024
fef13f5
Merge branch 'devel' into spin_lmp
iProzd Oct 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def forward_common_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
extra_nlist_sort: bool = False,
):
nframes, nloc = nlist.shape[:2]
iProzd marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -492,6 +493,7 @@ def forward_common_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=extra_nlist_sort,
)
model_output_type = self.backbone_model.model_output_type()
Expand Down Expand Up @@ -607,6 +609,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -617,6 +620,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=self.backbone_model.need_sorted_nlist_for_lower(),
)
model_predict = {}
Expand Down
4 changes: 4 additions & 0 deletions deepmd/tf/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def _make_node_names(
"o_atom_energy",
"o_atom_virial",
"spin_attr/ntypes_spin",
"spin_attr/virtual_len",
"spin_attr/spin_norm",
"fitting_attr/dfparam",
"fitting_attr/daparam",
"fitting_attr/aparam_nall",
Expand Down Expand Up @@ -259,6 +261,8 @@ def freeze_graph(
"train_attr/min_nbor_dist",
"fitting_attr/aparam_nall",
"spin_attr/ntypes_spin",
"spin_attr/virtual_len",
"spin_attr/spin_norm",
]
different_set = set(output_node) - set(input_node)
if different_set:
Expand Down
141 changes: 141 additions & 0 deletions source/api_c/include/c_api.h
iProzd marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,22 @@ extern void DP_DeepPotComputeNList(DP_DeepPot* dp,
double* atomic_energy,
double* atomic_virial);

extern void DP_DeepPotComputeNListSP(DP_DeepPot* dp,
const int natoms,
const double* coord,
const double* spin,
const int* atype,
const double* cell,
const int nghost,
const DP_Nlist* nlist,
const int ago,
double* energy,
double* force,
double* force_mag,
double* virial,
double* atomic_energy,
double* atomic_virial);
iProzd marked this conversation as resolved.
Show resolved Hide resolved

iProzd marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Evaluate the energy, force and virial by using a DP with the neighbor
*list. (float version)
Expand Down Expand Up @@ -268,6 +284,22 @@ extern void DP_DeepPotComputeNListf(DP_DeepPot* dp,
float* atomic_energy,
float* atomic_virial);

extern void DP_DeepPotComputeNListfSP(DP_DeepPot* dp,
const int natoms,
const float* coord,
const float* spin,
const int* atype,
const float* cell,
const int nghost,
const DP_Nlist* nlist,
const int ago,
double* energy,
float* force,
float* force_mag,
float* virial,
float* atomic_energy,
float* atomic_virial);

/**
* @brief Evaluate the energy, force and virial by using a DP. (double version)
* @version 2
Expand Down Expand Up @@ -392,6 +424,25 @@ extern void DP_DeepPotComputeNList2(DP_DeepPot* dp,
double* atomic_energy,
double* atomic_virial);

extern void DP_DeepPotComputeNList2SP(DP_DeepPot* dp,
const int nframes,
const int natoms,
const double* coord,
const double* spin,
const int* atype,
const double* cell,
const int nghost,
const DP_Nlist* nlist,
const int ago,
const double* fparam,
const double* aparam,
double* energy,
double* force,
double* force_mag,
double* virial,
double* atomic_energy,
double* atomic_virial);

/**
* @brief Evaluate the energy, force and virial by using a DP with the neighbor
*list. (float version)
Expand Down Expand Up @@ -438,6 +489,25 @@ extern void DP_DeepPotComputeNListf2(DP_DeepPot* dp,
float* atomic_energy,
float* atomic_virial);

extern void DP_DeepPotComputeNListf2SP(DP_DeepPot* dp,
const int nframes,
const int natoms,
const float* coord,
const float* spin,
const int* atype,
const float* cell,
const int nghost,
const DP_Nlist* nlist,
const int ago,
const float* fparam,
const float* aparam,
double* energy,
float* force,
float* force_mag,
float* virial,
float* atomic_energy,
float* atomic_virial);

/**
* @brief Evaluate the energy, force and virial by using a DP with the mixed
*type. (double version)
Expand Down Expand Up @@ -734,6 +804,22 @@ extern void DP_DeepPotModelDeviComputeNList(DP_DeepPotModelDevi* dp,
double* atomic_energy,
double* atomic_virial);

extern void DP_DeepPotModelDeviComputeNListSP(DP_DeepPotModelDevi* dp,
const int natoms,
const double* coord,
const double* spin,
const int* atype,
const double* cell,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
const int nghost,
const DP_Nlist* nlist,
const int ago,
double* energy,
double* force,
double* force_mag,
double* virial,
double* atomic_energy,
double* atomic_virial);

/**
* @brief Evaluate the energy, force and virial by using a DP model deviation
*with neighbor list. (float version)
Expand Down Expand Up @@ -771,6 +857,22 @@ extern void DP_DeepPotModelDeviComputeNListf(DP_DeepPotModelDevi* dp,
float* atomic_energy,
float* atomic_virial);

extern void DP_DeepPotModelDeviComputeNListfSP(DP_DeepPotModelDevi* dp,
const int natoms,
const float* coord,
const float* spin,
const int* atype,
const float* cell,
const int nghost,
const DP_Nlist* nlist,
const int ago,
double* energy,
float* force,
float* force_mag,
float* virial,
float* atomic_energy,
float* atomic_virial);

/**
* @brief Evaluate the energy, force and virial by using a DP model deviation
*with neighbor list. (double version)
Expand Down Expand Up @@ -816,6 +918,26 @@ void DP_DeepPotModelDeviComputeNList2(DP_DeepPotModelDevi* dp,
double* virial,
double* atomic_energy,
double* atomic_virial);

void DP_DeepPotModelDeviComputeNList2SP(DP_DeepPotModelDevi* dp,
const int nframes,
const int natoms,
const double* coord,
const double* spin,
const int* atype,
const double* cell,
const int nghost,
const DP_Nlist* nlist,
const int ago,
const double* fparam,
const double* aparam,
double* energy,
double* force,
double* force_mag,
double* virial,
double* atomic_energy,
double* atomic_virial);

iProzd marked this conversation as resolved.
Show resolved Hide resolved
iProzd marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Evaluate the energy, force and virial by using a DP model deviation
*with neighbor list. (float version)
Expand Down Expand Up @@ -862,6 +984,25 @@ void DP_DeepPotModelDeviComputeNListf2(DP_DeepPotModelDevi* dp,
float* atomic_energy,
float* atomic_virial);

void DP_DeepPotModelDeviComputeNListf2SP(DP_DeepPotModelDevi* dp,
const int nframes,
const int natoms,
const float* coord,
const float* spin,
const int* atype,
const float* cell,
const int nghost,
const DP_Nlist* nlist,
const int ago,
const float* fparam,
const float* aparam,
double* energy,
float* force,
float* force_mag,
float* virial,
float* atomic_energy,
float* atomic_virial);

iProzd marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Get the type map of a DP model deviation.
* @param[in] dp The DP model deviation to use.
Expand Down
Loading
Loading