perf(dpa4): opt so3grid (with pt_expt GridProduct wrapping fix)#5552
Conversation
The parameter-free `GridProduct` NativeOP (added for the so3grid
optimization) has no `serialize`/`deserialize` and is not registered via
`register_dpmodel_mapping`. The pt_expt backend auto-wraps every dpmodel
NativeOP sub-component through `_auto_wrap_native_op`, which requires the
op to be serializable (or registered) to build its dynamic torch wrapper;
otherwise it raises:
TypeError: Cannot auto-wrap GridProduct: it must implement
serialize()/deserialize() or be explicitly registered via
register_dpmodel_mapping().
This broke every `Test Python` shard that loads a DPA4 pt_expt model
(e.g. test_get_model_dpa4.py). Add trivial `serialize`/`deserialize`
(no state, mirroring the GridMLP @class/@Version convention) so the op
auto-wraps cleanly.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (5)
📝 WalkthroughWalkthroughPorts ChangesPort GridMLP and refactor grid-op projector contract
Sequence Diagram(s)sequenceDiagram
participant EquivariantFFN
participant S2GridNet
participant BaseGridNet
participant GridOp as GridProduct/GridMLP/GridBranch
participant _to_grid
participant _from_grid
EquivariantFFN->>S2GridNet: forward(left, right, scalar_pair)
S2GridNet->>BaseGridNet: call(left, right, scalar_pair)
BaseGridNet->>BaseGridNet: _apply_grid_op(left, right, scalar_pair)
BaseGridNet->>GridOp: call(left, right, scalar_pair, to_grid=_to_grid, from_grid=_from_grid)
GridOp->>_to_grid: project coefficients to grid space
_to_grid-->>GridOp: grid tensor (channel width inferred from shape)
GridOp->>GridOp: quadratic product / MLP / branch routing in grid space
GridOp->>_from_grid: project grid back to coefficient space
_from_grid-->>GridOp: coeff_out
GridOp-->>BaseGridNet: coeff_out
BaseGridNet-->>S2GridNet: coeff_out
S2GridNet-->>EquivariantFFN: output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5552 +/- ##
==========================================
+ Coverage 82.21% 82.24% +0.02%
==========================================
Files 892 894 +2
Lines 101532 102084 +552
Branches 4240 4276 +36
==========================================
+ Hits 83475 83955 +480
- Misses 16753 16828 +75
+ Partials 1304 1301 -3 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Based on #5517 (
perf(dpa4): opt so3gridby @OutisLi) — this branch contains all of its commits plus one fix commit that addresses the CI failures on that PR.Problem
#5517 introduces a new parameter-free
GridProductNativeOPindeepmd/dpmodel/descriptor/dpa4_nn/grid_net.pyfor the so3grid optimization, but it has noserialize/deserializeand is not registered viaregister_dpmodel_mapping. The pt_expt backend auto-wraps every dpmodelNativeOPsub-component through_auto_wrap_native_op, which requires the op to be serializable (or registered) to build its dynamic torch wrapper. Otherwise it raises:This broke every
Test Pythonshard that loads a DPA4 pt_expt model (e.g.source/tests/pt_expt/model/test_get_model_dpa4.py::TestGetModelDPA4::test_pair_exclude_types_from_descriptor) on #5517.Fix
Add trivial
serialize/deserializetoGridProduct(no state — mirrors theGridMLP@class/@versionconvention)._auto_wrap_native_opthen passes itshasattr(value, "serialize")guard and returnswrapped_cls.deserialize(value.serialize())cleanly.Notes
GridMLP(also new in perf(dpa4): opt so3grid #5517) already implementsserialize/deserialize; only the parameter-freeGridProductwas missing them._auto_wrap_native_opcode path (deepmd/pt_expt/common.py:138-170); the actual pt_expt DPA4 test runs in CI here.Summary by CodeRabbit
Release Notes
Refactor
Documentation
Tests