Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): support spin virial #4545

Draft
wants to merge 1 commit into
base: devel
Choose a base branch
from

Conversation

iProzd
Copy link
Collaborator

@iProzd iProzd commented Jan 10, 2025

Summary by CodeRabbit

  • New Features

    • Added support for virial loss calculations in spin energy models
    • Enhanced model's ability to process spin-related coordinate corrections
    • Improved handling of virial outputs in model computations
  • Bug Fixes

    • Corrected virial array population in model deviation and computation methods
    • Fixed issues with virial output processing in various model implementations
  • Tests

    • Added new test cases for spin energy models with virial calculations
    • Extended testing framework to support spin-related model evaluations

Copy link
Contributor

coderabbitai bot commented Jan 10, 2025

📝 Walkthrough

Walkthrough

This pull request introduces enhanced support for virial calculations in the DeepMD-kit framework, specifically focusing on spin-related models. The changes span multiple files across the project, adding new functionality to handle virial outputs, coordinate corrections, and spin-related computations. The modifications enable more comprehensive energy and virial loss calculations, with updates to model processing, loss computation, and testing frameworks.

Changes

File Change Summary
deepmd/pt/loss/ener_spin.py Added virial loss calculation support in EnergySpinLoss class
deepmd/pt/model/model/make_model.py Introduced coord_corr_for_virial parameter in forward methods
deepmd/pt/model/model/spin_model.py Enhanced spin input processing and model forward methods
deepmd/pt/model/model/transform_output.py Added extended_coord_corr parameter to output transformation
source/api_c/include/deepmd.hpp Fixed virial array population in DeepSpinModelDevi class
source/api_c/src/c_api.cc Restored virial output handling
source/api_cc/src/DeepSpinPT.cc Implemented virial tensor retrieval and assignment
source/tests/pt/model/test_autodiff.py Added spin-related test cases and logic
source/tests/pt/model/test_ener_spin_model.py Updated spin input processing tests
source/tests/universal/common/cases/model/utils.py Modified virial calculation logic
source/tests/universal/pt/model/test_model.py Added test_spin_virial property

Possibly related PRs

Suggested Labels

Python, Docs

Suggested Reviewers

  • njzjz
  • wanghan-iapcm

Finishing Touches

  • 📝 Generate Docstrings (Beta)

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (12)
deepmd/pt/model/model/spin_model.py (3)

58-59: Enhance code readability by reshaping tensor in a single step

In line 58, consider reshaping self.virtual_scale_mask.to(atype.device)[atype] directly without wrapping it in parentheses for better readability.


379-381: Avoid unnecessary computation by not calling process_spin_input during stat computation

In the compute_or_load_stat method, calling process_spin_input may introduce unnecessary computational overhead if coord_corr is not used. Consider modifying the code to exclude coord_corr when it's not needed.


591-594: Consider consistency in handling do_grad_c checks

In the forward method, ensure that the handling of do_grad_c("energy") and subsequent assignments align with the changes made in translated_output_def. This maintains consistency across the methods.

source/tests/pt/model/test_autodiff.py (4)

144-144: Initialize the spin variable only when necessary

The spin variable is initialized even when test_spin is False. Consider moving the initialization inside the conditional block to optimize performance.

Apply this diff to adjust the initialization:

-        spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)

Move the initialization to after line 150, within the if test_spin block.


148-148: Ensure spin is only converted to NumPy when necessary

Similar to the previous comment, the conversion of spin to a NumPy array should be conditional based on test_spin to avoid unnecessary computations.


151-154: Simplify the conditional assignment of test_keys

The assignment of test_keys can be streamlined for clarity.

Apply this diff to simplify the code:

-        if not test_spin:
-            test_keys = ["energy", "force", "virial"]
-        else:
-            test_keys = ["energy", "force", "force_mag", "virial"]
+        test_keys = ["energy", "force", "virial"]
+        if test_spin:
+            test_keys.append("force_mag")

263-268: Add a newline for code style consistency

Include a blank line after the class definition to follow PEP 8 style guidelines for better readability.

Apply this diff:

 class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest):
 
+    def setUp(self) -> None:
         model_params = copy.deepcopy(model_spin)
deepmd/pt/model/model/transform_output.py (3)

159-159: Update function documentation to include new parameter

The fit_output_to_model_output function has a new parameter extended_coord_corr. Update the docstring to describe this parameter and its role in the computation.


195-195: Avoid using # noqa comments for line length

Instead of using # noqa: RUF005 to suppress line length warnings, refactor the code to comply with style guidelines for better maintainability.

Apply this diff to split the line:

-                            ).view(list(dc.shape[:-2]) + [1, 9])  # noqa: RUF005
+                            )
+                        dc = dc.view(list(dc.shape[:-2]) + [1, 9])

Line range hint 226-226: Consider adding type annotations for function returns

Adding type annotations to functions enhances code clarity and aids in static analysis. Consider specifying the return types for the functions in this module.

deepmd/pt/loss/ener_spin.py (1)

271-286: LGTM! The virial loss calculation is well implemented.

The implementation follows the established pattern for loss calculations, with proper scaling and optional MAE computation. The code is clean and well-structured.

Consider extracting the common pattern of loss calculation (L2, MAE, scaling) into a helper method to reduce code duplication across energy, force, and virial loss calculations.

source/tests/pt/model/test_ener_spin_model.py (1)

118-118: Document the purpose of the ignored return values.

The additional return values (marked with _) from process_spin_input and process_spin_input_lower are silently ignored. Consider adding a comment explaining what these values represent and why they can be safely ignored in these tests.

Also applies to: 177-177

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9af197c and 8fd9565.

📒 Files selected for processing (11)
  • deepmd/pt/loss/ener_spin.py (1 hunks)
  • deepmd/pt/model/model/make_model.py (7 hunks)
  • deepmd/pt/model/model/spin_model.py (11 hunks)
  • deepmd/pt/model/model/transform_output.py (2 hunks)
  • source/api_c/include/deepmd.hpp (2 hunks)
  • source/api_c/src/c_api.cc (1 hunks)
  • source/api_cc/src/DeepSpinPT.cc (4 hunks)
  • source/tests/pt/model/test_autodiff.py (3 hunks)
  • source/tests/pt/model/test_ener_spin_model.py (2 hunks)
  • source/tests/universal/common/cases/model/utils.py (3 hunks)
  • source/tests/universal/pt/model/test_model.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (21)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (13)
deepmd/pt/model/model/spin_model.py (4)

63-64: Ensure proper alignment of coordinate corrections

The concatenation of tensors in coord_corr must maintain correct alignment with the corresponding atoms. Verify that torch.zeros_like(coord) and -spin_dist are correctly ordered, ensuring that the coordinate corrections apply to the appropriate atoms.


92-95: Validate the consistency of virtual atom handling

When creating extended_coord_corr, confirm that the virtual atoms are correctly accounted for, and that the concatenation preserves the intended structure. This is crucial for accurate virial calculations involving spin corrections.


410-412: Handle the new output coord_corr_for_virial appropriately

Ensure that all downstream methods that receive coord_corr_for_virial can handle this new parameter without errors. Verify that self.backbone_model.forward_common accepts coord_corr_for_virial as an argument.


631-636: Verify accurate squeezing of tensors and assignment

In the forward_lower method, confirm that the squeeze operations correctly reduce tensor dimensions and that the results are assigned to the appropriate keys in model_predict.

source/tests/pt/model/test_autodiff.py (1)

150-150: Conditionally handle spin and test_spin variables

Verify that all usages of spin and test_spin within the VirialTest class are properly guarded by conditionals to prevent errors when test_spin is False.

deepmd/pt/model/model/transform_output.py (1)

191-196: Ensure tensor shapes are compatible during matrix multiplication

In the computation of dc_corr, validate that the shapes of the tensors involved in the matrix multiplication are compatible to prevent runtime errors.

deepmd/pt/model/model/make_model.py (2)

