Skip to content

feat(dpa4): multiple updates for dpa4#5734

Open
OutisLi wants to merge 12 commits into
deepmodeling:masterfrom
OutisLi:pr/dpa4update
Open

feat(dpa4): multiple updates for dpa4#5734
OutisLi wants to merge 12 commits into
deepmodeling:masterfrom
OutisLi:pr/dpa4update

Conversation

@OutisLi

@OutisLi OutisLi commented Jul 5, 2026

Copy link
Copy Markdown
Collaborator

Since each commits relies on each other to pass the parity tests, so this is a fusion pr.

Summary by CodeRabbit

  • New Features

    • Added readout_layers and native per-atom spin support across SeZM/DPA4 descriptors, including new spin runtime input and spin-aware readout behavior.
    • Introduced a public SpinEmbedding component for the native spin scheme.
    • Added optional fused Triton acceleration for several SeZM/DPA4 operations (including flash-attention-style aggregation) and a GPU tile-configuration tuning workflow for supported shapes.
  • Bug Fixes

    • Improved robustness for zero-block/empty-edge descriptor paths and reduced spin-validation issues.
    • Added consistent "fndc" tensor layout handling and adjusted equivariant normalization/mixing layouts for correctness.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot wasn't able to review this pull request because it exceeds the maximum number of lines (20,000). Try reducing the number of changed lines and requesting a review from Copilot again.

Comment thread deepmd/pt/model/model/sezm_native_spin_model.py Dismissed
Comment thread deepmd/kernels/triton/sezm/so2_stack_fp16x3.py Dismissed
Comment thread deepmd/kernels/triton/sezm/so2_stack_fp16x3.py Dismissed
Comment thread deepmd/kernels/triton/sezm/so2_value_path.py Dismissed
Comment thread deepmd/kernels/triton/sezm/so2_value_path.py Dismissed
Comment thread deepmd/pt/model/descriptor/sezm_nn/so2.py Dismissed
@coderabbitai

coderabbitai Bot commented Jul 5, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Important

Review skipped

Review was skipped as selected files did not have any reviewable changes.

💤 Files selected but had no reviewable changes (4)
  • source/api_cc/include/DeepPot.h
  • source/api_cc/include/DeepPotPTExpt.h
  • source/api_cc/src/DeepPot.cc
  • source/api_cc/src/DeepPotPTExpt.cc
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 69526d55-2e9f-457a-a1d0-57a7adfaaec4

📥 Commits

Reviewing files that changed from the base of the PR and between 2dac033 and b4b22df.

📒 Files selected for processing (4)
  • source/api_cc/include/DeepPot.h
  • source/api_cc/include/DeepPotPTExpt.h
  • source/api_cc/src/DeepPot.cc
  • source/api_cc/src/DeepPotPTExpt.cc

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds native per-atom spin support, a configurable residual readout stack, focus-major tensor-layout updates, fused Triton/CuTe SeZM kernels, and spin-aware training/export/validation wiring.

Changes

Native Spin Descriptor and Readout Stack

Layer / File(s) Summary
SpinEmbedding module (dpmodel)
deepmd/dpmodel/descriptor/dpa4_nn/embedding.py
Adds SpinEmbedding, spin l=1 folding in GeometricInitialEmbedding, and spin-aware coordinate channels in EnvironmentInitialEmbedding.
SpinEmbedding module (pt SeZM)
deepmd/pt/model/descriptor/sezm_nn/embedding.py
Mirrors the dpmodel spin embedding, GIE folding, and environment spin-channel support.
DescrptDPA4 spin injection and readout stack (dpmodel)
deepmd/dpmodel/descriptor/dpa4.py
Wires spin embedding into construction/forward, introduces canonical backbone degrees, and replaces the final output FFN with a residual readout stack.
DescrptSeZM spin injection and readout stack (pt)
deepmd/pt/model/descriptor/sezm.py
Mirrors the DPA4 spin/readout changes for the PT SeZM descriptor, including AMP-inference gating.
Package re-exports and dependent atomic model update
deepmd/dpmodel/descriptor/dpa4_nn/__init__.py, deepmd/pt/model/descriptor/sezm_nn/__init__.py, deepmd/pt/model/atomic_model/sezm_atomic_model.py
Re-exports SpinEmbedding and updates SeZMAtomicModel to use canonical descriptor lmax attributes.

Focus-Major (fndc) Layout Refactor and Wigner-D Optimization

Layer / File(s) Summary
GatedActivation fndc layout support
.../dpa4_nn/activation.py, .../sezm_nn/activation.py
Adds fndc layout validation and layout-dependent gate projection logic.
GridNet fndc layout support
.../dpa4_nn/grid_net.py, .../sezm_nn/grid_net.py
Adds the fndc layout literal and axis-permutation handling.
RMSNorm axis order swap
.../dpa4_nn/norm.py, .../sezm_nn/norm.py
Swaps the leading focus/edge axis order and related broadcasting.
SO2Linear/SO2Convolution refactor (dpmodel)
.../dpa4_nn/so2.py
Focus-major SO2Linear/matmul refactor plus fused value-path seam hooks.
SO2Linear/SO2Convolution refactor (pt)
.../sezm_nn/so2.py, .../sezm_nn/utils.py
Mirrors the focus-major refactor, adds triton_infer_level gating and fused flash-attention/value-path wiring, and removes the old local Triton helper.
Wigner-D monomial optimization
.../dpa4_nn/wignerd.py, .../sezm_nn/wignerd.py
Unifies monomial matrix construction via _monomial_matrix, with an optional Triton fast path.
pt_expt fused-backend wiring
deepmd/pt_expt/descriptor/dpa4.py, .../dpa4_nn/*
Selects fused Triton/CuTe backends and AMP-inference autocast.

Fused Triton/CuTe Kernel Backends for SeZM SO(2)

Layer / File(s) Summary
Kernel gates
deepmd/kernels/utils.py, deepmd/kernels/*/__init__.py
Adds triton_infer_level/use_cute_infer/use_amp_infer env-gate helpers.
CuTe value-path kernels
deepmd/kernels/cute/sezm/*
Fused forward/backward CuTe kernels and per-convolution operator entrypoint.
Triton flash attention/force assembly
deepmd/kernels/triton/sezm/flash_atten.py, force_assembly.py
Fused aggregation and force/virial assembly kernels.
Triton GEMM/rotation
so2_block_gemm.py, so2_rotation.py
Block-diagonal GEMM kernel and rotation import path update.
Triton fp16x3/value-path
so2_stack_fp16x3.py, so2_value_path.py
Compensated fp16x3 mixing stack and fused rotate+mix value-path kernels.
Triton Wigner monomials
wigner_monomials.py
Fused quaternion monomial kernel.
Tile-config tuning
tile_config_data.py, tile_configs.py, sweep_tile_configs.py
Built-in/runtime tile-config lookup and benchmarking/sweeping tool.

Native Spin Model, Training, and Export Integration

Layer / File(s) Summary
SeZMNativeSpinModel and dispatch
deepmd/pt/model/model/sezm_native_spin_model.py, __init__.py, serialization.py
New native-spin model class, scheme dispatch, deserialization support.
SeZMModel/SeZMPropertyModel spin plumbing
sezm_model.py, sezm_property_model.py
Threads spin through forward/lower/core_compute/export tracing.
Spin stat packing/export ABI
sezm_spin_model.py, spin_model.py
Centralizes stat sample packing, energy-key splitting, export ABI overrides.
Magnetic force/loss helpers
transform_output.py, ener_spin.py
Spin-leaf gradient computation and NaN-safe magnetic-force loss reductions.
Training/validation spin support
training.py, validation.py
AMP-inference default, loss-type gating, profile-driven full-validation metrics.
freeze_pt2 export ABI/compile options
freeze_pt2.py, compile_compat.py
Native-spin edge_vec ABI, Triton tuning before trace, inference inductor options.

Estimated code review effort: 5 (Critical) | ~150 minutes

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant DescrptDPA4
  participant SpinEmbedding
  participant GIE as GeometricInitialEmbedding
  participant Readout as _apply_readout

  Caller->>DescrptDPA4: call(coord, atype, spin)
  DescrptDPA4->>SpinEmbedding: _apply_spin_embedding(type_feat, spin)
  SpinEmbedding-->>DescrptDPA4: scalar l=0, vector l=1
  DescrptDPA4->>GIE: call(..., spin_l1_message)
  GIE-->>DescrptDPA4: non_scalar_message with folded spin l=1
  DescrptDPA4->>Readout: _apply_readout(x)
  Readout-->>Caller: descriptor output
Loading
sequenceDiagram
  participant Trainer
  participant SeZMNativeSpinModel
  participant SeZMModel as SeZMModel.core_compute
  participant TransformOutput as edge_energy_deriv

  Trainer->>SeZMNativeSpinModel: forward(coord, atype, spin)
  SeZMNativeSpinModel->>SeZMModel: forward_common(spin=spin)
  SeZMModel->>TransformOutput: edge_energy_deriv(edge_vec, spin_leaf=spin)
  TransformOutput-->>SeZMModel: force, virial, energy_derv_r_mag
  SeZMModel-->>SeZMNativeSpinModel: energy, force, force_mag
  SeZMNativeSpinModel-->>Trainer: atom_energy, energy, force, force_mag, mask_mag
Loading

Possibly related PRs

Suggested reviewers: wanghan-iapcm, iProzd

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title is generic and does not clearly describe the main DPA4 changes. Rename it to name the primary change, such as adding native spin support and fused SeZM/DPA4 inference backends.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
deepmd/kernels/triton/sezm/__init__.py (1)

9-42: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

flash_atten.py's availability flag isn't included in the aggregate.

TRITON_AVAILABLE ANDs 7 per-kernel flags but omits FLASH_ATTEN-related availability from flash_atten.py (part of this cohort). Functionally equivalent today since every flag reduces to "is triton importable" per the comment, but if flash_atten.py ever gains its own stricter gating (e.g. a Triton-version check), the package flag would silently miss it.

♻️ Proposed fix
+from .flash_atten import (
+    FLASH_ATTEN_TRITON_AVAILABLE,
+)
 from .force_assembly import (
     FORCE_ASSEMBLY_TRITON_AVAILABLE,
 )
@@
 TRITON_AVAILABLE = (
     TRITON_ROTATION_AVAILABLE
     and RADIAL_MIX_TRITON_AVAILABLE
     and SO2_BLOCK_GEMM_TRITON_AVAILABLE
     and SO2_VALUE_PATH_TRITON_AVAILABLE
     and STACK_FP16X3_TRITON_AVAILABLE
     and WIGNER_MONOMIALS_TRITON_AVAILABLE
     and FORCE_ASSEMBLY_TRITON_AVAILABLE
+    and FLASH_ATTEN_TRITON_AVAILABLE
 )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/kernels/triton/sezm/__init__.py` around lines 9 - 42, The aggregate
TRITON_AVAILABLE flag in the sezm package is missing the availability guard from
flash_atten.py, so the package-level check can drift out of sync with all
kernels in this cohort. Update the imports and the TRITON_AVAILABLE conjunction
in deepmd/kernels/triton/sezm/__init__.py to include the flash_atten module’s
availability symbol alongside TRITON_ROTATION_AVAILABLE,
RADIAL_MIX_TRITON_AVAILABLE, SO2_BLOCK_GEMM_TRITON_AVAILABLE,
SO2_VALUE_PATH_TRITON_AVAILABLE, STACK_FP16X3_TRITON_AVAILABLE,
WIGNER_MONOMIALS_TRITON_AVAILABLE, and FORCE_ASSEMBLY_TRITON_AVAILABLE, keeping
the package flag as the full AND of every per-kernel availability check.
deepmd/pt/model/model/__init__.py (1)

142-156: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Optional: validate index range before scattering into the mask.

For the magnetic-index form, mask[use_spin] = True will raise a raw NumPy IndexError (and silently accept negatives) when an index is out of range for type_map. A small explicit check would give a clearer, actionable error consistent with the symbol-form validation just above.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/__init__.py` around lines 142 - 156, The spin.use_spin
normalization in deepmd/pt/model/model/__init__.py should validate magnetic
indices before scattering them into the mask. In the branch that handles index
lists in the model parameter processing logic around use_spin and type_map, add
an explicit range check for every index in use_spin so out-of-range values
(including negatives) raise a clear ValueError instead of a raw NumPy
IndexError. Keep the existing symbol-name validation path intact and update the
mask assignment only after the indices are confirmed valid.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@deepmd/kernels/triton/sezm/__init__.py`:
- Around line 9-42: The aggregate TRITON_AVAILABLE flag in the sezm package is
missing the availability guard from flash_atten.py, so the package-level check
can drift out of sync with all kernels in this cohort. Update the imports and
the TRITON_AVAILABLE conjunction in deepmd/kernels/triton/sezm/__init__.py to
include the flash_atten module’s availability symbol alongside
TRITON_ROTATION_AVAILABLE, RADIAL_MIX_TRITON_AVAILABLE,
SO2_BLOCK_GEMM_TRITON_AVAILABLE, SO2_VALUE_PATH_TRITON_AVAILABLE,
STACK_FP16X3_TRITON_AVAILABLE, WIGNER_MONOMIALS_TRITON_AVAILABLE, and
FORCE_ASSEMBLY_TRITON_AVAILABLE, keeping the package flag as the full AND of
every per-kernel availability check.

In `@deepmd/pt/model/model/__init__.py`:
- Around line 142-156: The spin.use_spin normalization in
deepmd/pt/model/model/__init__.py should validate magnetic indices before
scattering them into the mask. In the branch that handles index lists in the
model parameter processing logic around use_spin and type_map, add an explicit
range check for every index in use_spin so out-of-range values (including
negatives) raise a clear ValueError instead of a raw NumPy IndexError. Keep the
existing symbol-name validation path intact and update the mask assignment only
after the indices are confirmed valid.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 1a30b519-28ad-43fe-835a-90307a68c825

📥 Commits

Reviewing files that changed from the base of the PR and between ffe57a3 and 0d54235.

📒 Files selected for processing (97)
  • deepmd/dpmodel/descriptor/dpa4.py
  • deepmd/dpmodel/descriptor/dpa4_nn/__init__.py
  • deepmd/dpmodel/descriptor/dpa4_nn/activation.py
  • deepmd/dpmodel/descriptor/dpa4_nn/embedding.py
  • deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
  • deepmd/dpmodel/descriptor/dpa4_nn/norm.py
  • deepmd/dpmodel/descriptor/dpa4_nn/so2.py
  • deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py
  • deepmd/kernels/__init__.py
  • deepmd/kernels/cute/__init__.py
  • deepmd/kernels/cute/sezm/__init__.py
  • deepmd/kernels/cute/sezm/backward.py
  • deepmd/kernels/cute/sezm/forward.py
  • deepmd/kernels/cute/sezm/operator.py
  • deepmd/kernels/triton/__init__.py
  • deepmd/kernels/triton/sezm/__init__.py
  • deepmd/kernels/triton/sezm/flash_atten.py
  • deepmd/kernels/triton/sezm/force_assembly.py
  • deepmd/kernels/triton/sezm/radial_mix.py
  • deepmd/kernels/triton/sezm/so2_block_gemm.py
  • deepmd/kernels/triton/sezm/so2_rotation.py
  • deepmd/kernels/triton/sezm/so2_stack_fp16x3.py
  • deepmd/kernels/triton/sezm/so2_value_path.py
  • deepmd/kernels/triton/sezm/sweep_tile_configs.py
  • deepmd/kernels/triton/sezm/tile_config_data.py
  • deepmd/kernels/triton/sezm/tile_configs.py
  • deepmd/kernels/triton/sezm/wigner_monomials.py
  • deepmd/kernels/utils.py
  • deepmd/pt/entrypoints/freeze_pt2.py
  • deepmd/pt/infer/deep_eval.py
  • deepmd/pt/loss/ener_spin.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py
  • deepmd/pt/model/descriptor/sezm.py
  • deepmd/pt/model/descriptor/sezm_nn/__init__.py
  • deepmd/pt/model/descriptor/sezm_nn/activation.py
  • deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py
  • deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py
  • deepmd/pt/model/descriptor/sezm_nn/embedding.py
  • deepmd/pt/model/descriptor/sezm_nn/grid_net.py
  • deepmd/pt/model/descriptor/sezm_nn/norm.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py
  • deepmd/pt/model/descriptor/sezm_nn/utils.py
  • deepmd/pt/model/descriptor/sezm_nn/wignerd.py
  • deepmd/pt/model/model/__init__.py
  • deepmd/pt/model/model/sezm_model.py
  • deepmd/pt/model/model/sezm_native_spin_model.py
  • deepmd/pt/model/model/sezm_property_model.py
  • deepmd/pt/model/model/sezm_spin_model.py
  • deepmd/pt/model/model/spin_model.py
  • deepmd/pt/model/model/transform_output.py
  • deepmd/pt/train/training.py
  • deepmd/pt/train/validation.py
  • deepmd/pt/utils/compile_compat.py
  • deepmd/pt/utils/serialization.py
  • deepmd/pt_expt/descriptor/dpa4.py
  • deepmd/pt_expt/descriptor/dpa4_nn/__init__.py
  • deepmd/pt_expt/descriptor/dpa4_nn/so2.py
  • deepmd/pt_expt/descriptor/dpa4_nn/triton/__init__.py
  • deepmd/pt_expt/descriptor/dpa4_nn/triton/radial_mix.py
  • deepmd/pt_expt/descriptor/dpa4_nn/triton/so2_rotation.py
  • deepmd/pt_expt/descriptor/dpa4_nn/wignerd.py
  • deepmd/pt_expt/infer/deep_eval.py
  • deepmd/pt_expt/utils/edge_schema.py
  • deepmd/utils/argcheck.py
  • deepmd/utils/eval_metrics.py
  • deepmd/utils/spin.py
  • doc/model/dpa4.md
  • examples/spin/dpa4/input-deepspin.json
  • examples/spin/dpa4/input.json
  • examples/spin/dpa4/lmp/README.md
  • examples/spin/dpa4/lmp/in.lammps
  • examples/spin/dpa4/lmp/init.data
  • examples/water/dpa4/README.md
  • examples/water/dpa4/input.json
  • pyproject.toml
  • source/api_cc/include/DeepPot.h
  • source/api_cc/include/DeepPotPTExpt.h
  • source/api_cc/include/DeepSpinPTExpt.h
  • source/api_cc/src/DeepPot.cc
  • source/api_cc/src/DeepPotPTExpt.cc
  • source/api_cc/src/DeepSpinPTExpt.cc
  • source/lmp/pair_deepmd.cpp
  • source/lmp/pair_deepspin.cpp
  • source/op/pt/comm.cc
  • source/tests/common/dpmodel/test_dpa4_so3_projector.py
  • source/tests/common/test_examples.py
  • source/tests/pt/model/test_descriptor_sezm.py
  • source/tests/pt/model/test_descriptor_sezm_triton.py
  • source/tests/pt/model/test_dpa4_dpmodel_parity.py
  • source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py
  • source/tests/pt/model/test_sezm_export.py
  • source/tests/pt/model/test_sezm_model.py
  • source/tests/pt/model/test_sezm_parallel.py
  • source/tests/pt/model/test_sezm_spin_model.py
  • source/tests/pt/test_validation.py
  • source/tests/pt_expt/utils/test_border_op_backward.py
💤 Files with no reviewable changes (6)
  • deepmd/pt/model/descriptor/sezm_nn/cute/init.py
  • deepmd/pt_expt/descriptor/dpa4_nn/triton/init.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/init.py
  • deepmd/pt_expt/descriptor/dpa4_nn/triton/radial_mix.py
  • deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py
  • deepmd/pt/model/descriptor/sezm_nn/utils.py

@OutisLi OutisLi requested a review from wanghan-iapcm July 5, 2026 05:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants