feat(dpmodel): complete DPA4/SeZM SO3 grid projection (mirror current pt)#5555
Conversation
…) for DPA4 SO3 grid
… for DPA4 SO3 grid
for more information, see https://pre-commit.ci
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughPorts the previously unimplemented SO(3)/S2 cross-mode grid paths in the DPA4 dpmodel descriptor. Adds ChangesDPA4 dpmodel SO(3) grid port
Sequence Diagram(s)sequenceDiagram
participant Caller
participant EquivariantFFN
participant SO3GridNet
participant BaseGridNet
participant FrameExpand
participant GridOp as GridMLP/GridBranch
participant FrameContract
Caller->>EquivariantFFN: call(x_coeffs)
EquivariantFFN->>SO3GridNet: act(x_coeffs) [ffn_so3_grid=True]
SO3GridNet->>BaseGridNet: call(query, context=None)
BaseGridNet->>FrameExpand: expand packed frames
FrameExpand-->>BaseGridNet: expanded coefficients
BaseGridNet->>GridOp: to_grid / quadratic product / from_grid
GridOp-->>BaseGridNet: grid op output
BaseGridNet->>FrameContract: contract back to channel dim
FrameContract-->>BaseGridNet: contracted output
BaseGridNet-->>SO3GridNet: result + residual_scale
SO3GridNet-->>EquivariantFFN: activated features
EquivariantFFN-->>Caller: updated x_coeffs
sequenceDiagram
participant SO2Conv as SO2Convolution.call
participant NodeWise as node_wise_grid_product
participant AttnAgg as attention aggregation
participant MsgNode as message_node_grid_product
SO2Conv->>NodeWise: (x_dst_local, x_local) → residual [if enabled]
NodeWise-->>SO2Conv: add to x_local before SO(2) focus
SO2Conv->>AttnAgg: SO(2) convolution + edge aggregation → out
SO2Conv->>MsgNode: (out, node features) → residual [if enabled]
MsgNode-->>SO2Conv: add to out before final channel mixing
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
source/tests/common/dpmodel/test_dpa4_so3_gridnet.py (1)
333-338: ⚡ Quick winAssert both frame-mixer variables are absent in self mode.
test_so3_serialize_roundtriponly checks"frame_expand.weight"is absent formode="self". Add the symmetric check for"frame_contract.weight"to prevent silent schema regressions.Suggested patch
if mode == "cross": assert "frame_expand.weight" in data["`@variables`"] assert "frame_contract.weight" in data["`@variables`"] else: assert "frame_expand.weight" not in data["`@variables`"] + assert "frame_contract.weight" not in data["`@variables`"]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/common/dpmodel/test_dpa4_so3_gridnet.py` around lines 333 - 338, In the test_so3_serialize_roundtrip function, the else block (which handles mode="self") only asserts that "frame_expand.weight" is absent from data["`@variables`"] but is missing the symmetric check for "frame_contract.weight". Add an additional assert statement in the else block to verify that "frame_contract.weight" is also not in data["`@variables`"], mirroring the structure of the cross mode block which checks both variables.source/tests/common/dpmodel/test_dpa4_so2_grid.py (1)
194-197: ⚡ Quick winAdd a non-trivial-output guard to avoid vacuous parity passes.
_assert_conv_paritycurrently only checks DP/PT closeness. Add a simple magnitude guard so parity can’t pass if both outputs collapse to (near) zero.Suggested patch
out_dp = dp_mod.call(x, dp_cache, radial) out_pt = pt_mod(_to_pt(x), pt_cache, _to_pt(radial_valid)) + assert np.max(np.abs(np.asarray(out_dp))) > 1e-8 _assert_parity(out_dp, out_pt, rtol=rtol, atol=atol)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/common/dpmodel/test_dpa4_so2_grid.py` around lines 194 - 197, The _assert_parity function currently only checks if DP and PT outputs are close to each other, which means the test can pass vacuously if both outputs collapse to near-zero values. Add a magnitude guard within the _assert_parity function that verifies at least one of the outputs has a non-trivial magnitude (e.g., check that the absolute maximum value or norm of the outputs exceeds a small threshold) before asserting their closeness, ensuring the parity test only passes when the outputs are both meaningful and close to each other.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py`:
- Around line 1532-1544: The code in the class method (around the
projector_config assignment) is accessing only the `["config"]` portion of the
projector object, bypassing validation of the nested `@class` and `@version`
schema fields. Before extracting and using projector_config, validate the full
projector object structure (the complete config["projector"] payload) to ensure
it conforms to the expected schema and version, preventing acceptance of wrong
or incompatible projector schemas when rebuilding the object.
- Around line 950-956: The operands query and context_ndfc are being passed to
self.frame_expand() without first casting them to compute_dtype, unlike
scalar_pair which is properly cast on lines 952-953. This causes dtype
mismatches when fp32 inputs flow through float64 grid nets. Cast both query and
context_ndfc to compute_dtype (using xp.astype) before passing them to the
self.frame_expand() calls in the return statement, similar to how scalar_pair is
cast to compute_dtype.
In `@deepmd/dpmodel/descriptor/dpa4_nn/so2.py`:
- Around line 1827-1841: The deserialization of `node_wise_grid_product` and
`message_node_grid_product` does not validate that the template contains only
expected keys before calling `deserialize()`. After assigning the result of
`sub_vars()` to the `@variables` key in the template, add validation to reject
any unexpected or unknown top-level keys in the template before passing it to
the `deserialize()` method for both `node_wise_grid_product` and
`message_node_grid_product`. This prevents schema drift keys from being silently
ignored during the deserialization process.
---
Nitpick comments:
In `@source/tests/common/dpmodel/test_dpa4_so2_grid.py`:
- Around line 194-197: The _assert_parity function currently only checks if DP
and PT outputs are close to each other, which means the test can pass vacuously
if both outputs collapse to near-zero values. Add a magnitude guard within the
_assert_parity function that verifies at least one of the outputs has a
non-trivial magnitude (e.g., check that the absolute maximum value or norm of
the outputs exceeds a small threshold) before asserting their closeness,
ensuring the parity test only passes when the outputs are both meaningful and
close to each other.
In `@source/tests/common/dpmodel/test_dpa4_so3_gridnet.py`:
- Around line 333-338: In the test_so3_serialize_roundtrip function, the else
block (which handles mode="self") only asserts that "frame_expand.weight" is
absent from data["`@variables`"] but is missing the symmetric check for
"frame_contract.weight". Add an additional assert statement in the else block to
verify that "frame_contract.weight" is also not in data["`@variables`"], mirroring
the structure of the cross mode block which checks both variables.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 8516e16a-0724-4162-a7a7-f1167567c8f4
📒 Files selected for processing (19)
deepmd/dpmodel/descriptor/dpa4_nn/block.pydeepmd/dpmodel/descriptor/dpa4_nn/ffn.pydeepmd/dpmodel/descriptor/dpa4_nn/grid_net.pydeepmd/dpmodel/descriptor/dpa4_nn/projection.pydeepmd/dpmodel/descriptor/dpa4_nn/so2.pysource/tests/common/dpmodel/test_descrpt_dpa4.pysource/tests/common/dpmodel/test_dpa4_basegridnet_cross.pysource/tests/common/dpmodel/test_dpa4_ffn_so3.pysource/tests/common/dpmodel/test_dpa4_frame_mixers.pysource/tests/common/dpmodel/test_dpa4_grid_descriptor.pysource/tests/common/dpmodel/test_dpa4_gridbranch_frames.pysource/tests/common/dpmodel/test_dpa4_gridmlp_frames.pysource/tests/common/dpmodel/test_dpa4_project_frames.pysource/tests/common/dpmodel/test_dpa4_so2_grid.pysource/tests/common/dpmodel/test_dpa4_so3_grid_utils.pysource/tests/common/dpmodel/test_dpa4_so3_gridnet.pysource/tests/common/dpmodel/test_dpa4_so3_projector.pysource/tests/consistent/descriptor/test_dpa4.pysource/tests/pt/model/test_dpa4_dpmodel_parity.py
…ng#5555) - BaseGridNet cross-mode: lift query/context to compute_dtype before FrameExpand so the frame expansion runs in the net's precision (mirrors pt's fp64-weight-forced FrameExpand); _FrameMixer otherwise casts weights down to the operand dtype, expanding fp32 inputs in fp32. (CodeRabbit) - SO3GridNet.deserialize: validate the nested projector @class/@Version instead of blindly reading config["projector"]["config"]. (CodeRabbit) - SO2Convolution.deserialize: reject schema-drift keys under the node_wise/message_node grid-product prefixes (loaded @variables key set must match the fresh template). (CodeRabbit) - drop unused n_valid local in test_dpa4_so2_grid.py. (CodeQL) Tests: mixed-precision cross run, deserialize rejects bad projector @class/@Version, deserialize rejects grid-product drift key.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py`:
- Around line 1541-1548: The version key for the nested projector is being
accessed with a default value of 1 using get("`@version`", 1), which allows
malformed projector payloads without a version to silently pass validation.
Remove the default value parameter and add explicit validation to require the
`@version` key to be present in projector_data before calling
check_version_compatibility(). Raise a ValueError with a descriptive message if
the `@version` key is missing, following the same pattern as the existing `@class`
validation.
In `@source/tests/common/dpmodel/test_dpa4_so3_gridnet.py`:
- Around line 472-473: The test for invalid version handling in
DPSO3GridNet.deserialize uses overly broad exception handling with
pytest.raises(Exception), which can mask unrelated failures and weakens test
specificity. Replace Exception with the concrete exception type that the
deserialize method actually raises for invalid versions (examine the method
implementation to determine if it's ValueError, RuntimeError, or another
specific exception type), and optionally add a message parameter to match
against the error message fragment for even more precise assertion.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ca2ea555-f9dd-49f9-8000-5426bb8cb238
📒 Files selected for processing (4)
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.pydeepmd/dpmodel/descriptor/dpa4_nn/so2.pysource/tests/common/dpmodel/test_dpa4_so2_grid.pysource/tests/common/dpmodel/test_dpa4_so3_gridnet.py
🚧 Files skipped from review as they are similar to previous changes (2)
- source/tests/common/dpmodel/test_dpa4_so2_grid.py
- deepmd/dpmodel/descriptor/dpa4_nn/so2.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5555 +/- ##
==========================================
+ Coverage 82.16% 82.22% +0.05%
==========================================
Files 896 896
Lines 102643 103018 +375
Branches 4340 4340
==========================================
+ Hits 84341 84708 +367
- Misses 16965 16972 +7
- Partials 1337 1338 +1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
…eepmodeling#5555) Follow-up to CodeRabbit re-review: - SO3GridNet.deserialize now requires the nested projector @Version key (was silently defaulting a missing version to 1). - the version test asserts ValueError(match="version") instead of a blind Exception (ruff B017), and adds a missing-@Version case.
## What Adds the DPA4/SeZM `so3_readout` option (`"none"` / `"glu"` / `"mlp"`) **across all backends** (pt + dpmodel + pt_expt), making it cross-backend consistent. Builds on **#5556** (@OutisLi): its pt `so3_readout` commit (`refactor(dpa4): output ffn`) is included here with original authorship preserved; this PR adds the missing **dpmodel** counterpart so the shared DPA4 serialize format round-trips across backends. (The unrelated `nv_nlist`/compiler-check fixes from #5556 are intentionally left to #5556.) ## Why `so3_readout` is implemented by configuring the final output FFN — `"glu"`/`"mlp"` turn on the SO(3)-grid FFN (`ffn_so3_grid`, `grid_mlp`). On its own (pt-only, as in #5556) it breaks DPA4 cross-backend consistency: pt `serialize()` emits `so3_readout` but dpmodel `DescrptDPA4` couldn't round-trip it → `source/tests/consistent/descriptor/test_dpa4.py::...::test_pt_consistent_with_ref` failed on every Test Python shard. This is now feasible and small because **#5555** already ported `ffn_so3_grid` + the SO(3)-grid machinery (`SO3GridNet`/`GridMLP`/`GridProduct`) into the dpmodel `EquivariantFFN`. So the dpmodel `so3_readout` is just: accept the param, configure `output_ffn` exactly like pt, wire the readout (l=0 slice for `"none"`; full `(N,D,1,C)` fold for `"glu"`/`"mlp"` then slice l=0), and serialize the key. pt_expt auto-wraps. ## Changes - **pt** (from #5556): `so3_readout` in `DescrptSeZM` + argcheck + `examples/water/dpa4/input.json`. - **dpmodel** `descriptor/dpa4.py`: `so3_readout` param + validation; `output_ffn` configured (`lmax=node_l_schedule[-1]`, `kmax=min(kmax, readout_lmax)`, `ffn_so3_grid`, `grid_mlp`, `grid_branch=0`); readout forward mirrors pt; serialize the key. - **pt_expt**: auto-wrapped (no explicit change). ## Validation - `test_dpa4.py` cross-backend consistency rows for `so3_readout ∈ {none, glu, mlp}` (pt vs dpmodel vs pt_expt, mixed_types) — green; `test_pt_consistent_with_ref` now passes. - Full-descriptor pt→dpmodel parity (weight-copied, `glu`+`mlp`) — **~7e-15 abs** (gate 1e-10), proving serialize interop. - dpa4 suite: 611 passed; ruff clean. ## Notes - Depends on #5555 (merged) for the dpmodel SO(3)-grid FFN. - `so3_readout` no longer "(Supported Backend: PyTorch)" — now multi-backend. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added a configurable `so3_readout` option to the DPA4 and SeZM descriptors (modes: `"none"`, `"glu"`, `"mlp"`), controlling how the final SO(3) readout is computed. * The setting is included in descriptor configuration serialization/deserialization to support round-tripping. * Updated the water DPA4 example to use `so3_readout: "mlp"`. * **Tests** * Added tests covering multiple readout modes, backend parity between implementations, and correct behavior for edge-free scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: OutisLi <LTC201806070316@gmail.com> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
What
Completes the DPA4/SeZM SO3 grid projection port to the dpmodel backend so it faithfully mirrors master's current pt
sezm_nn/grid_net.py. After this, the flagshipexamples/water/dpa4/input.json(which setsffn_so3_grid=true,message_node_so3=true,grid_mlp) runs on dpmodel/pt_expt.Builds on top of the S2-grid base that #5517/#5552 landed (
GridProduct/GridMLP/op_type='mlp', the_project_framesrefactor). Master's dpmodel was the S2 (n_frames==1) slice with SO3/cross-mode fail-fast guarded; this PR generalizes those ops to frame-aware (n_frames>1) + cross-mode and adds the missing SO3 pieces — matching current pt exactly (single source of truth: dpmodel == pt).Supersedes #5547 (which ported the pre-#5552 design and went structurally stale).
Changes (all mirror current pt)
grid_net.py: add_project_frames; generalizeGridMLP/GridBranchto frame-aware (n_frames); generalizeBaseGridNet(un-guardmode='cross',layout='flat',residual_scale_init,n_frames>1; frame-axis to/from-grid viaxp.matmul+reshape); addFrameContract/FrameExpand/_build_frame_degree_index; addSO3GridNet(self+cross).projection.py: addSO3GridProjector(Wigner-D quadrature) +resolve_so3_grid/_build_so3_frame_set.ffn.py: un-guardffn_so3_grid→SO3GridNet(mode='self').so2.py: un-guardnode_wise_{s2,so3}/message_node_{s2,so3}→ cross-mode grid products, applied incall+ round-tripped in serialize.Validation
_project_frames,GridMLP/GridBranch(incl. S2 byte-identical regression),BaseGridNetcross/flat/residual,FrameContract/FrameExpand,SO3GridProjectormatrices,SO3GridNetself+cross (op_type glu/mlp/branch, kmax 1&2) — all 1e-12; rotation equivariance 1e-10.DescrptDPA4.deserialize(pt.serialize())on the example config (lmax=3, mmax=1) — ~1e-14 — provingdp convert-backendschema interop.pt_expt forward works today via auto-wrap (consistency + descriptor trio green) — no explicit registration needed.
Known limitations
torch.export/AOTI grid coverage, training e2e, argcheckdoc_only_pt_supportedremoval, and freeze/DeepEval are a follow-up PR.grid_method='e3nn'(non-Lebedev product grid) stays fail-fast (Lebedev-only, per parent design).Summary by CodeRabbit