138-138: LGTM! The parameter addition is well documented.

The new coord_corr_for_virial parameter is properly documented with clear shape information.

Also applies to: 157-159


187-194: LGTM! The coordinate correction handling is robust.

The implementation properly handles both cases:

  • When correction is provided: converts to correct dtype and gathers extended coordinates
  • When correction is not provided: sets extended_coord_corr to None
source/api_cc/src/DeepSpinPT.cc (1)

254-254: LGTM! The virial output handling is properly implemented.

The implementation correctly:

  • Retrieves virial tensor from outputs
  • Converts tensor to appropriate float type
  • Transfers data to CPU
  • Assigns to output vector

Also applies to: 269-273, 417-417, 432-435

source/tests/universal/pt/model/test_model.py (1)

716-717: Note the temporary nature of this flag.

The added test_spin_virial flag enables spin virial testing for this backend. As indicated by the comment, this can be removed once other backends support spin virial calculations.

Let's verify which backends currently lack spin virial support:

✅ Verification successful

Temporary flag is correctly annotated

The comment accurately reflects the current state of the codebase. The spin virial implementation is currently only available in the PyTorch (PT) backend, and the flag is appropriately marked as temporary until other backends (TensorFlow, PaddlePaddle) implement this functionality.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for spin virial implementations across backends
rg -A 5 "spin.*virial" source/api_cc/src/

Length of output: 43


Script:

#!/bin/bash
# Broader search for spin virial across the codebase
echo "=== Searching for spin virial across the codebase ==="
rg -i "spin.*virial" --type py

echo -e "\n=== Searching for related test files ==="
fd "test.*model.*py$"

echo -e "\n=== Looking for backend-specific implementations ==="
rg -i "spin.*virial" --type py -A 5 source/tests/

Length of output: 5891

source/tests/universal/common/cases/model/utils.py (1)

895-898: Improved control flow for spin virial testing.

The condition ensures virial computation only happens when appropriate - either for non-spin models or when spin virial testing is explicitly enabled.

source/api_c/src/c_api.cc (1)

865-869: Re-enabled virial data propagation.

The previously commented-out code for copying virial data has been restored, allowing proper propagation of virial values in spin model deviation computations.

source/api_c/include/deepmd.hpp (1)

2605-2607: Re-enabled virial data propagation in C++ interface.

The previously commented-out code for copying virial data has been restored, maintaining consistency with the C API changes.

@@ -408,6 +419,7 @@ def forward_common(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
coord_corr_for_virial=coord_corr_for_virial,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Confirm the compatibility of the new argument coord_corr_for_virial

Check if the backbone model's forward_common method is designed to accept coord_corr_for_virial. If not, update the backbone model accordingly or modify the call to prevent runtime errors.

@@ -469,6 +482,7 @@
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=extra_nlist_sort,
extended_coord_corr=extended_coord_corr_for_virial,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure extended_coord_corr is accepted by forward_common_lower

Similar to the previous comment, verify that self.backbone_model.forward_common_lower accepts extended_coord_corr as an argument. This prevents potential issues during model execution.

Comment on lines +558 to +562
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Assign the result of squeeze operations to reduce tensor dimensions

The squeeze methods in lines 560 and 562 do not modify tensors in place. Assign the results to ensure the dimensions are correctly reduced.

Apply this diff to fix the issue:

-        output_def["virial"].squeeze(-2)
+        output_def["virial"] = output_def["virial"].squeeze(-2)
-        output_def["atom_virial"].squeeze(-3)
+        output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"] = output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3)

Comment on lines +166 to 169
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
)
# detach
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure compatibility of tensor devices

When creating tensors within the np_infer function, ensure that all tensors are on the same device to prevent device mismatch errors, especially when env.DEVICE differs from "cpu".

Apply this diff to correct the device assignment:

-                    spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
+                    spins=torch.tensor(spin, device=new_cell.device).unsqueeze(0),

Committable suggestion skipped: line range outside the PR's diff.

@iProzd iProzd marked this pull request as draft January 13, 2025 06:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant