Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tf/pt): add/refact lammps support for spin model #4216

Draft
wants to merge 27 commits into
base: devel
Choose a base branch
from

Conversation

iProzd
Copy link
Collaborator

@iProzd iProzd commented Oct 15, 2024

Summary by CodeRabbit

  • New Features

    • Enhanced computational capabilities for the SpinModel and SpinEnergyModel classes, allowing for additional contextual data during processing.
    • New overloads for the compute and computew methods in the DeepPotPT class, enabling calculations involving force magnitudes and spin configurations.
    • Introduced a new function for concatenating real and virtual tensors, improving tensor management in spin-related computations.
    • Updated handling of spin states in the DescrptBlockRepformers class for better tensor processing.
    • Added a new method in the DeepPotTF class to support computations involving spin attributes.
    • Enhanced node management in the TensorFlow graph freezing process, allowing for broader model compatibility.
    • Introduced new functions in the C API for computations involving spin inputs, expanding API capabilities.
    • Significant enhancements to the DeepPot and DeepPotModelDevi classes to support spin in their computation methods.
    • Comprehensive unit tests for the DeepPot class focusing on spin functionality.
  • Bug Fixes

    • Improved handling of parameters in the computation methods to ensure accurate calculations and data flow.

@iProzd iProzd marked this pull request as draft October 15, 2024 16:00
Copy link
Contributor

coderabbitai bot commented Oct 15, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The pull request introduces modifications to the SpinModel and SpinEnergyModel classes by adding an optional parameter comm_dict to their forward_common_lower methods. Additionally, the DeepPotPT class is updated with new overloads for the compute and computew methods to include parameters for force_mag and spin. The DescrptBlockRepformers class is also updated to handle spin states, and a new utility function concat_switch_virtual is introduced to facilitate tensor operations related to spin configurations. Furthermore, several other classes and methods across the codebase are enhanced to support computations involving spin, improving the overall functionality of the framework.

Changes

File Path Change Summary
deepmd/pt/model/model/spin_model.py Updated forward_common_lower methods in SpinModel and SpinEnergyModel to include an optional comm_dict parameter; removed concat_switch_virtual.
source/api_cc/src/DeepPotPT.cc Added overloads for compute and computew methods in DeepPotPT to handle new parameters force_mag and spin.
source/api_cc/include/DeepPotPT.h Updated method signatures for compute and computew in DeepPotPT to include new parameters for force_mag and spin.
deepmd/pt/model/descriptor/repformers.py Modified forward method in DescrptBlockRepformers to handle spin states and introduced conditional logic for tensor dimensions.
deepmd/pt/utils/spin.py Added new function concat_switch_virtual for concatenating real and virtual tensors while managing local atom configurations.
deepmd/tf/entrypoints/freeze.py Enhanced node management in TensorFlow graph freezing; added new nodes for "ener" model type and improved error handling.
source/api_cc/src/DeepPotTF.cc Introduced new methods for spin support in DeepPotTF, including updates to existing methods for handling force magnitude and spin.
source/api_c/include/c_api.h Added new API functions to support spin in deep potential computations, enhancing the C API for better integration.
source/api_c/src/c_api.cc Implemented new functions for spin handling in existing computation functions, updating internal logic to accommodate spin parameters.
source/api_cc/include/DeepPot.h Enhanced DeepPot and DeepPotBase classes with new methods for handling magnetic forces and spins in computations.
source/api_cc/include/DeepPotTF.h Updated DeepPotTF class with new overloads for compute and computew methods to include parameters for force_mag and spin.
source/api_c/tests/test_deeppot_dpa1_pt_spin.cc Added unit tests for DeepPot class focusing on spin functionality, ensuring the correctness of computations.
source/lmp/pair_deepmd.cpp Enhanced PairDeepMD class to improve handling of spin states and associated forces in computation methods.

Possibly related PRs

Suggested reviewers

  • njzjz
  • wanghan-iapcm
  • anyangml

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Comment on lines +378 to +379
// std::vector<double> virtual_len;
// std::vector<double> spin_norm;

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
std::vector<int> extend_numneigh;
std::vector<std::vector<int>> extend_neigh;
std::vector<int*> extend_firstneigh;
// std::vector<double> extend_dcoord;

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map,
bkw_map, nall_real, nloc_real, coord, atype, aparam,
nghost, ntypes, 1, daparam, nall, aparam_nall);
int nloc = nall_real - nghost_real;

Check notice

Code scanning / CodeQL

Unused local variable

Variable nloc is not used.
Comment on lines +462 to +466
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
// cpu_atom_virial_.data_ptr<VALUETYPE>(),
// cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
// atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
Comment on lines +748 to +752
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
if (atomic) {
// c10::IValue atom_virial_ = outputs.at("atom_virial");

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
}

if (dtype == tensorflow::DT_DOUBLE) {
int ret = session_input_tensors<double>(

Check notice

Code scanning / CodeQL

Unused local variable

Variable ret is not used.
nframes, nghost_real);
}
} else {
int ret = session_input_tensors<float>(

Check notice

Code scanning / CodeQL

Unused local variable

Variable ret is not used.
@@ -1198,7 +1245,12 @@ void PairDeepMD::settings(int narg, char **arg) {
}
}

comm_reverse = numb_models * 3;
// comm_reverse = numb_models * 3;

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
Comment on lines +378 to +379
// std::vector<double> virtual_len;
// std::vector<double> spin_norm;

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
std::vector<int> extend_numneigh;
std::vector<std::vector<int>> extend_neigh;
std::vector<int*> extend_firstneigh;
// std::vector<double> extend_dcoord;

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map,
bkw_map, nall_real, nloc_real, coord, atype, aparam,
nghost, ntypes, 1, daparam, nall, aparam_nall);
int nloc = nall_real - nghost_real;

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable nloc is not used.
Comment on lines +462 to +466
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
// cpu_atom_virial_.data_ptr<VALUETYPE>(),
// cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
// atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
Comment on lines +748 to +752
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
if (atomic) {
// c10::IValue atom_virial_ = outputs.at("atom_virial");

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
int ret = session_input_tensors<double>(
input_tensors, dcoord, ntypes, datype, dbox, nlist, fparam, aparam,
atommap, nghost_real, ago, "", aparam_nall);
assert(nloc_real == ret);

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable ret is not used.
int ret = session_input_tensors<float>(
input_tensors, dcoord, ntypes, datype, dbox, nlist, fparam, aparam,
atommap, nghost_real, ago, "", aparam_nall);
assert(nloc_real == ret);

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable ret is not used.
@@ -1198,7 +1245,12 @@
}
}

comm_reverse = numb_models * 3;
// comm_reverse = numb_models * 3;

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 34

🧹 Outside diff range and nitpick comments (13)
deepmd/tf/entrypoints/freeze.py (2)

264-265: LGTM! Consider grouping related optional nodes.

The addition of spin_attr/virtual_len and spin_attr/spin_norm to the optional_node list is a good practice. It ensures backward compatibility with models that don't have spin attributes.

For improved readability, consider grouping related optional nodes together. You could refactor the optional_node list as follows:

optional_node = [
    "train_attr/min_nbor_dist",
    "fitting_attr/aparam_nall",
    # Spin-related attributes
    "spin_attr/ntypes_spin",
    "spin_attr/virtual_len",
    "spin_attr/spin_norm",
]

This grouping makes it easier to identify and manage related optional nodes.


Line range hint 1-389: Overall changes look good. Consider updating documentation.

The changes to add support for spin-related computations in the freezing process are well-implemented and consistent with the PR objectives. The modifications are minimal and targeted, which is a good practice for reducing the risk of introducing bugs.

To ensure that users and developers are aware of the new spin-related functionality, consider updating the module or function docstrings to mention the new spin attributes. For example, you could add a note to the freeze_graph function docstring:

def freeze_graph(...):
    """Freeze the single graph with chosen out_suffix.

    ...

    Notes
    -----
    This function now supports freezing of spin-related attributes
    (virtual_len and spin_norm) for the "ener" model type.
    """
    ...

This documentation update will help users understand the new capabilities added by these changes.

deepmd/pt/model/model/spin_model.py (2)

Line range hint 474-497: LGTM! Consider adding documentation for new parameters.

The addition of comm_dict and extra_nlist_sort parameters enhances the flexibility of the forward_common_lower method. However, it would be beneficial to document their purpose and usage in the method's docstring.

Consider adding documentation for the new parameters:

def forward_common_lower(
    self,
    extended_coord,
    extended_atype,
    extended_spin,
    nlist,
    mapping: Optional[torch.Tensor] = None,
    fparam: Optional[torch.Tensor] = None,
    aparam: Optional[torch.Tensor] = None,
    do_atomic_virial: bool = False,
    comm_dict: Optional[Dict[str, torch.Tensor]] = None,
    extra_nlist_sort: bool = False,
):
    """
    ...
    Parameters:
    ...
    comm_dict: Optional dictionary for additional communication data.
    extra_nlist_sort: Boolean flag for additional neighbor list sorting.
    ...
    """

Line range hint 612-624: LGTM! Consider adding documentation for the new parameter.

The addition of the comm_dict parameter and its usage are consistent with the changes in the parent class. The extra_nlist_sort parameter is now properly utilized.

Consider adding documentation for the new parameter:

@torch.jit.export
def forward_lower(
    self,
    extended_coord,
    extended_atype,
    extended_spin,
    nlist,
    mapping: Optional[torch.Tensor] = None,
    fparam: Optional[torch.Tensor] = None,
    aparam: Optional[torch.Tensor] = None,
    do_atomic_virial: bool = False,
    comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
    """
    ...
    Parameters:
    ...
    comm_dict: Optional dictionary for additional communication data.
    ...
    """
source/api_cc/include/DeepPotTF.h (5)

118-134: Add documentation for the new compute method overload

The newly added compute method overload at lines 118-134 introduces additional parameters force_mag and spin. To maintain code clarity and assist users, please add accompanying documentation that describes the purpose of this method and explains each parameter.


282-313: Add documentation for the new computew method overloads

The overloaded computew methods at lines 282-313 include new parameters force_mag and spin. To maintain consistency and aid users, please provide detailed documentation for these methods, explaining the purpose of the new parameters and any changes in the method's behavior.


359-359: Provide parameter names in the cum_sum method declaration

At line 359, the cum_sum method is declared without parameter names:

void cum_sum(std::map<int, int>&, std::map<int, int>&);

For better code readability and maintainability, please include descriptive parameter names:

void cum_sum(std::map<int, int>& input_map, std::map<int, int>& output_map);

378-379: Remove or clarify commented-out code

The lines 378-379 contain commented-out declarations:

// std::vector<double> virtual_len;
// std::vector<double> spin_norm;

If these variables are no longer needed, please remove them to keep the code clean. If they are intended for future use, add comments explaining their purpose.


131-131: Pass small integers by value instead of const reference

In the method parameters at lines 131, 294, and 310, the integer ago is passed as const int& ago. Since int is a small built-in type, passing it by value is more efficient and avoids the overhead of referencing. Please consider changing the parameter to:

const int ago

Also applies to: 294-294, 310-310

source/api_cc/src/DeepPotPT.cc (2)

446-447: Typo in comment: 'suported' should be 'supported'

There's a typographical error in the comment on line 446. Correct the spelling of "suported" to "supported" for clarity.

Apply this diff to fix the typo:

-// spin model not suported yet
+// Spin model not supported yet

463-466: Remove unnecessary commented-out code

The block of commented-out code related to virial_ is unnecessary since it's already indicated that the spin model does not support virial calculations yet. Removing this block can improve code readability.

Apply this diff to remove the unnecessary code:

-// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
-// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
-// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
-//               cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
source/lmp/pair_deepmd.cpp (1)

1249-1253: Clarify the computation of comm_reverse

You conditionally set comm_reverse based on atom->sp_flag:

if (atom->sp_flag) {
  comm_reverse = numb_models * 3 * 2;
} else {
  comm_reverse = numb_models * 3;
}

Consider adding comments or refactoring for clarity to indicate why the factor is * 2 when atom->sp_flag is true. This will enhance maintainability and readability.

source/api_c/include/c_api.h (1)

233-247: Add documentation for the new API function

The function DP_DeepPotComputeNListSP lacks a documentation comment block. Please add a documentation block similar to other functions, describing the purpose, parameters, and any important notes.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 6fe8dde and 3466e34.

📒 Files selected for processing (13)
  • deepmd/pt/model/model/spin_model.py (4 hunks)
  • deepmd/tf/entrypoints/freeze.py (2 hunks)
  • source/api_c/include/c_api.h (8 hunks)
  • source/api_c/include/deepmd.hpp (6 hunks)
  • source/api_c/src/c_api.cc (10 hunks)
  • source/api_cc/include/DeepPot.h (5 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/include/DeepPotTF.h (3 hunks)
  • source/api_cc/src/DeepPot.cc (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
  • source/api_cc/src/DeepPotTF.cc (4 hunks)
  • source/lmp/pair_deepmd.cpp (15 hunks)
  • source/lmp/pair_deepmd.h (1 hunks)
🧰 Additional context used
🔇 Additional comments (41)
deepmd/tf/entrypoints/freeze.py (1)

127-128: LGTM! Verify impact on other parts of the codebase.

The addition of spin_attr/virtual_len and spin_attr/spin_norm nodes for the "ener" model type is consistent with the PR objectives to add support for spin models. This change looks good and should enable the freezing of spin-related attributes.

To ensure these changes don't have unintended consequences, please run the following script to check for any other occurrences of these new attributes in the codebase:

This will help verify that these new attributes are consistently used across the project.

✅ Verification successful

Verified! No unintended occurrences of new spin attributes found elsewhere.

The newly added spin_attr/virtual_len and spin_attr/spin_norm nodes are exclusively present in deepmd/tf/entrypoints/freeze.py. Their implementation is confined to this file, ensuring that there are no unintended side effects in other parts of the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for other occurrences of new spin attributes

echo "Searching for 'spin_attr/virtual_len':"
rg "spin_attr/virtual_len" --type py

echo "\nSearching for 'spin_attr/spin_norm':"
rg "spin_attr/spin_norm" --type py

Length of output: 521

deepmd/pt/model/model/spin_model.py (1)

Line range hint 1-638: Overall, the changes look good and enhance the model's flexibility.

The additions of comm_dict and extra_nlist_sort parameters in both SpinModel and SpinEnergyModel classes are well-implemented and consistent. These changes provide more flexibility in the forward pass and neighbor list sorting. The code maintains good quality and there are no apparent issues introduced by these changes.

To further improve the code:

  1. Consider adding documentation for the new parameters in both forward_common_lower and forward_lower methods.
  2. Ensure that the usage of comm_dict is properly documented in the broader context of the DeepMD framework, explaining when and how it should be used.
source/lmp/pair_deepmd.h (1)

78-78: LGTM!

The addition of all_force_mag as a member variable appears appropriate.

source/api_cc/include/DeepPotTF.h (3)

126-126: Confirm the correctness of spin data type

At line 126, spin is declared as const std::vector<VALUETYPE>& spin. Since spin values might be integers in some models, ensure that VALUETYPE is the appropriate data type for representing spins in all cases. If spins are always floating-point numbers, this is acceptable; otherwise, consider using a more suitable type.


289-289: Consistent data types for spin across methods

In the computew methods, spin is declared as const std::vector<double>& at line 289 and const std::vector<float>& at line 305. Ensure that the usage of double and float for spin aligns with the rest of the method parameters and that any casting or precision considerations are properly handled.

Also applies to: 305-305


368-370: Check for duplication of get_vector method

The get_vector template method is declared at lines 368-370. Ensure that this declaration does not duplicate an existing method and that it is necessary for the extended functionality. If it's an intentional addition, consider providing documentation to explain its purpose.

source/api_cc/include/DeepPotPT.h (1)

77-90: Verify consistency of parameter ordering across method overloads

Ensure that the ordering of parameters, particularly the placement of spin and force_mag, is consistent with existing methods and logical grouping. Consistent parameter ordering prevents confusion and potential bugs.

source/api_cc/src/DeepPotPT.cc (5)

658-798: Consistent handling of optional tensors

Ensure that optional tensors like box_Tensor, fparam_tensor, and aparam_tensor are consistently handled across different compute methods. This prevents unexpected behaviors when these optional parameters are absent.

Please verify that all optional tensors are checked for validity before use to avoid null pointer dereferences.


Line range hint 446-793: Correct inconsistency in handling do_message_passing and nghost

Throughout the new compute methods, ensure that the conditional checks for do_message_passing and nghost are consistent and correctly implemented. This is crucial for parallel computations and message passing between processes.

Review the conditional logic to confirm that all possible states of do_message_passing and nghost are appropriately handled.


755-767: Commented-out code may indicate incomplete implementation

The commented-out sections related to atom_virial_ suggest that the spin model's support for atomic virials is incomplete.

Confirm whether the implementation of atomic virials with the spin model is pending. If so, consider adding a TODO comment to track this for future development.


468-474: ⚠️ Potential issue

Initialize force_mag vector before use

The force_mag vector should be initialized appropriately before being used in the select_map function to avoid potential undefined behavior.

Apply this diff to ensure force_mag is correctly initialized:

 force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
+force_mag.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
 select_map<VALUETYPE>(force, dforce, bkw_map, 3, nframes, fwd_map.size(),
                       nall_real);
 select_map<VALUETYPE>(force_mag, dforce_mag, bkw_map, 3, nframes,
                       fwd_map.size(), nall_real);

Likely invalid or redundant comment.


503-535: Ensure explicit instantiation of template methods for required types

The explicit template instantiations at the end of the file should cover all the necessary data type combinations used in the codebase. Verify that the methods are instantiated for all required VALUETYPE and ENERGYVTYPE combinations.

Run the following script to check for all instantiations:

Ensure that the output includes all necessary type combinations.

Also applies to: 772-798

✅ Verification successful

All required VALUETYPE and ENERGYVTYPE combinations are explicitly instantiated.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for explicit template instantiations of compute methods.

# Search for lines with 'template void DeepPotPT::compute<'
rg --type cpp 'template void DeepPotPT::compute<' -A 2

Length of output: 1885


Script:

#!/bin/bash
# Description: Identify all unique VALUETYPEs used in DeepPotPT::compute method calls.

# Extract all compute method calls and capture the VALUETYPE template argument
rg --type cpp 'DeepPotPT::compute<' | \
    sed -n "s/.*DeepPotPT::compute<\([^,>]*\),.*/\1/p" | \
    sort | uniq

Length of output: 117

source/lmp/pair_deepmd.cpp (6)

582-586: Verify the updated deep_pot.compute method signature

The deep_pot.compute function is called with additional parameters dforce_mag and dspin_const. Ensure that the deep_pot.compute method signature has been updated accordingly and all dependencies are properly handled.


606-608: Ensure consistency in deep_pot.compute calls with spin parameters

When calling deep_pot.compute, you are passing dspin. Verify that this parameter is correctly handled within the function and that all spin-related computations are accurate.


1407-1423: Ensure correct packing of reverse communication buffer with spin data

In pack_reverse_comm, you have added code to handle spin data:

if (atom->sp_flag) {
  // Code handling spin data
} else {
  // Original code
}

Verify that the buffer buf is being correctly populated, and that the size calculations match the expected data to prevent communication errors.


1436-1454: Validate unpacking of reverse communication buffer with spin data

Similar to packing, in unpack_reverse_comm, ensure that the unpacking logic correctly aligns with the packing structure, particularly with the added spin data. Mismatches here can lead to incorrect force calculations.


771-783: Check calculations for spin-related force deviations

In the block handling spin forces:

if (atom->sp_flag) {
  // Calculations for std_fm, tmp_avg_fm, etc.
}

Verify that the statistical calculations for std_fm, all_fm_min, all_fm_max, and all_fm_avg are accurate and that they properly represent the deviations intended.


657-693: ⚠️ Potential issue

Handle possible issues with model deviation calculations

In the multi-model deviation computation block, there are nested conditionals and numerous variables like all_force_mag, all_force, all_energy, etc. Ensure that:

  • All vectors are properly initialized before use.
  • The indexing aligns correctly across different models.
  • Memory management is correctly handled to prevent leaks.
source/api_cc/src/DeepPotTF.cc (1)

1429-1441: Check for potential out-of-bounds access in extend_atype assignment

In the loop updating extend_atype, there is a risk of out-of-bounds access when calculating indices with new_idx_map[ii] + nloc and new_idx_map[ii] + nghost. Ensure that these calculated indices do not exceed the size of extend_atype.

Run the following script to check index bounds:

Ensure that extend_atype is appropriately sized to accommodate these indices or adjust the index calculations to prevent out-of-bounds access.

source/api_cc/src/DeepPot.cc (6)

221-242: Correct implementation of spin support in DeepPot::compute method

The new overload of the compute method correctly includes the spin parameters dspin_ and dforce_mag_. The parameters are appropriately passed to dp->computew, ensuring that spin calculations are integrated seamlessly.


244-262: Proper extension of compute method for vector energies with spin

The overloaded compute method that returns vector energies now includes dspin_ and dforce_mag_ parameters. The adjustments ensure that spin data is consistently handled across computations.


264-319: Accurate template instantiations for spin-inclusive compute methods

The explicit template instantiations for both double and float types are correctly provided for the new compute methods with spin support. This ensures type safety and availability of these methods for different numerical precisions.


488-596: Extension of compute methods with atomic outputs and spin support

The compute methods that include atomic energies and virials have been appropriately extended to incorporate spin parameters. The modifications correctly pass dspin_ and dforce_mag_ to dp->computew, enabling detailed spin-aware computations at the atomic level.


954-1012: Integration of spin support in DeepPotModelDevi::compute method

The DeepPotModelDevi::compute method now includes support for spin parameters, adding dspin_ and dforce_mag_ to the computation of energies, forces, and virials across multiple models. The loop over numb_models ensures consistent processing of spin data.


1073-1140: Enhanced compute method with spin and atomic outputs in DeepPotModelDevi

The overloaded compute method in DeepPotModelDevi has been correctly expanded to handle spin parameters and compute atomic energies and virials. This comprehensive extension ensures that all relevant spin-related data is accurately calculated for each model.

source/api_c/include/c_api.h (6)

287-302: Add documentation for the new API function

Please add a documentation block for DP_DeepPotComputeNListfSP to maintain consistency with other functions in the API.


427-445: Add documentation for the new API function

Please add a documentation block for DP_DeepPotComputeNList2SP to maintain consistency with other functions.


492-510: Add documentation for the new API function

Please add a documentation block for DP_DeepPotComputeNListf2SP to maintain consistency.


807-822: Add documentation for the new API function

Please add a documentation block for DP_DeepPotModelDeviComputeNListSP similar to other functions in the API.


860-875: Add documentation for the new API function

Please add a documentation block for DP_DeepPotModelDeviComputeNListfSP to maintain consistency.


987-1005: Add 'extern' keyword to function declaration

The function DP_DeepPotModelDeviComputeNListf2SP lacks the extern keyword in its declaration. Please add extern to maintain consistency.

Here is the suggested change:

-void DP_DeepPotModelDeviComputeNListf2SP(DP_DeepPotModelDevi* dp,
+extern void DP_DeepPotModelDeviComputeNListf2SP(DP_DeepPotModelDevi* dp,
source/api_c/src/c_api.cc (10)

419-454: LGTM!

The template instantiations for the new DP_DeepPotComputeNList_variant_sp function look good.


1304-1322: LGTM!

The new DP_DeepPotComputeNListfSP function for the float type looks good.


1382-1403: LGTM!

The new DP_DeepPotComputeNList2SP function for handling multiple frames with spin looks good.


1426-1447: LGTM!

The new DP_DeepPotComputeNListf2SP function for the float type with multiple frames looks good.


1595-1613: LGTM!

The new DP_DeepPotModelDeviComputeNListSP function for the model deviation with spin looks good.


1633-1651: LGTM!

The new DP_DeepPotModelDeviComputeNListfSP function for the float type model deviation with spin looks good.


1674-1695: LGTM!

The new DP_DeepPotModelDeviComputeNList2SP function for handling multiple frames in the model deviation with spin looks good.


1718-1739: LGTM!

The new DP_DeepPotModelDeviComputeNListf2SP function for the float type model deviation with multiple frames and spin looks good.


842-879: LGTM!

The template instantiations for the new DP_DeepPotModelDeviComputeNList_variant_sp function look good.


1266-1284: Verify the new function is called correctly from all relevant code paths.

The new DP_DeepPotComputeNListSP function looks good. However, ensure it is being called correctly from all the relevant code paths that need to compute energy, force, force_mag, virial, etc. with spin.

Run the following script to verify the function usage:

source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotPT.h Show resolved Hide resolved
source/api_cc/include/DeepPotPT.h Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 35

🧹 Outside diff range and nitpick comments (23)
deepmd/tf/entrypoints/freeze.py (1)

264-265: LGTM. Consider grouping related attributes.

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" to the optional node list is consistent with the earlier changes and provides flexibility in the freezing process.

For improved readability, consider grouping related attributes together. You could move the "spin_attr/ntypes_spin" node (currently on line 263) next to the newly added spin attributes:

optional_node = [
    "train_attr/min_nbor_dist",
    "fitting_attr/aparam_nall",
    "spin_attr/ntypes_spin",
    "spin_attr/virtual_len",
    "spin_attr/spin_norm",
]

This grouping makes it easier to identify all spin-related attributes at a glance.

deepmd/pt/model/model/spin_model.py (2)

Line range hint 474-497: LGTM! Consider adding documentation for the new parameters.

The addition of comm_dict and extra_nlist_sort parameters enhances the flexibility of the forward_common_lower method. These changes appear to be well-integrated into the existing code.

Consider adding documentation for these new parameters to explain their purpose and usage:

  1. comm_dict: Describe its role in communication and when it should be provided.
  2. extra_nlist_sort: Explain when this boolean flag should be set to True and its impact on the neighbor list sorting.

Line range hint 612-624: LGTM! Consider adding documentation for the new parameter.

The addition of the comm_dict parameter and the dynamic setting of extra_nlist_sort enhance the method's flexibility and maintain consistency with the backbone model's requirements.

Consider adding documentation for the new comm_dict parameter to explain its purpose and usage in the context of the SpinEnergyModel.

source/lmp/pair_deepmd.h (1)

78-78: Consider adding documentation for all_force_mag.

To improve code readability and maintainability, consider adding a brief comment explaining the purpose and usage of the all_force_mag member variable, similar to the existing comments for other member variables.

source/api_cc/include/DeepPotTF.h (1)

378-380: Consider removing commented-out code or adding context

The variables virtual_len and spin_norm are commented out. If they are no longer needed, it would be cleaner to remove them entirely. If they are reserved for future use, consider adding a comment explaining their intended purpose to provide context for other developers.

source/api_cc/include/DeepPotPT.h (1)

Line range hint 356-379: Update documentation for computew overloads combining mixed precision and new parameters.

The overloads of computew methods that handle mixed precision now include force_mag and spin. The documentation does not reflect these changes. To ensure proper usage and understanding, please update the documentation to include descriptions of the new parameters and any implications for mixed precision computations.

source/api_cc/src/DeepPotPT.cc (3)

788-798: Consider handling atom_virial for the spin model

In the compute method starting at line 660, the code related to atom_virial is commented out due to the spin model not being supported yet. If the spin model now supports atom_virial, consider uncommenting and implementing this functionality. If not, provide a clearer comment indicating when this feature will be available.

Clarify the status of atom_virial support in the spin model for better code maintainability.


Line range hint 643-644: Possible code simplification in device checks

In the blocks starting at lines 643-644 and 685-686, there's a check for gpu_enabled to set the device. Consider simplifying the code by initializing device based on the gpu_enabled flag directly during declaration to improve readability.

Example:

torch::Device device = gpu_enabled ? torch::Device(torch::kCUDA, gpu_id) : torch::Device(torch::kCPU);

Also applies to: 685-686


316-331: Consider documenting the new compute method overloads

The newly added compute methods with additional parameters may benefit from inline documentation or comments explaining their purpose, parameters, and usage. This can help future developers understand the differences between overloads and when to use each one.

Add Doxygen-style comments or inline explanations for better clarity.

Also applies to: 660-672

source/lmp/pair_deepmd.cpp (2)

582-586: Consistent use of constant references

In the code:

const vector<double> &dcoord_const = dcoord;
const vector<double> &dspin_const = dspin;
deep_pot.compute(dener, dforce, dforce_mag, dvirial, dcoord_const, dspin_const, dtype, dbox, nghost, lmp_list, ago, fparam, daparam);

While dspin_const and dcoord_const are defined as constant references, consider passing dcoord and dspin directly to deep_pot.compute if they are not modified within the function. This can simplify the code and reduce unnecessary variable declarations.


771-775: Potential misalignment in relative standard deviation computation

In the computation of relative standard deviation:

if (out_rel == 1) {
  deep_pot_model_devi.compute_relative_std_f(std_fm, tmp_avg_fm, eps);
}

Consider whether out_rel should control the relative standard deviation computation for both force and force magnitude (std_f and std_fm). If independent control is desired, introduce a separate flag (e.g., out_rel_fm) for clarity and flexibility.

source/api_c/include/c_api.h (7)

233-247: Add missing documentation for the new function DP_DeepPotComputeNListSP.

The newly added function DP_DeepPotComputeNListSP lacks documentation. Consistent documentation is essential for maintainability and readability. Please add a detailed comment describing the function's purpose, parameters, and expected behavior, similar to the existing function comments.


287-302: Provide documentation for DP_DeepPotComputeNListfSP to maintain consistency.

The function DP_DeepPotComputeNListfSP is missing accompanying documentation. Adding a comprehensive comment will help users understand the function's usage and maintain consistency with the rest of the API.


427-445: Include function comments for DP_DeepPotComputeNList2SP for clarity.

The function DP_DeepPotComputeNList2SP lacks descriptive comments. Please add documentation outlining the function's purpose, parameters, and any special considerations, following the style of existing documented functions.


492-510: Document the new function DP_DeepPotComputeNListf2SP.

To maintain the API's usability, please provide documentation for DP_DeepPotComputeNListf2SP. This should include details about the function's role, its parameters, and any important notes for users.


807-822: Add missing documentation for DP_DeepPotModelDeviComputeNListSP.

The function DP_DeepPotModelDeviComputeNListSP is introduced without accompanying comments. Consistent and thorough documentation aids in code comprehension and maintenance. Please include a comment block describing this function.


860-875: Provide documentation for DP_DeepPotModelDeviComputeNListfSP.

The new function DP_DeepPotModelDeviComputeNListfSP should have a descriptive comment explaining its usage, parameters, and any important details. This ensures consistency and aids other developers.


233-247: Ensure consistent parameter naming for natoms and natom.

There is inconsistency in the naming of the atom count parameter across the new function declarations. Some functions use natoms, while others use natom. For clarity and consistency, it's advisable to standardize on one naming convention.

Apply this diff to rename natom to natoms:

 const int nframes,
-const int natom,
+const int natoms,

Also applies to: 287-302, 427-445, 492-510, 807-822, 860-875, 922-940, 987-1005

source/api_c/src/c_api.cc (2)

1266-1284: Add documentation for the new function DP_DeepPotComputeNListSP

The newly introduced function lacks comments explaining its purpose and usage. Adding documentation will enhance code readability and help other developers understand how to use the spin-supporting API correctly.


Line range hint 1674-1740: Plan for extending multi-frame support with spin in model deviation computations

Similar to earlier functions, DP_DeepPotModelDeviComputeNList2SP currently does not support nframes > 1. If there's a need for processing multiple frames with spin support in model deviation computations, consider implementing this feature.

source/api_c/include/deepmd.hpp (3)

160-179: Add documentation for the _DP_DeepPotComputeNListSP function template

The newly added _DP_DeepPotComputeNListSP function template lacks documentation. Including comments to explain the purpose, parameters, and usage of this function will improve code readability and maintainability.


1141-1194: Include documentation for overloaded compute methods with spin support

The newly added overloaded compute methods that include spin parameters lack explanatory comments. Adding documentation will help other developers understand the purpose and usage of these methods.


1916-2004: Add documentation for spin-aware compute methods in DeepPotModelDevi

The compute methods that handle spin in the DeepPotModelDevi class lack descriptive comments. Providing detailed documentation will aid in understanding the methods' functionalities and parameters.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 6fe8dde and 3466e34.

📒 Files selected for processing (13)
  • deepmd/pt/model/model/spin_model.py (4 hunks)
  • deepmd/tf/entrypoints/freeze.py (2 hunks)
  • source/api_c/include/c_api.h (8 hunks)
  • source/api_c/include/deepmd.hpp (6 hunks)
  • source/api_c/src/c_api.cc (10 hunks)
  • source/api_cc/include/DeepPot.h (5 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/include/DeepPotTF.h (3 hunks)
  • source/api_cc/src/DeepPot.cc (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
  • source/api_cc/src/DeepPotTF.cc (4 hunks)
  • source/lmp/pair_deepmd.cpp (15 hunks)
  • source/lmp/pair_deepmd.h (1 hunks)
🧰 Additional context used
🔇 Additional comments (27)
deepmd/tf/entrypoints/freeze.py (1)

127-128: LGTM. Can you provide more context on the spin attributes?

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" to the list of nodes for the "ener" model type looks good. These changes align with the PR objective of adding support for the spin model.

Could you provide some additional context on what these spin attributes represent and how they are used in the energy calculations?

source/lmp/pair_deepmd.h (1)

78-78: Addition of all_force_mag member variable looks good.

The new member variable std::vector<std::vector<double> > all_force_mag; appropriately extends the PairDeepMD class to store force magnitudes, enhancing its functionality for simulations involving spin handling.

source/api_cc/include/DeepPotPT.h (1)

Line range hint 77-379: Confirm consistency of method signatures across all overloads.

The additions of force_mag and spin parameters across multiple compute and computew method overloads appear consistent. However, please verify that all method signatures are correctly aligned and that the parameter order is maintained uniformly. This helps prevent potential mismatches or confusion during method invocation.

source/api_cc/include/DeepPot.h (5)

162-177: Update method documentation to include force_mag and spin parameters


393-420: Update method documentation to include force_mag and spin parameters


523-554: Update method documentation to include force_mag and spin parameters


827-840: Update documentation for compute methods with added force_mag and spin parameters

The compute methods in the DeepPotModelDevi class have been extended with force_mag and spin parameters. Please update the documentation to reflect these additions for consistency and clarity.


881-897: Update documentation for compute methods with added force_mag and spin parameters

source/api_cc/src/DeepPotPT.cc (3)

341-342: Ensure consistency in TensorOptions dtype settings

In the conditional blocks at lines 341-342 and 683-684, the TensorOptions are set for float32 when VALUETYPE is float. Verify that this pattern is consistently applied throughout the code to prevent any unintended type mismatches.

Also applies to: 683-684


362-365: Validate dimensions of spin_wrapped_Tensor

When creating spin_wrapped_Tensor, ensure that the dimensions match the expected input of the model. Mismatches in tensor dimensions can lead to runtime errors or incorrect computations.

Also applies to: 678-691


315-535: Ensure consistent handling of the spin parameter

The new overloaded compute method includes spin as an input parameter. Please verify that all necessary validations and error handling for spin are implemented. For instance, check if spin has the correct dimensions and whether additional checks are needed for its usage.

Run the following script to search for handling of the spin parameter:

source/lmp/pair_deepmd.cpp (5)

505-505: Initialization of dforce_mag

The vector dforce_mag is declared:

vector<double> dforce_mag(nall * 3);

Ensure that this vector is properly utilized and filled with the correct force magnitude data before being used in subsequent computations to prevent unintended results.


606-608: Ensure correct function overload is called

In the call to deep_pot.compute:

deep_pot.compute(dener, dforce, dforce_mag, dvirial, deatom, dvatom, dcoord, dspin, dtype, dbox, nghost, lmp_list, ago, fparam, daparam);

Verify that the deep_pot.compute function correctly handles the additional dforce_mag and dspin parameters and that the appropriate overload or implementation is being used.


1250-1253: Verify comm_reverse calculation for data communication

The communication reverse size is set as:

if (atom->sp_flag) {
  comm_reverse = numb_models * 3 * 2;
} else {
  comm_reverse = numb_models * 3;
}

Ensure that doubling the size (* 2) when atom->sp_flag is true correctly accounts for the additional spin-related data that needs to be communicated. Validate that this change is consistent throughout the communication routines.


1406-1423: Correct buffer packing in pack_reverse_comm with spin data

In pack_reverse_comm, when atom->sp_flag is true, additional data related to force magnitude is packed:

if (atom->sp_flag) {
  for (i = first; i < last; i++) {
    for (int dd = 0; dd < numb_models; ++dd) {
      buf[m++] = all_force[dd][3 * i + 0];
      buf[m++] = all_force[dd][3 * i + 1];
      buf[m++] = all_force[dd][3 * i + 2];
      buf[m++] = all_force_mag[dd][3 * i + 0];
      buf[m++] = all_force_mag[dd][3 * i + 1];
      buf[m++] = all_force_mag[dd][3 * i + 2];
    }
  }
} else {
  // Existing packing logic
}

Ensure that:

  • The buffer buf is adequately sized to hold the additional data.
  • The order of packing matches the expected order during unpacking.
  • There are no buffer overflows or memory issues.

1435-1454: Consistent buffer unpacking in unpack_reverse_comm

In unpack_reverse_comm, ensure that the unpacking logic accurately mirrors the packing logic:

if (atom->sp_flag) {
  for (i = 0; i < n; i++) {
    j = list[i];
    for (int dd = 0; dd < numb_models; ++dd) {
      all_force[dd][3 * j + 0] += buf[m++];
      all_force[dd][3 * j + 1] += buf[m++];
      all_force[dd][3 * j + 2] += buf[m++];
      all_force_mag[dd][3 * j + 0] += buf[m++];
      all_force_mag[dd][3 * j + 1] += buf[m++];
      all_force_mag[dd][3 * j + 2] += buf[m++];
    }
  }
} else {
  // Existing unpacking logic
}

Verify that:

  • The increments of m are consistent with those in pack_reverse_comm.
  • The data is assigned to the correct indices in all_force and all_force_mag.
  • There are no mismatches that could lead to incorrect force calculations.
source/api_cc/src/DeepPotTF.cc (1)

511-516: Implementation of get_vector method is correct

The get_vector method is correctly implemented and integrates well with the existing code.

source/api_cc/src/DeepPot.cc (5)

221-242: Integration of Spin Support in DeepPot::compute Method

The new overload of the compute method correctly incorporates spin support by introducing the dforce_mag_ and dspin_ parameters. The function calls dp->computew with the appropriate arguments, ensuring that spin effects are accounted for in the computation.


244-262: Addition of Vectorized compute Method with Spin Support

The added vectorized overload of the compute method extends spin support to handle multiple energies and forces. The implementation correctly passes the new dforce_mag_ and dspin_ parameters to dp->computew, enabling spin-aware computations for vector inputs.


264-291: Correct Template Instantiations for New Overloads

The template instantiations for the new compute methods are properly defined for both double and float types. This ensures that the methods supporting spin with different precision levels are available and correctly linked.


954-1012: Extension of DeepPotModelDevi::compute with Spin Parameters

The new compute method overload in DeepPotModelDevi class appropriately adds the dspin_ and all_force_mag parameters to support spin calculations. The loop correctly iterates over all models, invoking the compute method with the new parameters, which ensures consistent spin support across all models.


1073-1140: Addition of Atom-Wise Outputs with Spin Support in DeepPotModelDevi::compute

The extended compute method now includes atom-wise energy and virial outputs alongside spin support. This implementation correctly handles the additional data structures all_atom_energy and all_atom_virial, ensuring that detailed per-atom information is computed and stored for each model.

source/api_c/include/c_api.h (1)

233-247: Verify that memory allocation and deallocation are handled appropriately.

Ensure that any dynamic memory allocations within these functions are properly managed to prevent memory leaks.

To check for dynamic memory allocations in these functions, run:

Also applies to: 287-302, 427-445, 492-510, 807-822, 860-875, 922-940, 987-1005

source/api_c/src/c_api.cc (2)

1304-1323: Ensure consistent naming conventions for spin-related functions

The function DP_DeepPotComputeNListfSP introduces spin support for single precision. Verify that the naming aligns with the project's conventions for indicating single-precision and spin-supporting functions to maintain consistency across the API.

To verify naming consistency, you can search for similar function names:

#!/bin/bash
# Description: Check for naming patterns of spin-supporting functions.

# Test: Search for function names ending with 'SP'. Expect: Consistent naming conventions.
rg --type cpp -o 'void DP_.*SP\('

376-377: Verify the correctness of spin vector initialization

Ensure that the size calculation for the spin_ vector is accurate. The current initialization uses nframes * natoms * 3, which should match the dimensions of the spin input array. Verify that the spin data is provided correctly for all frames and atoms.

To confirm, you can check for inconsistencies in the codebase with:

source/api_c/include/deepmd.hpp (2)

181-227: Ensure consistency in function naming and existence

The specialized functions _DP_DeepPotComputeNListSP<double> and _DP_DeepPotComputeNListSP<float> invoke DP_DeepPotComputeNList2SP and DP_DeepPotComputeNListf2SP respectively. Please verify that these functions are correctly implemented in the C API and follow the established naming conventions.


1045-1049: Verify the size calculation for force_mag

The force_mag vector is resized to nframes * natoms * 3, similar to the force vector. Confirm that this size is appropriate for storing the magnitude of the forces and that it aligns with the output from _DP_DeepPotComputeNListSP.

source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Outdated Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Outdated Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_c/src/c_api.cc Show resolved Hide resolved
source/api_c/src/c_api.cc Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 21

🧹 Outside diff range and nitpick comments (25)
deepmd/tf/entrypoints/freeze.py (1)

264-265: LGTM. Consider using a constant for spin attribute names.

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" to the optional_node list is a good practice. It allows the freezing process to handle graphs that may or may not include these spin-related nodes.

For improved maintainability, consider defining these spin attribute names as constants at the module level. This would ensure consistency and make it easier to update if needed in the future.

Example:

SPIN_ATTR_VIRTUAL_LEN = "spin_attr/virtual_len"
SPIN_ATTR_SPIN_NORM = "spin_attr/spin_norm"

# Then use these constants in both the `_make_node_names` function and the `optional_node` list
deepmd/pt/model/model/spin_model.py (2)

474-475: Consider adding documentation for new parameters.

The new parameters comm_dict and extra_nlist_sort have been added to the forward_common_lower method. To improve code maintainability and usability, consider adding docstring comments explaining the purpose and expected usage of these parameters.

Here's a suggested docstring addition:

def forward_common_lower(
    self,
    extended_coord,
    extended_atype,
    extended_spin,
    nlist,
    mapping: Optional[torch.Tensor] = None,
    fparam: Optional[torch.Tensor] = None,
    aparam: Optional[torch.Tensor] = None,
    do_atomic_virial: bool = False,
    comm_dict: Optional[Dict[str, torch.Tensor]] = None,
    extra_nlist_sort: bool = False,
):
    """
    Forward pass for the lower part of the spin model.

    ...

    Parameters:
    ...
    comm_dict : Optional[Dict[str, torch.Tensor]], optional
        Dictionary for additional communication data, by default None
    extra_nlist_sort : bool, optional
        Whether to perform extra sorting on neighbor lists, by default False
    """
    # ... rest of the method

Line range hint 612-624: LGTM. Consider adding documentation for the new parameter.

The changes look good and are consistent with the modifications in the SpinModel class. The automatic setting of extra_nlist_sort based on the backbone model's requirements is a nice touch for maintaining consistency.

To improve code documentation, consider adding a description for the new comm_dict parameter in the method's docstring.

Here's a suggested docstring addition:

@torch.jit.export
def forward_lower(
    self,
    extended_coord,
    extended_atype,
    extended_spin,
    nlist,
    mapping: Optional[torch.Tensor] = None,
    fparam: Optional[torch.Tensor] = None,
    aparam: Optional[torch.Tensor] = None,
    do_atomic_virial: bool = False,
    comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
    """
    Lower-level forward pass for the spin energy model.

    ...

    Parameters:
    ...
    comm_dict : Optional[Dict[str, torch.Tensor]], optional
        Dictionary for additional communication data, by default None
    """
    # ... rest of the method
source/api_cc/include/DeepPot.h (5)

146-177: Add missing documentation for new computew methods with force_mag and spin.

The newly added computew methods in the DeepPotBase class lack documentation comments explaining the purpose and usage of the new parameters force_mag and spin. Adding appropriate documentation will improve code readability and maintain consistency with the rest of the codebase.


393-420: Provide documentation for new compute methods with force_mag and spin.

The new template compute methods in the DeepPot class include additional parameters force_mag and spin but lack accompanying documentation. Please add comments to explain the functionality and usage of these methods, maintaining consistency with existing documentation.


523-554: Add documentation for compute methods including force_mag and spin.

The newly introduced compute methods in the DeepPot class that handle atomic energies and virials with additional force_mag and spin parameters lack documentation comments. Providing detailed documentation will enhance clarity and assist other developers in understanding these methods.


827-840: Document new compute methods in DeepPotModelDevi with force_mag and spin.

The DeepPotModelDevi class has new compute methods that include force_mag and spin parameters but lack documentation. Please add comments to explain their purpose and usage, ensuring consistency across the codebase.


881-897: Provide documentation for new compute methods including force_mag and spin.

Please add documentation comments for the newly added methods in the DeepPotModelDevi class to maintain clarity and consistency with the rest of the code.

source/api_cc/src/DeepPotPT.cc (1)

446-446: Fix typographical error in comments

The word "suported" is misspelled. Please correct it to "supported" in the affected comment lines.

Also applies to: 462-462, 477-477, 733-733, 748-748

source/lmp/pair_deepmd.cpp (2)

505-505: Initialization of dforce_mag could be optimized

The vector dforce_mag is initialized with zero values. If the subsequent computation fully populates this vector, the initialization may be unnecessary.

You might consider delaying the initialization until it's necessary or using resize instead of initializing with zeros:

-vector<double> dforce_mag(nall * 3);
+vector<double> dforce_mag;
+dforce_mag.resize(nall * 3);

582-586: Avoid unnecessary copying by passing variables directly

The variables dcoord_const and dspin_const are references to dcoord and dspin, respectively. Since these variables are not modified within the compute call, passing dcoord and dspin directly can simplify the code.

Apply this diff to remove unnecessary variables:

-const vector<double> &dcoord_const = dcoord;
-const vector<double> &dspin_const = dspin;
-deep_pot.compute(dener, dforce, dforce_mag, dvirial, dcoord_const,
-                 dspin_const, dtype, dbox, nghost, lmp_list, ago,
+deep_pot.compute(dener, dforce, dforce_mag, dvirial, dcoord,
+                 dspin, dtype, dbox, nghost, lmp_list, ago,
                  fparam, daparam);
source/api_cc/src/DeepPot.cc (1)

954-1012: Document the usage of dspin_ in DeepPotModelDevi::compute

Consider adding comments or documentation to explain how the dspin_ parameter affects computations within DeepPotModelDevi::compute. This will enhance code readability and assist future contributors in understanding the spin-related computations.

source/api_c/include/c_api.h (9)

233-247: Add missing documentation for DP_DeepPotComputeNListSP function

The function DP_DeepPotComputeNListSP is missing a @brief documentation block. To maintain consistency and help users understand its purpose and usage, please add appropriate documentation similar to other functions in the API.


287-302: Add missing documentation for DP_DeepPotComputeNListfSP function

The function DP_DeepPotComputeNListfSP lacks accompanying documentation. Please include a @brief comment to describe the function's purpose, parameters, and usage.


427-445: Add missing documentation for DP_DeepPotComputeNList2SP function

Please add a @brief documentation block for the DP_DeepPotComputeNList2SP function. This will aid users in understanding its functionality and how it differs from existing functions.


492-510: Add missing documentation for DP_DeepPotComputeNListf2SP function

The function DP_DeepPotComputeNListf2SP is missing documentation. Including a @brief comment will provide clarity on its usage and maintain consistency across the API.


807-822: Add missing documentation for DP_DeepPotModelDeviComputeNListSP function

The DP_DeepPotModelDeviComputeNListSP function lacks a @brief documentation block. Adding documentation will help users understand the function's purpose and parameters.


860-875: Add missing documentation for DP_DeepPotModelDeviComputeNListfSP function

Please include a @brief documentation comment for the DP_DeepPotModelDeviComputeNListfSP function to explain its functionality and usage.


921-940: Add missing documentation for DP_DeepPotModelDeviComputeNList2SP function

The function DP_DeepPotModelDeviComputeNList2SP is missing a documentation block. Providing a @brief comment will enhance understandability and maintain consistency.


987-1005: Add missing documentation for DP_DeepPotModelDeviComputeNListf2SP function

The DP_DeepPotModelDeviComputeNListf2SP function lacks accompanying documentation. Please add a @brief comment to describe its purpose and usage.


Line range hint 233-1005: Ensure consistent use of extern keyword in function declarations

There is inconsistency in the use of the extern keyword across function declarations. Some functions, such as DP_DeepPotModelDeviComputeNList2, are declared without extern, while others include it. To maintain consistency and clarity in the API, please ensure that all public function declarations use the extern keyword appropriately.

source/api_c/src/c_api.cc (4)

354-418: Ensure Consistent Naming Convention for Spin Functions

The newly introduced function DP_DeepPotComputeNList_variant_sp adds spin support. To maintain consistency with existing naming conventions in the codebase, consider renaming the function to use the SP suffix, such as DP_DeepPotComputeNList_variantSP.


758-880: Maintain Consistent Naming for Multi-Model Spin Functions

The function DP_DeepPotModelDeviComputeNList_variant_sp adds spin support for multi-model computations. For consistency, consider renaming it to DP_DeepPotModelDeviComputeNList_variantSP, aligning with the naming convention used elsewhere in the codebase.


778-780: Improve Exception Message for Unsupported Frame Count

The exception message "nframes > 1 not supported yet" is thrown when nframes exceeds 1. Providing additional context or guidance could enhance user understanding. Consider updating the message to indicate whether support for multiple frames is planned or suggest alternative approaches.


1266-1284: Document New Spin Functions for Clarity

The newly added functions (DP_DeepPotComputeNListSP, DP_DeepPotComputeNListfSP, DP_DeepPotComputeNList2SP, DP_DeepPotComputeNListf2SP) introduce spin support. To aid users:

  • Provide detailed documentation for each function.
  • Explain the purpose of the spin parameter.
  • Highlight any differences from their non-spin counterparts.

Also applies to: 1304-1323, 1382-1403, 1426-1447

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 6fe8dde and 3466e34.

📒 Files selected for processing (13)
  • deepmd/pt/model/model/spin_model.py (4 hunks)
  • deepmd/tf/entrypoints/freeze.py (2 hunks)
  • source/api_c/include/c_api.h (8 hunks)
  • source/api_c/include/deepmd.hpp (6 hunks)
  • source/api_c/src/c_api.cc (10 hunks)
  • source/api_cc/include/DeepPot.h (5 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/include/DeepPotTF.h (3 hunks)
  • source/api_cc/src/DeepPot.cc (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
  • source/api_cc/src/DeepPotTF.cc (4 hunks)
  • source/lmp/pair_deepmd.cpp (15 hunks)
  • source/lmp/pair_deepmd.h (1 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
source/api_cc/include/DeepPotTF.h

[notice] 378-379: Commented-out code
This comment appears to contain commented-out code.


[notice] 385-385: Commented-out code
This comment appears to contain commented-out code.

source/api_cc/src/DeepPotPT.cc

[notice] 356-356: Unused local variable
Variable nloc is not used.


[notice] 462-466: Commented-out code
This comment appears to contain commented-out code.


[notice] 495-495: Commented-out code
This comment appears to contain commented-out code.


[notice] 687-687: Unused local variable
Variable nframes is not used.


[notice] 748-752: Commented-out code
This comment appears to contain commented-out code.


[notice] 754-754: Commented-out code
This comment appears to contain commented-out code.

source/api_cc/src/DeepPotTF.cc

[notice] 893-893: Unused local variable
Variable ret is not used.


[notice] 905-905: Unused local variable
Variable ret is not used.

source/lmp/pair_deepmd.cpp

[notice] 1248-1248: Commented-out code
This comment appears to contain commented-out code.

🔇 Additional comments (25)
deepmd/tf/entrypoints/freeze.py (2)

127-128: LGTM. Can you provide more context on the spin attributes?

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" nodes for the "ener" model type is consistent with the PR objectives of adding support for spin models.

Could you provide more information about these spin attributes and how they are used in the energy calculations?


131-131: LGTM. Can you elaborate on the 'aparam_nall' attribute?

The addition of the "fitting_attr/aparam_nall" node for the "ener" model type is in line with the PR objectives.

Could you provide more details about this fitting attribute and its role in the energy calculations or model fitting process?

deepmd/pt/model/model/spin_model.py (1)

Line range hint 1-624: Overall, the changes look good and enhance the model's flexibility.

The modifications to both SpinModel and SpinEnergyModel classes consistently add support for additional communication data (comm_dict) and neighbor list sorting (extra_nlist_sort). These changes likely improve the flexibility and potentially the performance of the spin models.

The implementation appears sound, with good consistency between the two classes. The main suggestion for improvement is to add documentation for the new parameters to enhance code maintainability and usability.

source/lmp/pair_deepmd.h (1)

78-78: Variable all_force_mag Added Correctly

The new member variable all_force_mag is properly declared and follows the existing naming conventions in the class. Its addition aligns with the class structure and should integrate well with the existing code.

source/api_cc/include/DeepPotPT.h (2)

132-148: Duplicate comment: Missing documentation for the new compute method overload

As with the previous compute method overload, this version also lacks documentation. Please include documentation comments for this method to maintain consistency.


315-327: Duplicate comment: Missing documentation for the new computew method overload

This computew method overload also lacks documentation. Please ensure all public methods are properly documented.

source/api_cc/include/DeepPot.h (5)

146-177: Method signatures correctly extend the interface for spin support.

The addition of force_mag and spin parameters to the computew methods appropriately extends the functionality of the DeepPotBase class to support spin models. The parameter types and ordering are consistent with existing methods.


393-420: New methods appropriately extend DeepPot functionality for spin models.

The additions of compute methods with force_mag and spin parameters are consistent with the class design and correctly extend the DeepPot interface to support spin-related computations.


523-554: Method implementations are consistent and correctly extend functionality.

The added methods properly extend the existing functionality to support spin models, maintaining consistency in parameter types and ordering.


827-840: Methods correctly extend functionality for model deviation with spin support.

The added methods in DeepPotModelDevi appropriately extend the class to handle spin models in model deviation calculations.


881-897: Method additions are consistent and extend functionality appropriately.

The new methods correctly extend DeepPotModelDevi to support spin models in computations involving atomic energies and virials.

source/api_cc/src/DeepPotPT.cc (2)

314-536: Implementation of new compute function overload with force_mag and spin

The new overload of the compute function correctly introduces the force_mag and spin parameters. The integration aligns well with the existing code structure and extends the functionality appropriately.

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 356-356: Unused local variable
Variable nloc is not used.


[notice] 462-466: Commented-out code
This comment appears to contain commented-out code.


[notice] 495-495: Commented-out code
This comment appears to contain commented-out code.


658-798: Implementation of compute function for spin model

The addition of the new compute function overload for the spin model is implemented correctly. The function handles the spin and force_mag parameters appropriately, enhancing the model's capabilities.

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 687-687: Unused local variable
Variable nframes is not used.


[notice] 748-752: Commented-out code
This comment appears to contain commented-out code.


[notice] 754-754: Commented-out code
This comment appears to contain commented-out code.

source/lmp/pair_deepmd.cpp (2)

1248-1253: Clarify the use of comm_reverse and remove commented-out code

The line // comm_reverse = numb_models * 3; is commented out, and new logic is introduced based on atom->sp_flag. If the old assignment is no longer needed, consider removing it to keep the code clean.

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 1248-1248: Commented-out code
This comment appears to contain commented-out code.


Line range hint 1406-1454: Possible buffer overflow in communication routines

In both pack_reverse_comm and unpack_reverse_comm, ensure that the buffer buf is sufficiently sized to handle the additional data when atom->sp_flag is true. The added loops increase the amount of data being packed and unpacked.

Run the following script to check buffer sizes in communication routines:

source/api_cc/src/DeepPotTF.cc (2)

511-516: Addition of get_vector Method Correctly Implements Vector Retrieval

The newly added get_vector method extends the DeepPotTF class to retrieve vectors from the session using the provided name. The implementation follows the existing code conventions and templates, ensuring type safety and consistency.


828-1019: Implementation of compute Method with Spin Support

The overloaded compute method effectively adds support for spin calculations. It integrates spin-related parameters and modifies the computational logic to accommodate spin interactions. The method handles the extended data appropriately and maintains compatibility with the existing architecture.

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 893-893: Unused local variable
Variable ret is not used.


[notice] 905-905: Unused local variable
Variable ret is not used.

source/api_cc/src/DeepPot.cc (5)

221-242: LGTM: Addition of compute functions with spin support

The newly added compute function overloads correctly incorporate the dspin_ parameter to support spin calculations. The implementation follows existing code patterns and maintains consistency.


244-262: LGTM: Overloaded compute function for vector energies with spin

The overloaded compute function handling vectors correctly adds spin support. Parameters are appropriately passed, and the function aligns with the established interface.


488-596: LGTM: Extended compute functions with atomic outputs and spin support

The new compute functions that include datom_energy_, datom_virial_, and spin parameters are implemented consistently with the existing code structure. This extends functionality while maintaining code consistency.


264-319: Ensure template specializations are correctly instantiated

Please verify that the new template specializations for the compute function with spin support are correctly instantiated for both double and float types and that they are properly utilized elsewhere in the codebase.

Run the following script to confirm that instantiations are correctly defined:


1073-1140: Verify correctness of computations involving spin in model deviations

Ensure that the computations in DeepPotModelDevi::compute correctly handle spin when calculating energy, force, and virial deviations across models. It's important that the spin contributions are accurately accounted for in deviation calculations.

Run the following script to check for consistency in deviation computations:

source/api_c/include/c_api.h (1)

921-940: Verify potential impact on backward compatibility due to function updates

The addition of the new function DP_DeepPotModelDeviComputeNList2SP may suggest an update or replacement of the existing DP_DeepPotModelDeviComputeNList2 function. Please verify whether the original function remains available and unmodified. If any existing functions have altered signatures or have been deprecated, consider the impact on users who rely on the previous API and provide guidance or versioning to handle backward compatibility.

source/api_c/src/c_api.cc (2)

1270-1270: Confirm Const Correctness of Spin Parameters

The spin parameter is declared as a const pointer (e.g., const double* spin). Ensure that the data pointed to by spin is not modified within the function. If modifications are necessary, remove the const qualifier.

Review the implementation to ensure spin is treated as read-only.

Also applies to: 1309-1309, 1387-1387, 1430-1430


376-377: Verify the Size of the Spin Vector Initialization

Ensure that the spin vector spin_ is initialized with the correct size of nframes * natoms * 3. A mismatch in size could lead to out-of-bounds access or incorrect computations.

Run the following script to confirm the size of the spin vector:

✅ Verification successful

[/run_scripts]


Verified: Spin Vector Initialization Sizes are Correct

The initialization of spin_ with size nframes * natoms * 3 at lines 376-377 is correct. Additionally, the initialization with natoms * 3 at line 783 is appropriate within its context where nframes is restricted to 1.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the size of the spin vector initialization
# Expected: The size matches nframes * natoms * 3

grep -rn 'std::vector<VALUETYPE> spin_' source/api_c/src/ | grep 'spin'

# Check for correct size calculation in the code

Length of output: 255


Script:

#!/bin/bash
# Description: Extract lines around line 783 in c_api.cc to understand the context of spin_ initialization

# Extract 10 lines before and after line 783
sed -n '773,793p' source/api_c/src/c_api.cc

Length of output: 886

source/lmp/pair_deepmd.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Outdated Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 26

🧹 Outside diff range and nitpick comments (15)
deepmd/tf/entrypoints/freeze.py (1)

264-265: LGTM. Consider extracting spin attributes to a constant.

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" to the optional_node list is consistent with the changes in the _make_node_names function. This improves the robustness of the freezing process for different model configurations.

For better code readability and maintainability, consider extracting the spin attribute names to a constant at the top of the file:

SPIN_ATTR_NODES = [
    "spin_attr/ntypes_spin",
    "spin_attr/virtual_len",
    "spin_attr/spin_norm",
]

Then, you can use this constant in both the _make_node_names function and the optional_node list:

optional_node = [
    "train_attr/min_nbor_dist",
    "fitting_attr/aparam_nall",
    *SPIN_ATTR_NODES,
]

This approach would make it easier to manage these related attributes in the future.

deepmd/pt/model/model/spin_model.py (1)

Line range hint 474-497: LGTM! Consider adding documentation for new parameters.

The changes to forward_common_lower look good. The addition of comm_dict and extra_nlist_sort parameters enhances the method's flexibility.

Consider adding docstring comments to explain the purpose and usage of these new parameters:

  • comm_dict: Explain its role in communication and any expected key-value pairs.
  • extra_nlist_sort: Describe when this should be set to True and its impact on the neighbor list.
source/api_cc/include/DeepPotTF.h (1)

384-384: Consider using smart pointers for extend_firstneigh

The member variable extend_firstneigh is defined as std::vector<int*>. Using raw pointers can lead to memory management issues. Consider using smart pointers or ensuring that the ownership and lifecycle of these pointers are well-managed.

source/api_cc/include/DeepPotPT.h (1)

Line range hint 356-371: Document the new computew method overload with force_mag and spin

The computew method overload at lines 356-371 introduces additional parameters force_mag and spin. Adding comprehensive documentation for this method, including details about all parameters and their purposes, will improve clarity for users and maintain consistency across the API.

source/api_cc/include/DeepPot.h (1)

523-554: Provide default values where appropriate

Consider providing default values for the new parameters force_mag and spin if they are optional. This can help maintain backward compatibility and ease of use for existing code that does not require these parameters.

source/lmp/pair_deepmd.cpp (3)

657-666: Ensure the model deviation functionality is thoroughly tested.

The added code introduces support for model deviation calculations when atom->sp_flag is not set. Ensure that this new functionality is thoroughly tested, covering scenarios with and without eflag_atom or cvflag_atom set.

Consider adding unit tests or integration tests to cover the new model deviation functionality and ensure its correctness.


677-695: Ensure the model deviation functionality is thoroughly tested with spin interactions.

The added code introduces support for model deviation calculations when atom->sp_flag is set, indicating spin interactions. Ensure that this new functionality is thoroughly tested, covering scenarios with and without eflag_atom or cvflag_atom set, and verifying the correctness of the spin-related computations.

Consider adding unit tests or integration tests to cover the new model deviation functionality with spin interactions and ensure its correctness.


848-848: Implement support for spin atomic force output.

The comment indicates that support for spin atomic force output is needed. Implement this functionality to ensure that the atomic forces associated with spin interactions are correctly output when out_each is set to 1.

Do you want me to provide a code snippet that implements the spin atomic force output? I can help generate the necessary code to gather and output the spin atomic forces.

source/api_cc/src/DeepPot.cc (2)

244-263: Add unit tests for the new spin-aware 'compute' methods

The new overloads of the compute method now support spin interactions. To ensure correctness and prevent future regressions, consider adding unit tests that specifically test these new methods with various spin configurations.


1073-1140: Update documentation to reflect new spin capabilities

The addition of spin support in the compute methods is a significant change. Remember to update the documentation and any relevant API references to inform users about the new parameters and functionality.

source/api_c/src/c_api.cc (5)

1266-1284: Suggest adding unit tests for DP_DeepPotComputeNListSP

The new function DP_DeepPotComputeNListSP introduces spin support into the public API. To ensure its correctness and prevent future regressions, consider adding unit tests covering various scenarios, including edge cases.

Would you like assistance in generating unit tests for this function?


1304-1323: Suggest adding unit tests for DP_DeepPotComputeNListfSP

Similarly, the DP_DeepPotComputeNListfSP function extends spin support for single-precision floats. Adding corresponding unit tests would help validate its functionality.

Would you like assistance in creating unit tests for this function?


1425-1448: Suggest adding unit tests for DP_DeepPotComputeNListf2SP

The addition of DP_DeepPotComputeNListf2SP introduces spin support for multi-frame computations with single-precision floats. Adding unit tests will help ensure its reliability.

Would you like assistance in developing unit tests for this function?


1633-1652: Suggest adding unit tests for DP_DeepPotModelDeviComputeNListfSP

To ensure the correctness of the new DP_DeepPotModelDeviComputeNListfSP function, consider adding unit tests that cover various scenarios with single-precision floats.

Would you like help in creating unit tests for this function?


1718-1740: Suggest adding unit tests for DP_DeepPotModelDeviComputeNListf2SP

Adding unit tests for the DP_DeepPotModelDeviComputeNListf2SP function will help verify its functionality with single-precision floats in multi-frame computations.

Do you need assistance in developing unit tests for this function?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 6fe8dde and 3466e34.

📒 Files selected for processing (13)
  • deepmd/pt/model/model/spin_model.py (4 hunks)
  • deepmd/tf/entrypoints/freeze.py (2 hunks)
  • source/api_c/include/c_api.h (8 hunks)
  • source/api_c/include/deepmd.hpp (6 hunks)
  • source/api_c/src/c_api.cc (10 hunks)
  • source/api_cc/include/DeepPot.h (5 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/include/DeepPotTF.h (3 hunks)
  • source/api_cc/src/DeepPot.cc (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
  • source/api_cc/src/DeepPotTF.cc (4 hunks)
  • source/lmp/pair_deepmd.cpp (15 hunks)
  • source/lmp/pair_deepmd.h (1 hunks)
🧰 Additional context used
🔇 Additional comments (56)
deepmd/tf/entrypoints/freeze.py (3)

127-128: LGTM. Can you provide more context on the new spin attributes?

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" to the list of nodes for the "ener" model type looks good. These changes align with the PR objective of adding support for spin models.

Could you please provide some additional context on what these new spin attributes represent and how they are used in the spin model calculations?


Line range hint 1-399: Overall assessment: Changes look good with minor suggestions

The modifications to freeze.py successfully add support for spin attributes in the graph freezing process, aligning well with the PR objectives. The changes are implemented in a backwards-compatible manner and extend the functionality as intended.

Key points:

  1. New spin attributes (virtual_len and spin_norm) are added to both the _make_node_names function and the optional_node list.
  2. The changes improve the robustness of the freezing process for different model configurations.

Suggestions for improvement:

  1. Consider extracting spin attribute names to a constant for better code readability and maintainability.
  2. Verify the consistent usage of these new spin attributes throughout the codebase where spin models are implemented or processed.

Overall, the changes look good and are ready for merging after addressing the minor suggestions and verifying the usage of new spin attributes in the broader codebase.


127-128: Verify usage of new spin attributes in the codebase

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" looks good and appears to be backwards compatible. However, it's important to ensure that these new attributes are properly handled in other parts of the codebase where spin models are used.

Let's verify the usage of these new spin attributes:

Please review the output of this script to ensure that the new spin attributes are consistently used throughout the codebase where spin models are implemented or processed.

Also applies to: 264-265

✅ Verification successful

Let's retry searching for the new spin attributes with the correct file type.


Usage of new spin attributes verified successfully

The new spin attributes "spin_attr/virtual_len" and "spin_attr/spin_norm" are consistently used across the codebase and are properly integrated without introducing any issues.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for usage of new spin attributes in the codebase

# Search for "virtual_len" usage
echo "Searching for 'virtual_len' usage:"
rg "virtual_len" --type python

# Search for "spin_norm" usage
echo "Searching for 'spin_norm' usage:"
rg "spin_norm" --type python

# Search for general spin attribute usage
echo "Searching for general spin attribute usage:"
rg "spin_attr" --type python

Length of output: 444


Script:

#!/bin/bash
# Description: Check for usage of new spin attributes in the codebase with corrected file type.

# Search for "virtual_len" usage
echo "Searching for 'virtual_len' usage:"
rg "virtual_len" --type py

# Search for "spin_norm" usage
echo "Searching for 'spin_norm' usage:"
rg "spin_norm" --type py

# Search for general spin attribute usage
echo "Searching for 'spin_attr' usage:"
rg "spin_attr" --type py

Length of output: 3279

deepmd/pt/model/model/spin_model.py (1)

Line range hint 612-624: LGTM! Consistent implementation with parent class.

The addition of the comm_dict parameter to forward_lower in SpinEnergyModel is consistent with the changes in the parent SpinModel class. This maintains a coherent interface across the model hierarchy.

source/lmp/pair_deepmd.h (1)

78-78: Ensure proper initialization and management of all_force_mag

The new member variable all_force_mag has been added. Please verify that it is properly initialized in the constructor and cleared in the destructor if necessary to prevent potential uninitialized access or memory leaks.

source/api_cc/include/DeepPot.h (4)

881-897: Ensure consistency across all compute method overloads

With the addition of new overloads that include force_mag and spin, double-check that all compute methods across the codebase maintain consistent signatures and parameter order where appropriate.

Run the following script to compare all compute method signatures:

#!/bin/bash
# Description: List all compute method signatures for comparison.

# Extract all compute method declarations
rg --type-add 'cpp:h,cpp' --type cpp 'void compute' -A 5

# Review for consistency in parameter lists

393-420: Check for consistency in template specializations

Ensure that all template specializations and usages of the compute function are updated to accommodate the new parameters. Inconsistent templates might lead to compilation errors or unexpected behavior.

Run the following script to identify all usages of compute and verify that they include the new parameters:

#!/bin/bash
# Description: Find all usages of the compute template function and check for new parameters.

# Search for calls to compute functions
rg --type-add 'cpp:h,cpp' --type cpp 'compute<' -A 5

827-840: Verify correct integration of new parameters

Ensure that the implementation of the new compute methods in DeepPotModelDevi correctly processes the force_mag and spin parameters. Incorrect handling could lead to inaccurate calculations or runtime errors.

You can run the following script to check the method implementations:

#!/bin/bash
# Description: Verify implementations of compute methods with force_mag and spin in DeepPotModelDevi.

# Search for compute method definitions in source files
rg --type-add 'cpp:h,cpp' --type cpp 'void DeepPotModelDevi::compute' -A 10

# Look for proper handling of force_mag and spin within these methods

146-161: Ensure implementation in derived classes

Verify that all derived classes of DeepPotBase correctly implement the new computew overloads with force_mag and spin. Missing implementations could lead to runtime errors due to pure virtual method calls.

You can run the following script to check for implementations of the new methods in all derived classes:

source/lmp/pair_deepmd.cpp (16)

606-608: Ensure const-correctness of dcoord and dspin in the compute call.

Passing dcoord and dspin as const references is a good practice to ensure const-correctness and avoid unintended modifications. Verify that the compute method in the DeepPot class is updated to accept these parameters as const references.


1248-1248: Commented-out code

This comment appears to contain commented-out code.

Show more details


1435-1447: Ensure the reverse communication unpacking for spin interactions is correct.

The added code unpacks the force and force magnitude data from the communication buffer when atom->sp_flag is set, indicating spin interactions. Verify that the unpacking order and indexing are correct, and that the received data is correctly accumulated into the all_force and all_force_mag arrays.

To verify, compare the unpacking logic with the corresponding packing logic in pack_reverse_comm:

#!/bin/bash
# Description: Verify reverse communication unpacking for spin interactions is consistent with packing.

# Test: Compare unpacking and packing logic for spin interactions. Expect: Consistent ordering and indexing of force and force magnitude data.
diff -u <(rg --type cpp $'unpack_reverse_comm.*atom->sp_flag' -A 15) <(rg --type cpp $'pack_reverse_comm.*atom->sp_flag' -A 15)

1179-1188: Ensure the output header is correctly formatted.

The added code outputs the header for the model deviation statistics when atom->sp_flag is not set. Verify that the header format is correct, including the column widths and labels, and that it is consistent with the rest of the output.

To verify, search for other occurrences of the header output and compare the format:


830-845: Ensure the spin-related force magnitude statistics are correctly output.

The added code outputs the spin-related force magnitude statistics (all_fm_max, all_fm_min, all_fm_avg) when atom->sp_flag is set. Verify that the output format is correct and consistent with the rest of the output, and that the statistics are converted to the appropriate units using force_unit_cvt_factor.

To verify, search for other occurrences of the spin-related force magnitude statistics and analyze their usage in the output:

✅ Verification successful

Spin-Related Force Magnitude Statistics Output Verified

The spin-related force magnitude statistics (all_fm_max, all_fm_min, all_fm_avg) are correctly output with consistent formatting and appropriate unit conversion using force_unit_cvt_factor.

  • Verified output statements in source/lmp/pair_deepmd.cpp correctly include all_fm_max, all_fm_min, and all_fm_avg.
  • Confirmed that unit conversion is applied before outputting the statistics.
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify spin-related force magnitude statistics are correctly output.

# Test: Search for occurrences of spin-related force magnitude statistics in output statements. Expect: Consistent output format and unit conversion.
rg --type cpp $'all_fm_min|all_fm_max|all_fm_avg' -A 5 | rg --type cpp $'fp <<'

Length of output: 280


Script:

#!/bin/bash
# Description: Verify spin-related force magnitude statistics are correctly output.

# Corrected Test: Search for occurrences of spin-related force magnitude statistics in output statements. Expect: Consistent output format and unit conversion.
rg --type cpp 'fp\s*<<[^;]*\ball_fm_min\b|\ball_fm_max\b|\ball_fm_avg\b' -A 5

Length of output: 2358


1406-1417: Ensure the reverse communication packing for spin interactions is correct.

The added code packs the force and force magnitude data into the communication buffer when atom->sp_flag is set, indicating spin interactions. Verify that the packing order and indexing are correct, and that the buffer size is sufficient to accommodate both the force and force magnitude components.

To verify, compare the packing logic with the corresponding unpacking logic in unpack_reverse_comm:


1190-1195: Ensure the output header for spin interactions is correctly formatted.

The added code outputs the header for the model deviation statistics when atom->sp_flag is set, indicating spin interactions. Verify that the header format is correct, including the column widths and labels for the spin-related force magnitudes (max_devi_fm, min_devi_fm, avg_devi_fm), and that it is consistent with the rest of the output.

To verify, search for other occurrences of the header output and compare the format:


1249-1253: Ensure the correct communication size for spin interactions.

The added code sets the comm_reverse size based on whether atom->sp_flag is set, indicating spin interactions. When spin interactions are present, the size is set to numb_models * 3 * 2, accounting for both the force and force magnitude components. Verify that this size is correct and consistent with the data being communicated for spin interactions.

To verify, search for other occurrences of comm_reverse and analyze how it is used in communication:

✅ Verification successful

Communication size for spin interactions is correctly set.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify `comm_reverse` size is consistent with spin interaction data communication.

# Test: Search for occurrences of `comm_reverse` and analyze usage in communication. Expect: Consistent usage with spin interaction data size.
rg --type cpp $'comm_reverse' -A 5

Length of output: 237


Script:

#!/bin/bash
# Description: Search for all occurrences of `comm_reverse` in C++ files and provide context for each match.

rg --type cpp 'comm_reverse' -C 2

Length of output: 893


755-756: Ensure std_fm and tmp_avg_fm are correctly populated and used.

The new variables std_fm and tmp_avg_fm seem to be related to the standard deviation and average of the force magnitudes. Verify that they are correctly populated and used in the appropriate places, such as in the output or analysis of the spin-related forces.

To verify, search for other occurrences of std_fm and tmp_avg_fm and analyze their usage:

✅ Verification successful

[/run_scripts]


std_fm and tmp_avg_fm are correctly populated and used.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify `std_fm` and `tmp_avg_fm` are correctly populated and used.

# Test: Search for occurrences of `std_fm` and `tmp_avg_fm` and analyze usage. Expect: Consistent population and usage in spin-related force output or analysis.
rg --type cpp $'std_fm|tmp_avg_fm' -A 5

Length of output: 280


Script:

#!/bin/bash
# Description: Corrected script to verify `std_fm` and `tmp_avg_fm` are correctly populated and used.

# Test: Search for occurrences of `std_fm` and `tmp_avg_fm` and analyze their usage. Expect: Consistent population and usage in spin-related force output or analysis.
rg --type cpp 'std_fm|tmp_avg_fm' -C 5

Length of output: 2257


770-783: Ensure the spin-related force magnitude statistics are correctly calculated and used.

The added code calculates the minimum, maximum, and average of the force magnitudes when atom->sp_flag is set, indicating spin interactions. Verify that these statistics are correctly calculated using MPI_Reduce and used in the appropriate places, such as in the output or analysis of the spin-related forces.

To verify, search for other occurrences of the spin-related force magnitude statistics and analyze their usage:

✅ Verification successful

Spin-related force magnitude statistics are correctly calculated and used.

The statistics all_fm_min, all_fm_max, and all_fm_avg are consistently calculated using MPI_Reduce and utilized appropriately in the output and analysis of spin-related forces.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify spin-related force magnitude statistics are correctly calculated and used.

# Test: Search for occurrences of spin-related force magnitude statistics and analyze usage. Expect: Consistent calculation using `MPI_Reduce` and usage in output or analysis.
rg --type cpp $'all_fm_min|all_fm_max|all_fm_avg' -A 5

Length of output: 280


Script:

#!/bin/bash
# Description: Verify spin-related force magnitude statistics are correctly calculated and used.

# Corrected Test: Search for occurrences of spin-related force magnitude statistics and analyze usage. Expect: Consistent calculation using `MPI_Reduce` and usage in output or analysis.
rg --type cpp 'all_fm_min|all_fm_max|all_fm_avg' -A 5

Length of output: 2436


765-766: Ensure all_fm_min, all_fm_max, and all_fm_avg are correctly populated and used.

The new variables all_fm_min, all_fm_max, and all_fm_avg seem to be related to the minimum, maximum, and average of the force magnitudes across all processes. Verify that they are correctly populated using MPI_Reduce and used in the appropriate places, such as in the output or analysis of the spin-related forces.

To verify, search for other occurrences of these variables and analyze their usage:

✅ Verification successful

Re-running Verification for Variable Usage


all_fm_min, all_fm_max, and all_fm_avg are correctly populated and used.

The variables all_fm_min, all_fm_max, and all_fm_avg are properly initialized, populated using MPI_Reduce, and utilized in the output for analyzing spin-related forces.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify `all_fm_min`, `all_fm_max`, and `all_fm_avg` are correctly populated and used.

# Test: Search for occurrences of these variables and analyze usage. Expect: Consistent population using `MPI_Reduce` and usage in spin-related force output or analysis.
rg --type cpp $'all_fm_min|all_fm_max|all_fm_avg' -A 5

Length of output: 280


Script:

# 
#!/bin/bash
# Description: Verify `all_fm_min`, `all_fm_max`, and `all_fm_avg` are correctly populated and used.

# Corrected Test: Search for occurrences of these variables and analyze their usage.
rg --type cpp 'all_fm_min|all_fm_max|all_fm_avg' -A 5

Length of output: 2436


703-703: Verify the assignment of dforce_mag from all_force_mag.

Ensure that assigning dforce_mag from all_force_mag[0] is correct and consistent with the expected behavior of the model deviation functionality when spin interactions are involved.

To verify, search for other occurrences of all_force_mag and analyze how it is populated and used:

✅ Verification successful

Verified the assignment of dforce_mag from all_force_mag is correct and consistent with the expected behavior.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify `all_force_mag` population and usage is consistent with `dforce_mag` assignment.

# Test: Search for occurrences of `all_force_mag` and analyze usage. Expect: Consistent population and usage with `dforce_mag` assignment.
rg --type cpp $'all_force_mag' -A 5

Length of output: 280


Script:

#!/bin/bash
# Description: Corrected script to verify `all_force_mag` population and usage consistency.

# Search for exact occurrences of `all_force_mag`
rg --type cpp 'all_force_mag' -A 5

# Additionally, search for possible variants to ensure comprehensive coverage
rg --type cpp 'all_fm_min|all_fm_max|all_fm_avg' -A 5

Length of output: 10874


505-506: Ensure force magnitude vector is correctly populated and used.

The new dforce_mag vector seems to store force magnitudes. Verify that it is correctly populated by the model and used in the appropriate places, such as in the spin force calculations.

To verify, search for other occurrences of dforce_mag and analyze its usage:


493-493: Verify the spin vector calculation.

Ensure that the spin vector calculation dspin[ii * 3 + dd] = sp[ii][dd] * sp[ii][3] is correct and consistent with the expected spin representation in the model.

To verify, search for other occurrences of sp and analyze how the spin data is populated and used:


919-921: Verify the force calculation for spin interactions.

The added code calculates the forces for spin interactions by adding the scaled dforce values to f[ii][dd]. Ensure that this calculation is correct and consistent with the expected force contributions from spin interactions.

To verify, search for other occurrences of f[ii][dd] and analyze how the forces are accumulated:

✅ Verification successful

Force calculation for spin interactions verified successfully in source/lmp/pair_deepmd.cpp lines 919-921. No inconsistencies found with force accumulation for spin interactions.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify force calculation for spin interactions is consistent with force accumulation.

# Test: Search for occurrences of `f[ii][dd]` and analyze force accumulation. Expect: Consistent accumulation of spin interaction forces.
rg --type cpp $'f\[ii\]\[dd\]' -A 5

Length of output: 237


Script:

#!/bin/bash
# Description: Search for all occurrences of force accumulation in C++ files to verify consistency in spin interactions.

# Search for any indexed access to 'f' and capture surrounding lines for context.
rg --type cpp 'f\[[0-9]+\]\[[0-9]+\]' -A 5

Length of output: 352


Script:

#!/bin/bash
# Description: Broadly search for any indexed access to 'f' in C++ files to verify force accumulation related to spin interactions.

# Search for any indexed access to 'f' and capture surrounding lines for context.
rg --type cpp 'f\[' -A 5

Length of output: 148492


582-586: Ensure const-correctness of dcoord and dspin in the compute call.

Passing dcoord and dspin as const references is a good practice to ensure const-correctness and avoid unintended modifications. Verify that the compute method in the DeepPot class is updated to accept these parameters as const references.

To verify, check the declaration of the compute method in the DeepPot class:

✅ Verification successful

[/run_scripts]


[/run_scripts]


[/run_scripts]


Const-correctness of dcoord and dspin Verified

All instances of the DeepPot::compute method in the codebase accept dcoord and dspin as const references, ensuring const-correctness.

  • source/api_cc/src/DeepPot.cc lines: 141-148, 227-235, 291-299, 397-405, 497-505, 519-527
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify `compute` method accepts `dcoord` and `dspin` as const references.

# Test: Search for the declaration of the `compute` method. Expect: `dcoord` and `dspin` parameters are const references.
rg --type cpp $'DeepPot::compute\(.*dcoord.*dspin.*\)' -o -r '$1'

Length of output: 280


Script:

#!/bin/bash
# Description: Verify `compute` method accepts `dcoord` and `dspin` as const references.

# Test: Search for the declaration of the `compute` method. Expect: `dcoord` and `dspin` parameters are const references.
rg --type cpp 'DeepPot::compute\((const\s+vector<double>&\s+dcoord, const\s+vector<double>&\s+dspin)' --only-matching

Length of output: 119


Script:

#!/bin/bash
# Description: Locate all declarations of the `DeepPot::compute` method.

# Test: Search for any `DeepPot::compute` method declarations and extract their signatures.
rg --type cpp 'DeepPot::compute\s*\(.*\)' --pretty --context 0 --multiline

Length of output: 76


Script:

#!/bin/bash
# Description: Locate all declarations and definitions of the `DeepPot::compute` method.

# Test: Search for any `DeepPot::compute` method declarations and extract their full signatures.
rg --type cpp 'DeepPot::compute\s*\([^)]*\)' --pretty --multiline

Length of output: 8346

source/api_cc/src/DeepPotTF.cc (1)

511-516: Function get_vector is correctly implemented.

The get_vector method provides a necessary addition to retrieve vectors from the session. The implementation aligns with the existing codebase and follows appropriate template usage.

source/api_cc/src/DeepPot.cc (1)

221-243: ⚠️ Potential issue

Ensure 'dener_' is populated before accessing its first element

In the compute method, you set dener = dener_[0]; without explicitly checking if dener_ has at least one element. While it's likely that dp->computew populates dener_, it's good practice to ensure that dener_ is not empty before accessing dener_[0] to prevent potential out-of-range errors.

Run the following script to verify that dener_ always has at least one element after calling dp->computew:

source/api_c/include/c_api.h (21)

287-302: Add missing documentation for DP_DeepPotComputeNListfSP.

Similar to DP_DeepPotComputeNListSP, this new function lacks a documentation comment block. Please provide detailed documentation to explain the function's purpose, parameters, and any important information.


287-302: Inconsistent parameter naming: natom vs. natoms.

The parameter natoms should be standardized to natom for consistency across the API.


287-302: Ensure consistent parameter order with existing API functions.

The introduction of the spin parameter changes the parameter order. Please ensure consistency or update documentation accordingly.


427-445: Add missing documentation for DP_DeepPotComputeNList2SP.

This function lacks a documentation block. Including detailed documentation will help users understand its usage and parameters.


427-445: Inconsistent parameter naming: natom vs. natoms.

Please standardize the parameter name to natom for consistency.


427-445: Ensure consistent parameter order with existing API functions.

The placement of the spin parameter differs from existing functions. Consider adjusting the parameter order or updating documentation.


492-510: Add missing documentation for DP_DeepPotComputeNListf2SP.

Please add a documentation block for clarity and consistency.


492-510: Inconsistent parameter naming: natom vs. natoms.

Standardize the parameter name to natom.


492-510: Ensure consistent parameter order with existing API functions.

The spin parameter's position should be consistent with other functions.


807-822: Add missing documentation for DP_DeepPotModelDeviComputeNListSP.

Documentation is missing for this new function. Including it will enhance understanding and usability.


807-822: Inconsistent parameter naming: natom vs. natoms.

Please use natom to match existing functions.


807-822: Ensure consistent parameter order with existing API functions.

Review the placement of the spin parameter for consistency.


860-875: Add missing documentation for DP_DeepPotModelDeviComputeNListfSP.

Adding documentation will help users understand this function's usage.


860-875: Inconsistent parameter naming: natom vs. natoms.

Use natom for consistency.


860-875: Ensure consistent parameter order with existing API functions.

Consider adjusting the spin parameter's position.


921-940: Add missing documentation for DP_DeepPotModelDeviComputeNList2SP.

Please include a documentation block for this function.


921-940: Inconsistent parameter naming: natom vs. natoms.

Standardize to natom.


921-940: Ensure consistent parameter order with existing API functions.

Review the parameter order for consistency.


987-1005: Add missing documentation for DP_DeepPotModelDeviComputeNListf2SP.

Documentation is needed for clarity.


987-1005: Inconsistent parameter naming: natom vs. natoms.

Use natom to remain consistent.


987-1005: Ensure consistent parameter order with existing API functions.

Adjust the spin parameter's position for consistency.

source/api_c/src/c_api.cc (6)

354-418: Well-structured implementation of spin support in DP_DeepPotComputeNList_variant_sp

The addition of the DP_DeepPotComputeNList_variant_sp function effectively extends spin support. The implementation follows the existing coding patterns, ensuring consistency and maintainability.


419-454: Template instantiations correctly added for spin variant

The explicit template instantiations for double and float types ensure that the spin variant functions are available for both data types, aligning with the rest of the API.


758-880: Extension of spin support to DP_DeepPotModelDeviComputeNList_variant_sp

The function DP_DeepPotModelDeviComputeNList_variant_sp properly incorporates spin support into the model deviation computations. The handling of inputs and outputs is consistent with the existing codebase.


1382-1404: Spin support extended to multi-frame computations in DP_DeepPotComputeNList2SP

The function DP_DeepPotComputeNList2SP appropriately extends spin support to multi-frame computations. The parameter handling and function structure are consistent with the rest of the API.


1595-1614: Expansion of spin support in DP_DeepPotModelDeviComputeNListSP

The DP_DeepPotModelDeviComputeNListSP function successfully incorporates spin support for model deviation computations. The implementation aligns with the existing code structure.


1674-1696: Spin support in multi-frame model deviation with DP_DeepPotModelDeviComputeNList2SP

The function DP_DeepPotModelDeviComputeNList2SP appropriately extends spin support to multi-frame model deviation computations. The code is well-integrated and maintains consistency.

source/api_c/include/deepmd.hpp (2)

160-180: Addition of spin support in _DP_DeepPotComputeNListSP is correctly implemented

The new template function and its specializations for handling computations with spin are properly defined and aligned with existing code patterns.


391-453: Addition of spin support in _DP_DeepPotModelDeviComputeNListSP is appropriate

The introduction of the template function and its specializations for computations including spin in the DeepPotModelDevi class is correctly structured.

source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Outdated Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_c/include/c_api.h Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 19

🧹 Outside diff range and nitpick comments (20)
deepmd/tf/entrypoints/freeze.py (2)

127-128: LGTM. Consider adding documentation for the new spin attributes.

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" nodes for the "ener" model type is consistent with the PR objectives of adding LAMMPS support for spin models. This enhancement will allow the frozen graph to include necessary information for spin-related computations.

To improve maintainability, consider adding a brief comment explaining the purpose of these new spin attributes and when they are used.


264-265: LGTM. Consider grouping related optional nodes.

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" to the optional_node list is appropriate. This ensures backward compatibility with existing models that may not have these spin attributes.

To improve code readability, consider grouping related optional nodes together. For example:

optional_node = [
    "train_attr/min_nbor_dist",
    "fitting_attr/aparam_nall",
    # Spin-related attributes
    "spin_attr/ntypes_spin",
    "spin_attr/virtual_len",
    "spin_attr/spin_norm",
]

This grouping makes it easier to identify and manage related optional nodes.

deepmd/pt/model/model/spin_model.py (1)

612-613: LGTM! Consider adding a brief comment for comm_dict.

The changes in the SpinEnergyModel class are consistent with those in SpinModel and should improve the model's performance and flexibility for energy calculations in spin systems. The use of self.backbone_model.need_sorted_nlist_for_lower() for extra_nlist_sort is a good practice.

Consider adding a brief comment explaining the purpose of comm_dict for future maintainability. For example:

comm_dict: Optional[Dict[str, torch.Tensor]] = None,  # Dictionary for additional communication data

Also applies to: 623-624

source/api_cc/include/DeepPotTF.h (3)

282-313: Update Documentation for computew Method Overloads

The computew method now includes additional parameters force_mag and spin. To maintain clarity and assist future developers, please update the method's documentation to reflect these new parameters, providing explanations of their purpose and how they influence the computations.


339-358: Add Documentation for New extend Method

The newly introduced extend method lacks documentation. For better maintainability and ease of understanding, please add comprehensive comments explaining the purpose of this method, detailing each parameter, and highlighting any important considerations or usage examples.


368-369: Document the Template Method get_vector

The template method get_vector is added without accompanying documentation. Please include comments that describe the purpose of this method, explain the template and function parameters, and provide any necessary usage details to aid in understanding and future maintenance.

source/api_cc/include/DeepPotPT.h (4)

77-90: Add documentation for the new compute method overload with force_mag and spin parameters

The new overload of the compute method includes additional parameters force_mag and spin. To maintain code readability and help users understand the purpose and usage of these parameters, please add appropriate documentation comments for this method.


132-148: Document the compute method overload with neighbor list parameters and new arguments

This overload of the compute method introduces force_mag, spin, nghost, lmp_list, and ago parameters but lacks accompanying documentation. Please provide detailed documentation comments to explain these parameters and their usage.


304-314: Provide documentation for the new computew method overloads with force_mag and spin

The computew methods added here include force_mag and spin parameters. To ensure consistency and clarity, please document these methods, explaining the purpose of the new parameters.


356-363: Add documentation for the computew method overload with neighbor list parameters and new arguments

This overload of the computew method includes additional parameters force_mag, spin, nghost, inlist, and ago but lacks documentation. Providing detailed comments will aid in understanding and using this method correctly.

source/api_cc/include/DeepPot.h (5)

146-177: Include documentation for new parameters force_mag and spin

The newly added overloads of the computew method in the DeepPotBase class introduce the parameters force_mag and spin. However, the function's documentation does not reflect these additions. Updating the documentation will enhance clarity and maintain consistency.


393-420: Add documentation for new parameters in compute methods

In the DeepPot class, the compute methods now include additional parameters force_mag and spin in their overloads. Please update the function comments to include descriptions of these new parameters to maintain comprehensive documentation.


523-554: Document new parameters force_mag and spin in function comments

The overloads of the compute methods in the DeepPot class have been extended with force_mag and spin. The current documentation does not mention these parameters. Ensuring that all parameters are documented helps users understand the method interfaces fully.


827-840: Update documentation for added parameters in DeepPotModelDevi

The compute method overloads in the DeepPotModelDevi class now accept force_mag and spin as additional parameters. The accompanying documentation should be revised to include these parameters, providing clear guidance on their usage.


881-897: Ensure all new parameters are included in method documentation

The latest overloads of the compute methods in DeepPotModelDevi introduce force_mag and spin but the method comments have not been updated accordingly. For consistency and clarity, please add descriptions of these parameters to the documentation.

source/api_cc/src/DeepPotPT.cc (1)

446-446: Correct the typo 'suported' to 'supported' in comments

There is a typo in the comments at lines 446, 462, 476, 732, and 748: "suported" should be corrected to "supported".

Also applies to: 462-462, 476-476, 732-732, 748-748

source/api_c/src/c_api.cc (2)

1304-1323: Maintain Consistent Naming Conventions

The introduction of DP_DeepPotComputeNListfSP should follow the project's naming conventions. Ensure that the naming is consistent with existing functions to avoid confusion.

Review the function names to confirm they adhere to the established patterns.


1595-1614: Add Documentation for New Functionality

The function DP_DeepPotModelDeviComputeNListSP lacks comments explaining its purpose and usage. Adding documentation will improve code readability and assist future developers.

Include a descriptive comment block above the function declaration.

source/api_c/include/deepmd.hpp (2)

160-203: Add Documentation for Spin Support Functions

The newly added template functions _DP_DeepPotComputeNListSP and their specializations introduce spin support in the computations. To enhance maintainability and readability, please add documentation comments explaining the purpose, parameters, and usage of these functions.


391-452: Add Documentation for DeepPotModelDevi Spin Support Functions

The new template functions _DP_DeepPotModelDeviComputeNListSP and their specializations introduce spin support in the DeepPotModelDevi class. Adding documentation comments will help users understand how to use these functions properly.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 6fe8dde and 3466e34.

📒 Files selected for processing (13)
  • deepmd/pt/model/model/spin_model.py (4 hunks)
  • deepmd/tf/entrypoints/freeze.py (2 hunks)
  • source/api_c/include/c_api.h (8 hunks)
  • source/api_c/include/deepmd.hpp (6 hunks)
  • source/api_c/src/c_api.cc (10 hunks)
  • source/api_cc/include/DeepPot.h (5 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/include/DeepPotTF.h (3 hunks)
  • source/api_cc/src/DeepPot.cc (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
  • source/api_cc/src/DeepPotTF.cc (4 hunks)
  • source/lmp/pair_deepmd.cpp (15 hunks)
  • source/lmp/pair_deepmd.h (1 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
source/api_cc/include/DeepPotTF.h

[notice] 378-379: Commented-out code
This comment appears to contain commented-out code.


[notice] 385-385: Commented-out code
This comment appears to contain commented-out code.

source/api_cc/src/DeepPotPT.cc

[notice] 356-356: Unused local variable
Variable nloc is not used.


[notice] 462-466: Commented-out code
This comment appears to contain commented-out code.


[notice] 495-495: Commented-out code
This comment appears to contain commented-out code.


[notice] 687-687: Unused local variable
Variable nframes is not used.


[notice] 748-752: Commented-out code
This comment appears to contain commented-out code.


[notice] 754-754: Commented-out code
This comment appears to contain commented-out code.

source/api_cc/src/DeepPotTF.cc

[notice] 893-893: Unused local variable
Variable ret is not used.


[notice] 905-905: Unused local variable
Variable ret is not used.

source/lmp/pair_deepmd.cpp

[notice] 1248-1248: Commented-out code
This comment appears to contain commented-out code.

🔇 Additional comments (29)
deepmd/tf/entrypoints/freeze.py (1)

Line range hint 1-394: Overall, the changes look good and align with the PR objectives.

The modifications to deepmd/tf/entrypoints/freeze.py successfully add support for spin-related attributes in the energy model while maintaining backward compatibility. The code quality is good, and the changes are well-integrated into the existing structure.

Here's a summary of the review:

  1. New spin attributes ("spin_attr/virtual_len" and "spin_attr/spin_norm") have been added to the energy model.
  2. These attributes are properly handled in both the _make_node_names and freeze_graph functions.
  3. Backward compatibility is maintained by adding the new attributes to the optional_node list.

Minor suggestions for improvement include adding brief documentation for the new spin attributes and grouping related optional nodes for better readability.

deepmd/pt/model/model/spin_model.py (2)

Line range hint 1-638: Overall, the changes look good and align with the PR objectives.

The modifications to both SpinModel and SpinEnergyModel classes enhance the flexibility and potentially the performance of the spin models. The implementation is consistent and correct. The addition of comm_dict and extra_nlist_sort parameters should improve LAMMPS support for spin models as intended.

A few minor suggestions:

  1. Verify the consistent usage of new parameters across the codebase.
  2. Consider adding a brief comment explaining the purpose of comm_dict for future maintainability.

Great job on these improvements!


474-476: LGTM! Verify consistent usage of new parameters.

The addition of comm_dict and extra_nlist_sort parameters to the forward_common_lower method enhances the model's flexibility and aligns with the PR objectives. The implementation looks correct.

To ensure consistency, let's verify if these new parameters are used correctly in other parts of the codebase:

Also applies to: 496-497

source/lmp/pair_deepmd.h (1)

78-78: LGTM!

The addition of all_force_mag is appropriate and consistent with the existing member variables.

source/api_cc/include/DeepPotTF.h (1)

118-134: Ensure Consistency with New compute Method Overload

The new compute method overload includes additional parameters force_mag and spin. Please verify that all calls to compute throughout the codebase are updated to match this new signature where appropriate, and ensure that the added parameters are correctly handled in the implementation to prevent any potential runtime errors or mismatches.

To assist in verifying the usage of the compute method, you can run the following script:

✅ Verification successful

All compute method calls are consistent with the new signature.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Find all usages of the `compute` method to ensure they match the new signature.

rg --type cpp --type c++ -A 2 -B 2 '\.compute\('

Length of output: 78


Script:

#!/bin/bash
# Description: Find all usages of the `compute` method to ensure they match the new signature.

rg --type cpp -A 2 -B 2 '\.compute\('

Length of output: 114850


Script:

#!/bin/bash
# Description: Find all `compute` method calls that do not include `force_mag` and `spin` parameters.

ast-grep --lang cpp --pattern "compute($_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_)" --match-only

Length of output: 476


Script:

#!/bin/bash
# Description: Find all `compute` method calls that do not include `force_mag` and `spin` parameters.

ast-grep --lang cpp --pattern "compute($_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_)"

Length of output: 251

source/api_cc/include/DeepPotPT.h (1)

328-341: Ensure consistency in method signatures and parameter ordering in computew overloads

There are multiple overloads of the computew method with varying parameter types and orders. Please verify that the method signatures are consistent and that the parameter ordering is appropriate to prevent confusion and potential misuse.

source/lmp/pair_deepmd.cpp (2)

582-586: Handling spin interactions correctly in force computation.

The function call to deep_pot.compute has been appropriately updated to include dforce_mag and dspin_const, allowing for the computation of spin-related forces. This modification correctly extends the functionality to handle spin interactions.


1406-1423: Ensure correct buffer size and indexing in reverse communication.

In the pack_reverse_comm function (lines 1406-1423), when atom->sp_flag is set, additional spin force data (dforce_mag) is packed into the buffer. Ensure that the buffer size comm_reverse is correctly set to accommodate the extra data and that indexing variable m correctly tracks the buffer positions to prevent buffer overflows or misalignment.

source/api_cc/src/DeepPotTF.cc (2)

511-515: Implementation of get_vector method is correct

The get_vector method correctly retrieves vector data from the session using the provided name.


893-893: Acknowledging past comments on unused variables

The variable ret is assigned but not used beyond an assert statement at lines 893 and 905. Previous review comments have already addressed this issue.

No further action is needed if the previous comments are being addressed.

Also applies to: 905-905

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 893-893: Unused local variable
Variable ret is not used.

source/api_cc/src/DeepPot.cc (5)

221-243: Correct implementation of new compute method with spin support

The new compute method overloads correctly add parameters for spin (dspin_) and magnetic force (dforce_mag_). The method properly forwards these parameters to dp->computew, enhancing functionality to support spin calculations.


244-263: Consistent addition of vectorized compute method with spin support

The overload of the compute method for vector energies (std::vector<ENERGYTYPE>& dener) is implemented correctly with the new spin parameters. This maintains consistency with the single-energy version and appropriately extends the class's capabilities.


954-1012: Accurate addition of spin support in DeepPotModelDevi::compute

The DeepPotModelDevi class now includes an overloaded compute method that supports spin (dspin_) and magnetic forces (all_force_mag). The implementation correctly loops over numb_models and passes the new parameters to each dps[ii].compute call.


1073-1140: Extension of compute method with atom energies and spin support is sound

The extended compute method in DeepPotModelDevi now supports atom energies, atom virials, and spin parameters. These additions are correctly integrated into the method, enhancing its functionality for spin-related computations.


488-596: Ensure consistency in template instantiations for spin support

The template instantiations for both double and float types for the new compute methods support spin calculations correctly. Verify that these instantiations cover all required use cases and that there are no missing specializations.

To confirm completeness, you might run a check to ensure all necessary template specializations are provided:

✅ Verification successful

Template Instantiations Verified

All necessary template instantiations for double and float types with spin support are present.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify all template instantiations for DeepPot::compute with spin support

# Expected: All instantiations for double and float types should be present
rg --type cpp 'template void DeepPot::compute<.*>\(' -A 5 | rg 'dspin_'

Length of output: 479

source/api_c/include/c_api.h (8)

807-821: Verify the usage of the new function in the codebase.

The new function DP_DeepPotModelDeviComputeNListSP looks good. It extends the existing DP_DeepPotModelDeviComputeNList function by adding a spin parameter to support spin computations for the dipole charge modifier.

However, please ensure that all the callers of this function provide the correct arguments, especially the new spin and force_mag parameters.

Run the following script to verify the function usage:

#!/bin/bash
# Description: Verify all function calls to `DP_DeepPotModelDeviComputeNListSP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 $'DP_DeepPotModelDeviComputeNListSP'

860-874: Verify the usage of the new function in the codebase.

The new function DP_DeepPotModelDeviComputeNListfSP looks good. It is the float version counterpart of DP_DeepPotModelDeviComputeNListSP and extends the existing DP_DeepPotModelDeviComputeNListf function by adding a spin parameter to support spin computations for the dipole charge modifier.

However, please ensure that all the callers of this function provide the correct arguments, especially the new spin and force_mag parameters.

Run the following script to verify the function usage:

#!/bin/bash
# Description: Verify all function calls to `DP_DeepPotModelDeviComputeNListfSP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 $'DP_DeepPotModelDeviComputeNListfSP'

922-939: Verify the usage of the new function in the codebase.

The new function DP_DeepPotModelDeviComputeNList2SP looks good. It extends the existing DP_DeepPotModelDeviComputeNList2 function by adding a spin parameter to support spin computations for the dipole charge modifier with multiple frames.

However, please ensure that all the callers of this function provide the correct arguments, especially the new spin and force_mag parameters.

Run the following script to verify the function usage:

#!/bin/bash
# Description: Verify all function calls to `DP_DeepPotModelDeviComputeNList2SP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 $'DP_DeepPotModelDeviComputeNList2SP'

233-247: Verify the usage of the new function in the codebase.

The new function DP_DeepPotComputeNListSP looks good. It extends the existing DP_DeepPotComputeNList function by adding a spin parameter to support spin computations.

However, please ensure that all the callers of this function provide the correct arguments, especially the new spin and force_mag parameters.

Run the following script to verify the function usage:

✅ Verification successful

No usages of DP_DeepPotComputeNListSP found in the codebase.

There are currently no callers of the DP_DeepPotComputeNListSP function to verify argument correctness.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify all function calls to `DP_DeepPotComputeNListSP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 $'DP_DeepPotComputeNListSP'

Length of output: 552


492-509: Verify the usage of the new function in the codebase.

The new function DP_DeepPotComputeNListf2SP looks good. It is the float version counterpart of DP_DeepPotComputeNList2SP and extends the existing DP_DeepPotComputeNListf2 function by adding a spin parameter to support spin computations for multiple frames.

However, please ensure that all the callers of this function provide the correct arguments, especially the new spin and force_mag parameters.

Run the following script to verify the function usage:


987-1004: Verify the usage of the new function in the codebase.

The new function DP_DeepPotModelDeviComputeNListf2SP looks good. It is the float version counterpart of DP_DeepPotModelDeviComputeNList2SP and extends the existing DP_DeepPotModelDeviComputeNListf2 function by adding a spin parameter to support spin computations for the dipole charge modifier with multiple frames.

However, please ensure that all the callers of this function provide the correct arguments, especially the new spin and force_mag parameters.

Run the following script to verify the function usage:

✅ Verification successful

No callers of DP_DeepPotModelDeviComputeNListf2SP found.

All usages are correctly implemented or yet to be added.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify all function calls to `DP_DeepPotModelDeviComputeNListf2SP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 $'DP_DeepPotModelDeviComputeNListf2SP'

Length of output: 682


Script:

#!/bin/bash
# Description: Verify all function calls to `DP_DeepPotModelDeviComputeNListf2SP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 'DP_DeepPotModelDeviComputeNListf2SP'

Length of output: 593


287-301: Verify the usage of the new function in the codebase.

The new function DP_DeepPotComputeNListfSP looks good. It is the float version counterpart of DP_DeepPotComputeNListSP and extends the existing DP_DeepPotComputeNListf function by adding a spin parameter to support spin computations.

However, please ensure that all the callers of this function provide the correct arguments, especially the new spin and force_mag parameters.

Run the following script to verify the function usage:

✅ Verification successful

Function DP_DeepPotComputeNListfSP has no usage in the codebase.
No callers found that need to provide the spin and force_mag arguments.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash 
# Description: Verify all function calls to `DP_DeepPotComputeNListfSP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 $'DP_DeepPotComputeNListfSP'

Length of output: 558


Script:

#!/bin/bash
# Description: Correctly verify all function calls to `DP_DeepPotComputeNListfSP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 'DP_DeepPotComputeNListfSP'

Length of output: 556


427-444: Verify the usage of the new function in the codebase.

The new function DP_DeepPotComputeNList2SP looks good. It extends the existing DP_DeepPotComputeNList2 function by adding a spin parameter to support spin computations for multiple frames.

However, please ensure that all the callers of this function provide the correct arguments, especially the new spin and force_mag parameters.

Run the following script to verify the function usage:

✅ Verification successful

No usages of DP_DeepPotComputeNList2SP found in the codebase.

Upon verification, there are no existing calls to the new function DP_DeepPotComputeNList2SP. Please ensure that this function is integrated correctly where needed.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify all function calls to `DP_DeepPotComputeNList2SP` provide the correct arguments.

# Test: Search for the function usage. Expect: Callers provide all required arguments including `spin` and `force_mag`.
rg --type c -A 5 $'DP_DeepPotComputeNList2SP'

Length of output: 558

source/api_c/src/c_api.cc (3)

1382-1404: Check Template Instantiations for Completeness

In DP_DeepPotComputeNList2SP, ensure that template instantiations are provided for all required data types (both double and float) to prevent linker errors.


1674-1696: Ensure All Data Types are Supported in New Functions

For DP_DeepPotModelDeviComputeNList2SP, confirm that all necessary data types and template instantiations are correctly implemented.


1266-1284: Verify Integration of New Spin Support Functions

The new function DP_DeepPotComputeNListSP adds spin support. Ensure that all relevant parts of the codebase are updated to utilize this new function where appropriate.

Run the following script to find and review all usages of DP_DeepPotComputeNList:

✅ Verification successful

Memory Management Verified for Large Vectors

The vectors force_mag, atom_energy, and atom_virial are properly managed with the use of the clear() method across relevant source files, ensuring effective memory deallocation and preventing memory leaks.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for proper resizing and clearing of large vectors.

# Search for usage of large vectors without proper memory management.
rg --type cpp --context 5 'std::vector<VALUETYPE>' | grep -E 'force_mag|atom_energy|atom_virial'

Length of output: 290


Script:

#!/bin/bash
# Description: Check for proper memory management of large vectors force_mag, atom_energy, and atom_virial.

# Search for memory management operations on the specified vectors.
rg --type cpp 'force_mag\.(resize|clear|reserve|shrink_to_fit)|atom_energy\.(resize|clear|reserve|shrink_to_fit)|atom_virial\.(resize|clear|reserve|shrink_to_fit)' -C 5

Length of output: 20595

source/api_cc/include/DeepPotTF.h Show resolved Hide resolved
source/api_cc/include/DeepPotPT.h Show resolved Hide resolved
source/api_cc/src/DeepPotPT.cc Outdated Show resolved Hide resolved
source/lmp/pair_deepmd.cpp Show resolved Hide resolved
source/api_cc/src/DeepPot.cc Show resolved Hide resolved
source/api_c/src/c_api.cc Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
@iProzd iProzd linked an issue Oct 16, 2024 that may be closed by this pull request
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

🧹 Outside diff range and nitpick comments (2)
source/api_cc/src/DeepPotPT.cc (1)

507-507: Fix typographical error in comments

There's a typographical error in the comments at lines 507, 733, and 794. "suported" should be "supported".

Apply this diff to correct the spelling:

- // spin model not suported yet
+ // spin model not supported yet

Also applies to: 733-733, 794-794

source/api_cc/include/DeepPotPT.h (1)

Line range hint 356-379: Review data type consistency in computew method overloads

In these computew overloads, there is a mix of double and float types among the parameters. For instance, ener and coord are of type std::vector<double>&, while force, force_mag, spin, and other parameters are std::vector<float>&. This inconsistency can lead to precision issues and may affect computational results.

Consider standardizing the parameter types to ensure consistency. Decide whether double or single precision is more appropriate for your application and apply it uniformly across all parameters. Here's an example of how to adjust the types:

 void computew(std::vector<double>& ener,
-              std::vector<float>& force,
-              std::vector<float>& force_mag,
-              std::vector<float>& virial,
-              std::vector<float>& atom_energy,
-              std::vector<float>& atom_virial,
-              const std::vector<float>& coord,
-              const std::vector<float>& spin,
-              const std::vector<int>& atype,
-              const std::vector<float>& box,
-              const int nghost,
-              const InputNlist& inlist,
-              const int& ago,
-              const std::vector<float>& fparam,
-              const std::vector<float>& aparam,
+              std::vector<double>& force,
+              std::vector<double>& force_mag,
+              std::vector<double>& virial,
+              std::vector<double>& atom_energy,
+              std::vector<double>& atom_virial,
+              const std::vector<double>& coord,
+              const std::vector<double>& spin,
+              const std::vector<int>& atype,
+              const std::vector<double>& box,
+              const int nghost,
+              const InputNlist& inlist,
+              const int& ago,
+              const std::vector<double>& fparam,
+              const std::vector<double>& aparam,
               const bool atomic);
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 3466e34 and e2e1e55.

📒 Files selected for processing (3)
  • deepmd/pt/model/model/spin_model.py (9 hunks)
  • source/api_cc/include/DeepPotPT.h (5 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
🧰 Additional context used
🔇 Additional comments (10)
source/api_cc/src/DeepPotPT.cc (3)

563-596: Template instantiations for new compute method are correct

The template instantiations for the new compute method overload (lines 563-596) are correctly implemented for both double and float types. This ensures that the new functionality is available for both precision levels.


832-859: Template instantiations for new non-ghost compute method are correct

The template instantiations for the new compute method overload for the non-ghost atom case (lines 832-859) are correctly implemented for both double and float types. This ensures that the new functionality is available for both precision levels.


363-365: Validate dimensions of spin_wrapped_Tensor

When creating spin_wrapped_Tensor, ensure that its dimensions match the expected input for the forward_lower method. Mismatched dimensions could lead to runtime errors or incorrect computations.

Verify that spin_wrapped_Tensor is initialized with the correct shape:

 at::Tensor spin_wrapped_Tensor =
-    torch::from_blob(spin_wrapped.data(), {1, nall_real, 3}, options)
+    torch::from_blob(spin_wrapped.data(), {1, nall_real, spin_dimension}, options)
        .to(device);

Replace spin_dimension with the appropriate size based on how spin data is structured.

To verify this, we can use the following script:

This script will help identify the correct dimension for the spin data used in the model.

source/api_cc/include/DeepPotPT.h (3)

77-90: The new compute method overload appears well-integrated

The addition of the compute method overload with force_mag and spin parameters enhances the functionality to support spin models. The parameter types and their placement in the signature are consistent with existing conventions in the codebase.


132-148: The compute method overload with neighbor list correctly incorporates new parameters

The overloaded compute method now includes force_mag, spin, and neighbor list parameters (nghost, lmp_list, ago). The method signature aligns with the existing design, and the inclusion of new parameters is handled appropriately.


304-314: Ensure consistent parameter types in computew method

In this computew overload, all parameters are of type std::vector<double>&, which maintains consistency and avoids unnecessary type conversions. This is good practice for ensuring numerical precision and performance.

deepmd/pt/model/model/spin_model.py (4)

86-89: Ensure consistent usage of 'recv_num' when calling 'concat_switch_virtual'.

In the calls to self.concat_switch_virtual at lines 86 and 89, you have added recv_num=recv_num. Verify that recv_num is correctly passed and handled within concat_switch_virtual.


99-99: Pass 'recv_num' consistently to 'extend_nlist'.

You are now passing recv_num to self.extend_nlist. Ensure that recv_num is properly utilized within extend_nlist and that all call sites are updated accordingly.


561-561: Ensure 'comm_dict' is correctly passed to 'backbone_model.forward_common_lower'.

Passing comm_dict to self.backbone_model.forward_common_lower is appropriate. Verify that the backbone model correctly handles comm_dict and that all necessary keys are included.


688-688: Pass 'comm_dict' to 'forward_common_lower' in 'SpinEnergyModel'.

Passing comm_dict to self.forward_common_lower ensures consistency with the updated method signatures. Confirm that comm_dict is correctly handled throughout the method chain.

source/api_cc/src/DeepPotPT.cc Outdated Show resolved Hide resolved
source/api_cc/include/DeepPotPT.h Show resolved Hide resolved
source/api_cc/include/DeepPotPT.h Outdated Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotPT.cc Fixed Show fixed Hide fixed
source/api_cc/src/DeepPotPT.cc Fixed Show fixed Hide fixed
Comment on lines 397 to 402
// std::cout<<sendlist_part[0]<<std::endl;
// std::cout<<sendlist_part[1]<<std::endl;
// std::cout<<sendlist_part[2]<<std::endl;
// std::cout<<sendlist_part[3]<<std::endl;
// std::cout<<sendlist_part[4]<<std::endl;
// std::cout<<sendlist_part[5]<<std::endl;

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
Comment on lines 431 to 435
// for(int j = 0; j < lmp_list.sendnum[i]; j++)
// {
// std::cout<<spin_sendlist[i][j]<<" ";
// }
// std::cout<<std::endl;

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (1)
source/api_cc/src/DeepPotPT.cc (1)

Line range hint 315-955: Summary: New spin model support added, but refinements needed

The changes to DeepPotPT.cc successfully introduce support for spin models by adding force_mag and spin parameters to various compute and computew method overloads. However, there are areas for improvement:

  1. Incomplete spin model support: Virial calculations are consistently commented out across new methods, indicating partial implementation.

  2. Code duplication: There's significant repetition across compute method overloads that could be reduced through further refactoring.

  3. Consistency: Ensure that all new methods maintain consistent error handling and parameter validation.

Consider a comprehensive refactoring to:

  1. Implement or remove commented virial calculations.
  2. Create a common core for compute methods to reduce duplication.
  3. Ensure consistent error handling and parameter validation across all new methods.

These improvements will enhance maintainability and completeness of the spin model support.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between e2e1e55 and cf85275.

📒 Files selected for processing (3)
  • deepmd/pt/model/model/spin_model.py (4 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
🧰 Additional context used
🔇 Additional comments (13)
source/api_cc/src/DeepPotPT.cc (3)

502-535: LGTM: Template instantiations for new compute method

The template instantiations for the new compute method overload with force_mag parameter are correctly implemented for both double and float types. This ensures type-specific compilations are available as needed.


840-875: LGTM: New computew overloads correctly implemented

The new computew method overloads for both double and float types are correctly implemented. They properly include the new force_mag and spin parameters and maintain the existing pattern of wrapping the corresponding compute methods within a translate_error lambda.


914-955: LGTM: New computew overloads with ghost atoms correctly implemented

The new computew method overloads for both double and float types, including ghost atom handling, are correctly implemented. They properly incorporate the new force_mag and spin parameters and maintain the existing pattern of wrapping the corresponding compute methods within a translate_error lambda.

source/api_cc/include/DeepPotPT.h (6)

77-90: Documentation for new compute method overload is still missing

The added overload of the compute method introduces force_mag and spin parameters but lacks accompanying documentation explaining their purpose and usage.


132-148: Documentation for compute method overload with neighbor list is still missing

The new compute overload that includes force_mag, spin, and neighbor list parameters lacks documentation. This information is crucial for users to understand how to utilize these parameters effectively.


304-314: Documentation for new computew method overload is still missing

The computew method overload with additional force_mag and spin parameters requires documentation to explain the functionality and usage of these new parameters.


315-327: Type inconsistency in computew method parameters persists

The computew overload starting at line 315 has a type mismatch where ener is of type std::vector<double>&, but other parameters like force, force_mag, and spin are of type std::vector<float>&. This could lead to precision issues or unintended behavior.


328-355: Documentation for computew method overload with neighbor list is still missing

The computew overloads that incorporate force_mag, spin, and neighbor list parameters lack necessary documentation. Providing detailed descriptions of these parameters will enhance code clarity and usability.


374-379: Type inconsistency in computew method parameters persists

In the computew method starting at line 374, there's a type mismatch between std::vector<double>& ener and std::vector<float>& for other parameters like force, force_mag, and spin. Consistent parameter types are important to prevent precision loss and runtime errors.

deepmd/pt/model/model/spin_model.py (4)

474-474: Include 'comm_dict' in the method documentation.

As previously noted, please update the docstring of forward_common_lower to include the comm_dict parameter and explain its purpose and expected usage.


496-496: LGTM!

The addition of comm_dict to the call ensures proper passing of communication data to the backbone model.


612-612: Include 'comm_dict' in the method documentation.

As previously noted, please update the docstring of forward_lower to include the comm_dict parameter and explain its purpose and expected usage.


623-623: LGTM!

Passing comm_dict to self.forward_common_lower maintains consistency in handling communication data across methods.

source/api_cc/src/DeepPotPT.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Outside diff range and nitpick comments (4)
source/api_cc/src/DeepPotPT.cc (2)

Line range hint 842-957: Consider template-based refactoring for computew methods

The new computew overloads correctly handle force_mag and spin parameters, but there's significant code duplication across all computew methods.

Consider using a template-based approach to reduce code duplication:

template <typename VALUETYPE, typename... Args>
void DeepPotPT::computew(std::vector<double>& ener,
                         std::vector<VALUETYPE>& force,
                         std::vector<VALUETYPE>& force_mag,
                         std::vector<VALUETYPE>& virial,
                         std::vector<VALUETYPE>& atom_energy,
                         std::vector<VALUETYPE>& atom_virial,
                         Args&&... args) {
  translate_error([&] {
    compute(ener, force, force_mag, virial, atom_energy, atom_virial,
            std::forward<Args>(args)...);
  });
}

This template can handle all variations of computew, reducing the number of nearly identical method implementations.


Line range hint 315-957: Overall implementation of spin model support is consistent but has room for improvement

The changes successfully implement support for spin models across multiple method overloads, maintaining consistency with existing code structure and error handling. However, there are a few areas that could be improved:

  1. The commented-out virial calculations for spin models indicate incomplete functionality. Consider implementing these calculations or adding clear TODO comments explaining the limitations and future plans.

  2. There are several opportunities for refactoring to reduce code duplication, particularly in the compute and computew methods. Consider using templates or helper functions to consolidate common logic.

  3. Ensure that the new functionality is thoroughly tested, especially the interaction between force_mag, spin, and other parameters.

Overall, the implementation is solid but could benefit from some refinement to improve maintainability and completeness.

deepmd/pt/utils/spin.py (1)

37-37: Use torch.floor_divide for integer division for clarity.

In line 37, origin_recv_num = torch.div(recv_num, 2).to(torch.int) performs division followed by type casting. Using torch.floor_divide(recv_num, 2) makes it explicit that integer division is intended, improving code clarity.

Apply the following diff:

-    origin_recv_num = torch.div(recv_num, 2).to(torch.int)
+    origin_recv_num = torch.floor_divide(recv_num, 2)
deepmd/pt/model/descriptor/repformers.py (1)

524-525: Specify data types and devices when creating tensors

When creating tensors with torch.tensor, it's best practice to explicitly specify the dtype and device to ensure consistency across different environments:

Apply the following changes to specify dtype and device:

- torch.tensor(real_nloc),  # pylint: disable=no-explicit-dtype,no-explicit-device
- torch.tensor(real_nall - real_nloc),  # pylint: disable=no-explicit-dtype,no-explicit-device
+ torch.tensor(real_nloc, dtype=torch.int64, device=env.DEVICE),
+ torch.tensor(real_nall - real_nloc, dtype=torch.int64, device=env.DEVICE),
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between cf85275 and 1d6defe.

📒 Files selected for processing (4)
  • deepmd/pt/model/descriptor/repformers.py (4 hunks)
  • deepmd/pt/model/model/spin_model.py (6 hunks)
  • deepmd/pt/utils/spin.py (1 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
🧰 Additional context used
🔇 Additional comments (8)
source/api_cc/src/DeepPotPT.cc (2)

315-331: New parameters added to support spin models

The method signature has been updated to include force_mag and spin parameters, enhancing the functionality to support spin models. This change is consistent with the PR objectives.


661-674: Consistent implementation of spin model support

This overload maintains consistency with the previous one, correctly handling force_mag and spin parameters for cases without ghost atoms.

deepmd/pt/model/model/spin_model.py (4)

29-31: Importing concat_switch_virtual enhances code reusability

Importing concat_switch_virtual from deepmd.pt.utils.spin replaces the local implementation, promoting code reuse and maintainability.


450-450: Including comm_dict parameter enhances method flexibility

The addition of the optional comm_dict parameter to forward_common_lower allows for passing additional communication data, enhancing the method's versatility without impacting existing functionality.

Also applies to: 472-472


588-588: Extending forward_lower with comm_dict ensures consistency

Adding comm_dict to SpinEnergyModel.forward_lower aligns it with the updated SpinModel methods, maintaining consistency across the models.

Also applies to: 599-599


87-95: Verify consistency between old and new concat_switch_virtual implementations

Replacing the local concat_switch_virtual method with the imported version may introduce differences if the implementations are not identical. Please verify that both functions behave the same to prevent any unintended side effects.

Run the following script to compare the implementations:

deepmd/pt/model/descriptor/repformers.py (2)

48-50: Import of concat_switch_virtual is appropriate

The addition of concat_switch_virtual is necessary for handling spin configurations and is correctly imported.


462-462: Variable ng1 correctly assigned

The assignment ng1 = g1.shape[-1] accurately captures the last dimension of g1, which is essential for subsequent tensor operations.

source/api_cc/src/DeepPotPT.cc Show resolved Hide resolved
deepmd/pt/utils/spin.py Outdated Show resolved Hide resolved
deepmd/pt/utils/spin.py Outdated Show resolved Hide resolved
deepmd/pt/utils/spin.py Outdated Show resolved Hide resolved
deepmd/pt/model/descriptor/repformers.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (5)
deepmd/tf/entrypoints/freeze.py (2)

126-127: LGTM! Consider grouping spin-related attributes.

The addition of spin-related nodes "spin_attr/virtual_len" and "spin_attr/spin_norm" is consistent with the PR objectives to enhance LAMMPS support for spin models.

For improved readability, consider grouping all spin-related attributes together. You could move the existing "spin_attr/ntypes_spin" to be adjacent to the newly added spin attributes:

"spin_attr/ntypes_spin",
"spin_attr/virtual_len",
"spin_attr/spin_norm",
"fitting_attr/dfparam",
# ... rest of the attributes

263-264: LGTM! Consider grouping spin-related attributes for consistency.

The addition of spin-related nodes "spin_attr/virtual_len" and "spin_attr/spin_norm" to the optional_node list is consistent with the changes made in the _make_node_names function and aligns with the PR objectives.

For consistency with the suggested change in the _make_node_names function, consider grouping all spin-related attributes together in the optional_node list:

optional_node = [
    "train_attr/min_nbor_dist",
    "fitting_attr/aparam_nall",
    "spin_attr/ntypes_spin",
    "spin_attr/virtual_len",
    "spin_attr/spin_norm",
]

This grouping would improve readability and maintain consistency throughout the file.

deepmd/pt/model/descriptor/repformers.py (2)

490-491: Consider using explicit dtype and device for tensor creation

When creating tensors for real_nloc and real_nall - real_nloc, it's better to specify the dtype and device explicitly for consistency and to avoid potential issues with dtype or device mismatches.

Consider updating the tensor creation as follows:

- torch.tensor(real_nloc),
- torch.tensor(real_nall - real_nloc),
+ torch.tensor(real_nloc, dtype=torch.int64, device=g1.device),
+ torch.tensor(real_nall - real_nloc, dtype=torch.int64, device=g1.device),

This ensures that the new tensors have a consistent dtype and are on the same device as the input tensor g1.


494-498: LGTM: Proper handling of spin configurations

The addition of spin-specific logic after the border_op call is well-implemented. It correctly splits and recombines the real and virtual components for spin configurations using the concat_switch_virtual function.

For improved readability, consider adding a brief comment explaining the purpose of this block:

 if has_spin:
+    # Recombine real and virtual components for spin configurations
     g1_real_ext, g1_virtual_ext = torch.split(g1_ext, [ng1, ng1], dim=2)
     g1_ext = concat_switch_virtual(
         g1_real_ext, g1_virtual_ext, real_nloc
     )

This comment will help future readers quickly understand the purpose of this code block.

source/api_cc/src/DeepPotPT.cc (1)

395-401: Clarify the use of has_spin tensor in comm_dict

The tensor has_spin is hardcoded with a value of 1. To enhance code clarity and maintainability, consider defining a constant or using a boolean flag to represent the presence of spin.

Apply this diff for improved clarity:

-torch::Tensor has_spin = torch::tensor({1}, int32_option);
+constexpr int HAS_SPIN = 1;
+torch::Tensor has_spin = torch::tensor({HAS_SPIN}, int32_option);
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 1d6defe and e5c0ecf.

📒 Files selected for processing (7)
  • deepmd/pt/model/descriptor/repformers.py (4 hunks)
  • deepmd/pt/model/model/spin_model.py (6 hunks)
  • deepmd/pt/utils/spin.py (1 hunks)
  • deepmd/tf/entrypoints/freeze.py (2 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
  • source/api_cc/src/DeepPotTF.cc (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/utils/spin.py
🧰 Additional context used
🪛 Ruff
deepmd/pt/model/model/spin_model.py

448-448: Undefined name Dict

(F821)


586-586: Undefined name Dict

(F821)

🔇 Additional comments (18)
deepmd/tf/entrypoints/freeze.py (1)

Line range hint 1-385: Overall assessment: Changes are well-implemented and align with PR objectives.

The modifications to deepmd/tf/entrypoints/freeze.py successfully integrate support for spin models in the LAMMPS freezing process. The changes are localized, maintain the existing structure of the code, and are consistent across functions. The minor suggestions for grouping spin-related attributes would further enhance code readability and consistency.

deepmd/pt/model/descriptor/repformers.py (1)

31-33: LGTM: New import for spin-related functionality

The addition of concat_switch_virtual import from deepmd.pt.utils.spin is appropriate for the new spin-handling functionality introduced in this file.

deepmd/pt/model/model/spin_model.py (6)

18-20: LGTM: Improved code reuse with utility function.

The addition of concat_switch_virtual from deepmd.pt.utils.spin is a good refactoring step. It promotes code reuse and maintainability by utilizing a common utility function instead of a class-specific method.


448-448: LGTM: New parameter for additional communication data.

The addition of the comm_dict parameter allows for passing additional communication data, which can enhance the flexibility and functionality of the model.

🧰 Tools
🪛 Ruff

448-448: Undefined name Dict

(F821)


470-472: LGTM: Proper usage of new parameters.

The comm_dict is correctly passed to the backbone model, and the addition of extra_nlist_sort parameter provides flexibility for neighbor list sorting optimization. These changes enhance the model's functionality and performance tuning capabilities.


586-586: LGTM: Consistent implementation of new parameter.

The addition of the comm_dict parameter in the forward_lower method of SpinEnergyModel is consistent with the changes made in the SpinModel class. This ensures that the new functionality is properly propagated through the class hierarchy.

🧰 Tools
🪛 Ruff

586-586: Undefined name Dict

(F821)


586-586: Import Dict from typing module.

The Dict type is used but not imported. This issue was previously mentioned for the SpinModel class. Ensure that the following import is added at the beginning of the file to resolve both occurrences:

from typing import Dict
🧰 Tools
🪛 Ruff

586-586: Undefined name Dict

(F821)


597-599: LGTM: Proper usage of new parameters and optimization.

The comm_dict is correctly passed to the forward_common_lower method, and the extra_nlist_sort parameter is set based on the backbone model's requirements. This ensures consistency with the SpinModel class and allows for proper optimization of neighbor list sorting.

source/api_cc/include/DeepPotPT.h (4)

77-90: Extension of compute method to support spin models

The addition of force_mag and spin parameters to the compute method overload enhances functionality, enabling support for spin models in your simulations.


132-148: Updated compute method overload with neighbor list and spin support

Incorporating force_mag and spin parameters alongside neighbor list data in the compute method overload is appropriate for handling spin interactions in conjunction with neighbor lists.


304-327: Enhancement of computew methods to include spin and force magnitude

Introducing force_mag and spin parameters in the computew method overloads expands the computational capabilities to accurately model spin-dependent interactions.


Line range hint 356-379: Addition of computew overloads with neighbor list, spin, and force magnitude

The new computew method overloads that include neighbor list parameters along with force_mag and spin are well-suited for efficient computation of spin models with neighbor interactions.

source/api_cc/src/DeepPotPT.cc (5)

841-876: New computew overloads for spin models are properly implemented

The added computew methods now support spin and force_mag parameters, enhancing the functionality to handle spin configurations. The implementation appears correct and consistent with the existing codebase.


915-956: Extended computew methods support ghost atoms with spin

The new computew overloads accommodate nghost and InputNlist parameters alongside spin and force_mag. This extension correctly integrates ghost atom support in spin model computations.


362-364: 🛠️ Refactor suggestion

Optimize spin tensor creation to avoid unnecessary copying

Currently, spin is copied into spin_wrapped before tensor creation. To improve performance and reduce memory usage, consider using spin directly when creating spin_wrapped_Tensor.

Apply this diff to eliminate the redundant copy:

-std::vector<VALUETYPE> spin_wrapped = spin;
-at::Tensor spin_wrapped_Tensor =
-    torch::from_blob(spin_wrapped.data(), {1, nall_real, 3}, options)
+at::Tensor spin_wrapped_Tensor =
+    torch::from_blob(const_cast<VALUETYPE*>(spin.data()), {1, nall_real, 3}, options)
         .to(device);

Likely invalid or redundant comment.


659-799: Ensure proper integration of new compute overload for spin models

The new compute method overload handles spin and force_mag parameters for calculations without neighbor lists. Verify that this method integrates seamlessly with other components of the system and that all necessary calls are updated.

Run the following script to find and review usages of the new compute method:

#!/bin/bash
# Description: Locate all calls to DeepPotPT::compute with the new signature

# Search for the new compute method signature in the codebase
ast-grep --pattern 'DeepPotPT::compute($_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_)'

This will help ensure that the new overload is correctly utilized wherever required.


313-537: Verify updates to compute method usages across the codebase

The new overload of the compute method includes additional parameters force_mag and spin. Please ensure that all calls to DeepPotPT::compute are reviewed and updated as necessary to match the new method signature where applicable.

Run the following script to locate all usages of the compute method:

This will help identify any existing calls that may need adjustments due to the new parameters.

✅ Verification successful

All usages of DeepPotPT::compute are up to date with the new method signature.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for all calls to DeepPotPT::compute and check their argument lists

# Find all files potentially containing calls to DeepPotPT::compute
fd -e cpp -e cc -e cxx -e h -e hh | xargs rg 'DeepPotPT::compute'

# Use ast-grep to match function calls to DeepPotPT::compute with various signatures
ast-grep --pattern 'DeepPotPT::compute($_, $$$)'

Length of output: 294


Script:

#!/bin/bash
# Description: Search for all calls to DeepPotPT::compute and check their argument lists

# Find all files potentially containing calls to DeepPotPT::compute
fd -e cpp -e cc -e cxx -e h -e hh | xargs rg 'DeepPotPT::compute'

# Use ast-grep to match function calls to DeepPotPT::compute with the new parameters
ast-grep --pattern 'DeepPotPT::compute($_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_)'

Length of output: 2083

source/api_cc/src/DeepPotTF.cc (1)

1199-1237: Implementation of spin-supporting computew methods looks good

The added computew methods to support spin functionality are well-implemented and align with the existing code structure.

deepmd/pt/model/descriptor/repformers.py Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotPT.cc Show resolved Hide resolved
source/api_cc/src/DeepPotTF.cc Show resolved Hide resolved
source/api_cc/src/DeepPotTF.cc Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Outside diff range and nitpick comments (2)
deepmd/tf/entrypoints/freeze.py (2)

126-127: LGTM. Consider adding a comment for clarity.

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" nodes for the "ener" model type is consistent with the PR objectives of enhancing LAMMPS support for spin models. These nodes likely represent attributes related to spin configurations in the energy model.

Consider adding a brief comment explaining the purpose of these new spin-related nodes for better code documentation.


263-264: LGTM. Consider adding "fitting_attr/aparam_nall" for consistency.

The addition of "spin_attr/virtual_len" and "spin_attr/spin_norm" to the optional_node list is consistent with the changes in the _make_node_names function. This change allows for backward compatibility with older checkpoints that may not have these spin-related attributes.

For consistency, consider also adding "fitting_attr/aparam_nall" to the optional_node list, as it was added to the "ener" model type nodes in the _make_node_names function.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 1d6defe and e5c0ecf.

📒 Files selected for processing (7)
  • deepmd/pt/model/descriptor/repformers.py (4 hunks)
  • deepmd/pt/model/model/spin_model.py (6 hunks)
  • deepmd/pt/utils/spin.py (1 hunks)
  • deepmd/tf/entrypoints/freeze.py (2 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
  • source/api_cc/src/DeepPotTF.cc (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/pt/model/descriptor/repformers.py
  • deepmd/pt/utils/spin.py
🧰 Additional context used
🪛 Ruff
deepmd/pt/model/model/spin_model.py

448-448: Undefined name Dict

(F821)


586-586: Undefined name Dict

(F821)

🪛 GitHub Check: CodeQL
source/api_cc/src/DeepPotPT.cc

[notice] 355-355: Unused local variable
Variable nloc is not used.


[notice] 463-467: Commented-out code
This comment appears to contain commented-out code.


[notice] 496-496: Commented-out code
This comment appears to contain commented-out code.


[notice] 688-688: Unused local variable
Variable nframes is not used.


[notice] 749-753: Commented-out code
This comment appears to contain commented-out code.


[notice] 755-755: Commented-out code
This comment appears to contain commented-out code.

source/api_cc/src/DeepPotTF.cc

[notice] 893-893: Unused local variable
Variable ret is not used.


[notice] 905-905: Unused local variable
Variable ret is not used.

🔇 Additional comments (14)
deepmd/tf/entrypoints/freeze.py (2)

130-130: LGTM. Please clarify the purpose of the new node.

The addition of the "fitting_attr/aparam_nall" node for the "ener" model type is consistent with the structure of the function. However, the purpose of this new node is not immediately clear from the context.

Could you please provide a brief explanation of the purpose and significance of the "fitting_attr/aparam_nall" node in the energy model? This will help improve code documentation and understanding for future maintainers.


Line range hint 1-391: Overall, the changes look good and align with the PR objectives.

The modifications to freeze.py enhance LAMMPS support for spin models by adding new nodes for spin attributes and fitting parameters. The changes are well-integrated into the existing code structure and maintain backward compatibility. Minor suggestions for improvement have been made, including adding comments for clarity and ensuring consistency in the optional node list.

deepmd/pt/model/model/spin_model.py (3)

18-20: Improved code organization with utility function.

The addition of concat_switch_virtual from a separate utility module enhances code organization and promotes reusability. This change is a positive step towards better modularization of the codebase.


470-472: Correct propagation of new parameters to backbone model.

The new comm_dict and extra_nlist_sort parameters are correctly passed to the backbone_model.forward_common_lower method, ensuring consistent behavior throughout the model hierarchy.


597-599: Proper handling of new parameters in SpinEnergyModel.forward_lower.

The comm_dict parameter is correctly passed to self.forward_common_lower, maintaining consistency with the parent class. The dynamic setting of extra_nlist_sort based on the backbone model's requirements is a good practice, ensuring flexibility and correct behavior across different model configurations.

source/api_cc/src/DeepPotPT.cc (1)

362-364: 🛠️ Refactor suggestion

Potential optimization for spin tensor creation

The creation of spin_wrapped_Tensor could be optimized by directly using the spin vector instead of creating an intermediate spin_wrapped vector.

Consider refactoring the tensor creation as follows:

-std::vector<VALUETYPE> spin_wrapped = spin;
-at::Tensor spin_wrapped_Tensor =
-    torch::from_blob(spin_wrapped.data(), {1, nall_real, 3}, options)
+at::Tensor spin_wrapped_Tensor =
+    torch::from_blob(const_cast<VALUETYPE*>(spin.data()), {1, nall_real, 3}, options)
         .to(device);

This change reduces memory usage and potentially improves performance by avoiding an unnecessary copy.

Likely invalid or redundant comment.

source/api_cc/include/DeepPotPT.h (5)

77-90: Changes look good in the new compute method overload

The added overload of the compute method with force_mag and spin parameters is correctly defined and consistent with the existing code structure.


132-148: New compute method overload with neighbor list parameters is appropriate

The additional overload of the compute method that includes force_mag, spin, and neighbor list parameters is well-defined and integrates properly with the class.


304-314: computew method overload with double types is consistent

The new overload of the computew method incorporating force_mag and spin parameters with double precision types aligns with the existing method signatures and ensures type consistency.


315-327: Ensure type consistency in computew method parameters

In this overload of the computew method, ener is of type std::vector<double>&, while other parameters like force, force_mag, virial, and spin are std::vector<float>&. Mixing double and float types can lead to precision loss and potential issues due to implicit type conversions.


328-355: Additional computew overload with neighbor list parameters is properly defined

The new overload of the computew method that includes neighbor list inputs is correctly implemented and maintains consistency with the class's interface.

source/api_cc/src/DeepPotTF.cc (3)

514-519: Function get_vector implementation is correct

The get_vector method properly wraps session_get_vector and correctly passes all necessary parameters. The implementation aligns with the expected functionality.


831-1024: Spin support is integrated effectively

The added compute method for spin support extends the functionality to handle spin-related computations. The code carefully manages spin states and integrates with the existing data structures and logic without introducing inconsistencies.

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 893-893: Unused local variable
Variable ret is not used.


[notice] 905-905: Unused local variable
Variable ret is not used.


934-948: Optimized assignment placement in loops

The assignments to datom_energy_[ii] and datom_virial_[ii] are correctly placed outside the inner loop over dd. This optimization prevents redundant assignments and enhances performance.

deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
deepmd/pt/model/model/spin_model.py Outdated Show resolved Hide resolved
source/api_cc/src/DeepPotPT.cc Show resolved Hide resolved
source/api_cc/src/DeepPotPT.cc Show resolved Hide resolved
source/api_cc/src/DeepPotTF.cc Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 21

🧹 Outside diff range and nitpick comments (9)
source/api_cc/include/DeepPotPT.h (2)

77-117: New compute method overload for spin calculations looks good, but needs minor documentation improvement.

The new compute method overload for handling spin calculations is well-structured and consistent with the existing method signatures. However, the documentation could be slightly improved for clarity.

Consider updating the documentation to clarify the units and expected range of values for the spin parameter. For example:

-   * @param[in] spin The spins of atoms, [0, 0, 0] if no spin. The array should
-   *be of size nframes x natoms x 3.
+   * @param[in] spin The spins of atoms. Each spin is a 3D vector, [0, 0, 0] if no spin.
+   * The array should be of size nframes x natoms x 3. Units: [TODO: specify units].

Replace [TODO: specify units] with the appropriate units for spin (e.g., ħ/2 for electron spin).


Line range hint 413-436: New computew method overloads with neighbor list for spin calculations are good, but have type inconsistencies.

The new computew method overloads for spin calculations with neighbor list handling are well-structured, but there are type inconsistencies that should be addressed.

Similar to the previous comment, there are type mismatches between double and float in these method overloads. Consider updating the types to be consistent:

 void computew(std::vector<double>& ener,
-               std::vector<float>& force,
-               std::vector<float>& force_mag,
-               std::vector<float>& virial,
-               std::vector<float>& atom_energy,
-               std::vector<float>& atom_virial,
-               const std::vector<float>& coord,
-               const std::vector<float>& spin,
+               std::vector<double>& force,
+               std::vector<double>& force_mag,
+               std::vector<double>& virial,
+               std::vector<double>& atom_energy,
+               std::vector<double>& atom_virial,
+               const std::vector<double>& coord,
+               const std::vector<double>& spin,
                const std::vector<int>& atype,
-               const std::vector<float>& box,
+               const std::vector<double>& box,
                const int nghost,
                const InputNlist& inlist,
                const int& ago,
-               const std::vector<float>& fparam,
-               const std::vector<float>& aparam,
+               const std::vector<double>& fparam,
+               const std::vector<double>& aparam,
                const bool atomic);

If you intend to keep the mixed types, please add a comment explaining the rationale behind this design decision.

source/api_cc/tests/test_deeppot_dpa1_pt_spin.cc (1)

23-31: Consider using consistent formatting for vector initialization

For better readability and maintainability, align the vector initialization to follow a consistent style. This can make the code cleaner and easier to modify in the future.

Apply this diff to adjust the formatting:

      std::vector<VALUETYPE> coord = {
-         12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
-         00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
-         3.51,  2.51, 2.60, 4.27,  3.22, 1.56};
+      12.83, 2.56, 2.18,
+      12.09, 2.87, 2.74,
+      0.25,  3.32, 1.68,
+      3.36,  3.00, 1.81,
+      3.51,  2.51, 2.60,
+      4.27,  3.22, 1.56};
source/lmp/pair_deepmd.cpp (4)

582-586: Simplify code by eliminating unnecessary constant references

The variables dcoord_const and dspin_const are constant references to dcoord and dspin:

const vector<double> &dcoord_const = dcoord;
const vector<double> &dspin_const = dspin;
deep_pot.compute(dener, dforce, dvirial, dcoord_const, dspin_const, dtype, dbox, nghost, lmp_list, ago, fparam, daparam);

Since dcoord and dspin are not modified within this scope, you can pass them directly as arguments without creating new references. This reduces code redundancy.

Apply this change:

-const vector<double> &dcoord_const = dcoord;
-const vector<double> &dspin_const = dspin;
-deep_pot.compute(dener, dforce, dvirial, dcoord_const, dspin_const, dtype, dbox, nghost, lmp_list, ago, fparam, daparam);
+deep_pot.compute(dener, dforce, dvirial, dcoord, dspin, dtype, dbox, nghost, lmp_list, ago, fparam, daparam);

904-907: Simplify code by eliminating unnecessary constant references

Similar to a previous comment, in the serial computation branch, you can eliminate unnecessary constant references:

const vector<double> &dcoord_const = dcoord;
const vector<double> &dspin_const = dspin;
deep_pot.compute(dener, dforce, dforce_mag, dvirial, dcoord_const, dspin_const, dtype, dbox);

Apply this change:

-const vector<double> &dcoord_const = dcoord;
-const vector<double> &dspin_const = dspin;
-deep_pot.compute(dener, dforce, dforce_mag, dvirial, dcoord_const, dspin_const, dtype, dbox);
+deep_pot.compute(dener, dforce, dforce_mag, dvirial, dcoord, dspin, dtype, dbox);

848-848: Implement support for spin atomic forces

The comment indicates missing functionality:

// need support for spin atomic force.

To ensure that spin atomic forces are correctly handled, consider implementing the necessary computations in this section.

Would you like assistance in developing this feature or opening a new GitHub issue to track this task?


Line range hint 1416-1464: Refactor duplicated code in communication functions

The pack_reverse_comm and unpack_reverse_comm functions contain duplicated code for handling atom->sp_flag. Refactoring can improve readability and maintainability.

Apply the following changes to streamline the code:

// In pack_reverse_comm
int m = 0;
for (i = first; i < last; i++) {
  for (int dd = 0; dd < numb_models; ++dd) {
    buf[m++] = all_force[dd][3 * i + 0];
    buf[m++] = all_force[dd][3 * i + 1];
    buf[m++] = all_force[dd][3 * i + 2];
+   if (atom->sp_flag) {
+     buf[m++] = all_force_mag[dd][3 * i + 0];
+     buf[m++] = all_force_mag[dd][3 * i + 1];
+     buf[m++] = all_force_mag[dd][3 * i + 2];
+   }
  }
}

// In unpack_reverse_comm
int m = 0;
for (i = 0; i < n; i++) {
  j = list[i];
  for (int dd = 0; dd < numb_models; ++dd) {
    all_force[dd][3 * j + 0] += buf[m++];
    all_force[dd][3 * j + 1] += buf[m++];
    all_force[dd][3 * j + 2] += buf[m++];
+   if (atom->sp_flag) {
+     all_force_mag[dd][3 * j + 0] += buf[m++];
+     all_force_mag[dd][3 * j + 1] += buf[m++];
+     all_force_mag[dd][3 * j + 2] += buf[m++];
+   }
  }
}

This reduces code duplication by combining the common parts and conditionally adding spin-related data.

source/api_cc/include/DeepPot.h (2)

Line range hint 22-36: Pass primitive types by value instead of by const reference

In the method declarations, parameters of type int are passed as const int&, such as const int& gpu_rank and const int& ago. Passing primitive types like int by const reference can be less efficient than passing them by value due to the overhead of indirection. Consider changing these parameters to be passed by value instead.

Also applies to: 130-130, 247-247, 570-570, 787-787, 1107-1107, 1120-1120


1180-1182: Use std::size_t for size-related variables

The member variable numb_models is declared as unsigned. It's recommended to use std::size_t for variables representing sizes or counts to ensure portability and avoid potential issues with type sizes on different platforms.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 11aeb17 and 114898f.

📒 Files selected for processing (10)
  • source/api_c/include/c_api.h (13 hunks)
  • source/api_c/include/deepmd.hpp (9 hunks)
  • source/api_c/src/c_api.cc (15 hunks)
  • source/api_cc/include/DeepPot.h (8 hunks)
  • source/api_cc/include/DeepPotPT.h (4 hunks)
  • source/api_cc/include/DeepPotTF.h (5 hunks)
  • source/api_cc/src/DeepPot.cc (6 hunks)
  • source/api_cc/src/DeepPotTF.cc (6 hunks)
  • source/api_cc/tests/test_deeppot_dpa1_pt_spin.cc (1 hunks)
  • source/lmp/pair_deepmd.cpp (15 hunks)
🧰 Additional context used
🔇 Additional comments (31)
source/api_cc/include/DeepPotPT.h (2)

159-205: New compute method overload with neighbor list for spin calculations is well-implemented.

The new compute method overload that includes neighbor list handling for spin calculations is consistent with the existing method signatures and properly documented.


385-412: New computew method overload with neighbor list is well-implemented.

The new computew method overload that includes neighbor list handling is consistent with the existing method signatures and properly structured.

source/api_cc/src/DeepPotTF.cc (1)

514-519: ⚠️ Potential issue

Add error handling to get_vector method

The get_vector method calls session_get_vector without handling potential exceptions. If session_get_vector fails, it might throw an exception that could propagate unexpectedly. It's advisable to handle exceptions to ensure the method's reliability.

Apply this diff to add exception handling:

 template <class VT>
 void DeepPotTF::get_vector(std::vector<VT>& vec,
                            const std::string& name) const {
+  try {
     session_get_vector<VT>(vec, session, name);
+  } catch (const std::exception& e) {
+    throw deepmd::deepmd_exception("Failed to get vector '" + name + "': " + e.what());
+  }
 }

Likely invalid or redundant comment.

source/api_cc/tests/test_deeppot_dpa1_pt_spin.cc (3)

134-138: Validate the correctness of the test assertions

The current test uses a tolerance EPSILON to compare computed values with expected results. Given the previous comment about an unclear bug affecting precision (lines 16-18), please ensure that the tolerance is appropriate and that the test reliably verifies the correctness of the computation.

Consider running additional checks or adjusting the EPSILON value if necessary to ensure the test's robustness.


110-110: Ensure all necessary types are included in ValueTypes

In the TYPED_TEST_SUITE, the ValueTypes type list should include all relevant numeric types intended for testing. Verify that ValueTypes is defined and contains the types you wish to test, such as float and double.

If ValueTypes is not already defined, you can define it as follows:

using ValueTypes = ::testing::Types<float, double>;

Ensure this is included in your test suite to cover all necessary cases.


145-183: LGTM!

The second test case cpu_build_nlist_atomic correctly extends the first test by verifying atomic energies. The implementation aligns with the intended functionality.

source/api_cc/include/DeepPotTF.h (4)

399-399: The previous comment about including parameter names in the cum_sum function declaration is still valid.


408-409: The previous comment about adding documentation for the get_vector template method is still valid.


418-419: The previous comment about the commented-out code in lines 418-419 is still valid.


425-425: The previous comment about the commented-out code in line 425 is still valid.

source/lmp/pair_deepmd.cpp (1)

1258-1263: Verify correctness of comm_reverse buffer size calculation

The buffer size comm_reverse is adjusted based on the atom->sp_flag:

// comm_reverse = numb_models * 3;
if (atom->sp_flag) {
  comm_reverse = numb_models * 3 * 2;
} else {
  comm_reverse = numb_models * 3;
}

Ensure that the multiplication by 2 accurately accounts for the additional spin-related data (all_force_mag) being communicated. Double-check that this calculation matches the data packed and unpacked in pack_reverse_comm and unpack_reverse_comm to prevent buffer overflows or data corruption.

To confirm, you can review the data sizes in the communication functions.

source/api_cc/include/DeepPot.h (1)

91-147: Well-documented addition of spin-related functionality

The newly added methods that include force_mag and spin parameters are thoroughly documented, and the parameter ordering is consistent across overloads. This enhances the clarity and usability of the API.

Also applies to: 207-270, 437-484, 534-588, 637-689, 746-807, 1080-1120, 1162-1206

source/api_cc/src/DeepPot.cc (6)

137-155: Addition of overloaded compute method with spin support

The new overloaded compute method correctly introduces spin support by adding dforce_mag_ and dspin_ parameters. The implementation aligns with existing coding patterns and extends functionality appropriately.


157-172: Overloaded compute method for vector energies with spin

The addition of this overloaded compute method to handle vectors of energies with spin parameters is implemented correctly and consistently with the codebase.


476-564: Addition of overloaded compute methods with atom-level outputs and spin support

These overloaded compute methods appropriately include dforce_mag_, dspin_, and atom-level outputs datom_energy_, datom_virial_ to support spin calculations. The implementation is consistent with the existing design.


658-766: Redundant code in overloaded compute methods

Multiple overloaded compute methods are added to support spin, but they introduce redundancy in the codebase. Although previous comments addressed code duplication, the issue persists. Consider refactoring these methods to reduce duplication by extracting common logic into helper functions or utilizing default parameters.


1124-1182: Proper initialization of all_force_mag

The vector all_force_mag is now correctly resized to match numb_models before use, addressing the previous concern about uninitialized vectors. This ensures safe memory operations and prevents potential runtime errors.


1243-1310: Refactoring opportunity in DeepPotModelDevi::compute methods

The newly added compute methods supporting spin and atom-level outputs contain similar code structures that could benefit from refactoring. Consider extracting shared code into private helper functions to adhere to the DRY principle and enhance maintainability.

source/api_c/include/c_api.h (13)

13-15: API Version Updated to 23

The API version has been correctly incremented to 23, reflecting the changes introduced with spin support.


164-201: New Function DP_DeepPotComputeSP Added Successfully

The function DP_DeepPotComputeSP is properly documented, and the parameter list aligns with the documentation. This addition extends the API to handle spin inputs effectively.


233-270: New Function DP_DeepPotComputefSP Implemented Correctly

The float version of the spin computation function, DP_DeepPotComputefSP, is added with appropriate documentation. Parameter names and types are consistent with the double version.


309-323: Missing Documentation for DP_DeepPotComputeNListSP

The function DP_DeepPotComputeNListSP lacks a documentation block explaining its purpose, parameters, and usage. To maintain consistency and aid users, please add detailed documentation above the function declaration.


363-378: Add Documentation for DP_DeepPotComputeNListfSP

Similar to the double version, the float version DP_DeepPotComputeNListfSP is missing its documentation block. Please provide comprehensive comments to describe the function's behavior and parameters.


418-462: Function DP_DeepPotCompute2SP Added with Proper Documentation

The multi-frame spin computation function DP_DeepPotCompute2SP is correctly implemented, and the documentation thoroughly describes its usage. Parameters are consistently named and well-documented.


502-546: Function DP_DeepPotComputef2SP Implemented Correctly

The float version DP_DeepPotComputef2SP mirrors the double version with accurate documentation and consistent parameter naming. The function enhances spin support for multiple frames.


593-611: Add Documentation for DP_DeepPotComputeNList2SP

The function DP_DeepPotComputeNList2SP is missing a documentation block. Providing detailed comments will help users understand its purpose and how to use it properly.


658-676: Provide Documentation for DP_DeepPotComputeNListf2SP

The function DP_DeepPotComputeNListf2SP lacks documentation. Please include comprehensive comments to maintain consistency and assist users.


973-988: Add Documentation for DP_DeepPotModelDeviComputeNListSP

The function DP_DeepPotModelDeviComputeNListSP does not have a documentation block. Adding one will improve understandability and maintain consistency across the API.


1026-1041: Document DP_DeepPotModelDeviComputeNListfSP

Please provide a documentation block for the function DP_DeepPotModelDeviComputeNListfSP to explain its usage and parameters.


1088-1106: Include extern Keyword in Function Declaration

The function DP_DeepPotModelDeviComputeNList2SP lacks the extern keyword, which is used in other function declarations. This inconsistency might lead to linkage issues.


1153-1171: Add extern Keyword to DP_DeepPotModelDeviComputeNListf2SP

The function DP_DeepPotModelDeviComputeNListf2SP is missing the extern keyword. Including it will ensure consistency and prevent potential linkage problems.

source/api_cc/include/DeepPotPT.h Show resolved Hide resolved
source/api_cc/src/DeepPotTF.cc Show resolved Hide resolved
source/api_cc/src/DeepPotTF.cc Show resolved Hide resolved
source/api_cc/src/DeepPotTF.cc Show resolved Hide resolved
source/api_cc/src/DeepPotTF.cc Show resolved Hide resolved
source/api_c/src/c_api.cc Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
source/api_c/include/deepmd.hpp Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (1)
source/api_cc/src/DeepPotPT.cc (1)

Line range hint 839-954: Reduce duplication in computew methods using templates

The computew methods contain significant duplication in their error handling wrappers.

Consider using a template function to reduce duplication:

template <typename VALUETYPE, typename ENERGYVTYPE>
void DeepPotPT::computew_impl(
    ENERGYVTYPE& ener,
    std::vector<VALUETYPE>& force,
    std::vector<VALUETYPE>& force_mag,
    std::vector<VALUETYPE>& virial,
    std::vector<VALUETYPE>& atom_energy,
    std::vector<VALUETYPE>& atom_virial,
    const std::vector<VALUETYPE>& coord,
    const std::vector<VALUETYPE>& spin,
    const std::vector<int>& atype,
    const std::vector<VALUETYPE>& box,
    const std::vector<VALUETYPE>& fparam,
    const std::vector<VALUETYPE>& aparam,
    const bool atomic,
    const std::optional<InputNlist>& inlist = std::nullopt,
    const std::optional<int>& nghost = std::nullopt,
    const std::optional<int>& ago = std::nullopt) {
    translate_error([&] {
        if (inlist) {
            compute(ener, force, force_mag, virial, atom_energy, atom_virial,
                   coord, spin, atype, box, *nghost, *inlist, *ago, fparam,
                   aparam, atomic);
        } else {
            compute(ener, force, force_mag, virial, atom_energy, atom_virial,
                   coord, spin, atype, box, fparam, aparam, atomic);
        }
    });
}

This would allow you to replace all the computew overloads with calls to this template function.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 114898f and fef13f5.

📒 Files selected for processing (2)
  • deepmd/pt/model/descriptor/repformers.py (4 hunks)
  • source/api_cc/src/DeepPotPT.cc (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/descriptor/repformers.py
🔇 Additional comments (1)
source/api_cc/src/DeepPotPT.cc (1)

359-362: 🛠️ Refactor suggestion

Optimize spin tensor creation

The creation of spin_wrapped_Tensor could be optimized by directly using the spin vector data instead of creating an intermediate copy.

Apply this diff to optimize memory usage:

-std::vector<VALUETYPE> spin_wrapped = spin;
-at::Tensor spin_wrapped_Tensor =
-    torch::from_blob(spin_wrapped.data(), {1, nall_real, 3}, options)
+at::Tensor spin_wrapped_Tensor =
+    torch::from_blob(const_cast<VALUETYPE*>(spin.data()), {1, nall_real, 3}, options)
         .to(device);

Likely invalid or redundant comment.

Comment on lines +312 to +500
if (do_message_passing == 1 && nghost == 0) {
// for the situation that no ghost atoms (e.g. serial nopbc)
// set the mapping arange(nloc) is enough
auto option = torch::TensorOptions().device(device).dtype(torch::kInt64);
mapping_tensor = at::arange(nloc_real, option).unsqueeze(0);
}
}
at::Tensor firstneigh = createNlistTensor(nlist_data.jlist);
firstneigh_tensor = firstneigh.to(torch::kInt64).to(device);
bool do_atom_virial_tensor = atomic;
c10::optional<torch::Tensor> fparam_tensor;
if (!fparam.empty()) {
fparam_tensor =
torch::from_blob(const_cast<VALUETYPE*>(fparam.data()),
{1, static_cast<std::int64_t>(fparam.size())}, options)
.to(device);
}
c10::optional<torch::Tensor> aparam_tensor;
if (!aparam_.empty()) {
aparam_tensor =
torch::from_blob(
const_cast<VALUETYPE*>(aparam_.data()),
{1, lmp_list.inum,
static_cast<std::int64_t>(aparam_.size()) / lmp_list.inum},
options)
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
(do_message_passing == 1 && nghost > 0)
? module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
spin_wrapped_Tensor, firstneigh_tensor,
mapping_tensor, fparam_tensor, aparam_tensor,
do_atom_virial_tensor, comm_dict)
.toGenericDict()
: module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
spin_wrapped_Tensor, firstneigh_tensor,
mapping_tensor, fparam_tensor, aparam_tensor,
do_atom_virial_tensor)
.toGenericDict();
c10::IValue energy_ = outputs.at("energy");
c10::IValue force_ = outputs.at("extended_force");
c10::IValue force_mag_ = outputs.at("extended_force_mag");
// spin model not suported yet
// c10::IValue virial_ = outputs.at("virial");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
cpu_energy_.data_ptr<ENERGYTYPE>() + cpu_energy_.numel());
torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU);
dforce.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
torch::Tensor flat_force_mag_ =
force_mag_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_mag_ = flat_force_mag_.to(torch::kCPU);
dforce_mag.assign(
cpu_force_mag_.data_ptr<VALUETYPE>(),
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

// bkw map
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
force_mag.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
select_map<VALUETYPE>(force, dforce, bkw_map, 3, nframes, fwd_map.size(),
nall_real);
select_map<VALUETYPE>(force_mag, dforce_mag, bkw_map, 3, nframes,
fwd_map.size(), nall_real);
if (atomic) {
// spin model not suported yet
// c10::IValue atom_virial_ = outputs.at("extended_virial");
c10::IValue atom_energy_ = outputs.at("atom_energy");
torch::Tensor flat_atom_energy_ =
atom_energy_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_atom_energy_ = flat_atom_energy_.to(torch::kCPU);
datom_energy.resize(nall_real,
0.0); // resize to nall to be consistenet with TF.
datom_energy.assign(
cpu_atom_energy_.data_ptr<VALUETYPE>(),
cpu_atom_energy_.data_ptr<VALUETYPE>() + cpu_atom_energy_.numel());
// spin model not suported yet
// torch::Tensor flat_atom_virial_ =
// atom_virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_atom_virial_ = flat_atom_virial_.to(torch::kCPU);
// datom_virial.assign(
// cpu_atom_virial_.data_ptr<VALUETYPE>(),
// cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
// atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
select_map<VALUETYPE>(atom_energy, datom_energy, bkw_map, 1, nframes,
fwd_map.size(), nall_real);
// select_map<VALUETYPE>(atom_virial, datom_virial, bkw_map, 9, nframes,
// fwd_map.size(), nall_real);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider refactoring to improve maintainability

The compute method is quite long (>180 lines) and handles multiple responsibilities. This makes it harder to maintain and test.

Consider breaking down the method into smaller, focused helper methods:

private:
    template <typename VALUETYPE>
    void prepareInputTensors(/* params */);
    
    template <typename VALUETYPE>
    void handleMessagePassing(/* params */);
    
    template <typename VALUETYPE>
    void processOutputs(/* params */);

This would improve readability and make the code easier to maintain.

Comment on lines +658 to +768
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam,
const bool atomic) {
torch::Device device(torch::kCUDA, gpu_id);
if (!gpu_enabled) {
device = torch::Device(torch::kCPU);
}
std::vector<VALUETYPE> coord_wrapped = coord;
std::vector<VALUETYPE> spin_wrapped = spin;
int natoms = atype.size();
auto options = torch::TensorOptions().dtype(torch::kFloat64);
torch::ScalarType floatType = torch::kFloat64;
if (std::is_same_v<VALUETYPE, float>) {
options = torch::TensorOptions().dtype(torch::kFloat32);
floatType = torch::kFloat32;
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
int nframes = 1;
std::vector<torch::jit::IValue> inputs;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options)
.to(device);
inputs.push_back(coord_wrapped_Tensor);
std::vector<std::int64_t> atype_64(atype.begin(), atype.end());
at::Tensor atype_Tensor =
torch::from_blob(atype_64.data(), {1, natoms}, int_options).to(device);
inputs.push_back(atype_Tensor);
at::Tensor spin_wrapped_Tensor =
torch::from_blob(spin_wrapped.data(), {1, natoms, 3}, options).to(device);
inputs.push_back(spin_wrapped_Tensor);
c10::optional<torch::Tensor> box_Tensor;
if (!box.empty()) {
box_Tensor =
torch::from_blob(const_cast<VALUETYPE*>(box.data()), {1, 9}, options)
.to(device);
}
inputs.push_back(box_Tensor);
c10::optional<torch::Tensor> fparam_tensor;
if (!fparam.empty()) {
fparam_tensor =
torch::from_blob(const_cast<VALUETYPE*>(fparam.data()),
{1, static_cast<std::int64_t>(fparam.size())}, options)
.to(device);
}
inputs.push_back(fparam_tensor);
c10::optional<torch::Tensor> aparam_tensor;
if (!aparam.empty()) {
aparam_tensor =
torch::from_blob(
const_cast<VALUETYPE*>(aparam.data()),
{1, natoms, static_cast<std::int64_t>(aparam.size()) / natoms},
options)
.to(device);
}
inputs.push_back(aparam_tensor);
bool do_atom_virial_tensor = atomic;
inputs.push_back(do_atom_virial_tensor);
c10::Dict<c10::IValue, c10::IValue> outputs =
module.forward(inputs).toGenericDict();
c10::IValue energy_ = outputs.at("energy");
c10::IValue force_ = outputs.at("force");
c10::IValue force_mag_ = outputs.at("force_mag");
// spin model not suported yet
// c10::IValue virial_ = outputs.at("virial");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
cpu_energy_.data_ptr<ENERGYTYPE>() + cpu_energy_.numel());
torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU);
force.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
torch::Tensor flat_force_mag_ =
force_mag_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_mag_ = flat_force_mag_.to(torch::kCPU);
force_mag.assign(
cpu_force_mag_.data_ptr<VALUETYPE>(),
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
if (atomic) {
// c10::IValue atom_virial_ = outputs.at("atom_virial");
c10::IValue atom_energy_ = outputs.at("atom_energy");
torch::Tensor flat_atom_energy_ =
atom_energy_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_atom_energy_ = flat_atom_energy_.to(torch::kCPU);
atom_energy.assign(
cpu_atom_energy_.data_ptr<VALUETYPE>(),
cpu_atom_energy_.data_ptr<VALUETYPE>() + cpu_atom_energy_.numel());
// torch::Tensor flat_atom_virial_ =
// atom_virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_atom_virial_ = flat_atom_virial_.to(torch::kCPU);
// atom_virial.assign(
// cpu_atom_virial_.data_ptr<VALUETYPE>(),
// cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Reduce code duplication between compute overloads

There's significant code duplication between this overload and the previous one, particularly in tensor creation and output processing.

Consider extracting common logic into helper methods:

template <typename VALUETYPE>
void DeepPotPT::createInputTensors(
    std::vector<torch::jit::IValue>& inputs,
    const std::vector<VALUETYPE>& coord,
    const std::vector<VALUETYPE>& spin,
    const std::vector<int>& atype,
    const torch::Device& device) {
    // Common tensor creation logic
}

template <typename VALUETYPE>
void DeepPotPT::processOutputs(
    const c10::Dict<c10::IValue, c10::IValue>& outputs,
    std::vector<VALUETYPE>& force,
    std::vector<VALUETYPE>& force_mag) {
    // Common output processing logic
}

Comment on lines +445 to +465
// spin model not suported yet
// c10::IValue virial_ = outputs.at("virial");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
cpu_energy_.data_ptr<ENERGYTYPE>() + cpu_energy_.numel());
torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU);
dforce.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
torch::Tensor flat_force_mag_ =
force_mag_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_mag_ = flat_force_mag_.to(torch::kCPU);
dforce_mag.assign(
cpu_force_mag_.data_ptr<VALUETYPE>(),
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Document the current limitations of virial calculations

The commented-out code for virial calculations needs proper documentation to explain the current limitations and future plans.

Add a detailed comment explaining why virial calculations are not supported for spin models and what the plans are for implementing them:

-// spin model not suported yet
+// TODO(iProzd): Virial calculations for spin models are not yet supported
+// Current limitation: The virial tensor calculation needs to be extended to handle spin contributions
+// Planned implementation: Update the virial calculations to include spin-dependent terms
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// spin model not suported yet
// c10::IValue virial_ = outputs.at("virial");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
cpu_energy_.data_ptr<ENERGYTYPE>() + cpu_energy_.numel());
torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU);
dforce.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
torch::Tensor flat_force_mag_ =
force_mag_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_mag_ = flat_force_mag_.to(torch::kCPU);
dforce_mag.assign(
cpu_force_mag_.data_ptr<VALUETYPE>(),
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
// TODO(iProzd): Virial calculations for spin models are not yet supported
// Current limitation: The virial tensor calculation needs to be extended to handle spin contributions
// Planned implementation: Update the virial calculations to include spin-dependent terms
// c10::IValue virial_ = outputs.at("virial");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
cpu_energy_.data_ptr<ENERGYTYPE>() + cpu_energy_.numel());
torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU);
dforce.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
torch::Tensor flat_force_mag_ =
force_mag_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_mag_ = flat_force_mag_.to(torch::kCPU);
dforce_mag.assign(
cpu_force_mag_.data_ptr<VALUETYPE>(),
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
// TODO(iProzd): Virial calculations for spin models are not yet supported
// Current limitation: The virial tensor calculation needs to be extended to handle spin contributions
// Planned implementation: Update the virial calculations to include spin-dependent terms
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Support C++ interface for pytorch spin model.
4 participants