Skip to content

fix(jax): stabilize repflow dynamic selection export#5533

Open
njzjz wants to merge 3 commits into
deepmodeling:masterfrom
njzjz:fix/jax-repflow-static-dynamic-sel
Open

fix(jax): stabilize repflow dynamic selection export#5533
njzjz wants to merge 3 commits into
deepmodeling:masterfrom
njzjz:fix/jax-repflow-static-dynamic-sel

Conversation

@njzjz

@njzjz njzjz commented Jun 14, 2026

Copy link
Copy Markdown
Member

Summary

  • add an internal fixed-capacity dynamic-selection layout for repflows so JAX/jax2tf export avoids runtime-sized edge/angle tensors
  • skip unnecessary bincount in sum-only aggregate calls with a known owner count
  • add regression coverage comparing compact and static dynamic selection outputs

Validation

  • ruff check .
  • ruff format .
  • pytest source/tests/universal/dpmodel/descriptor/test_descriptor.py::TestDPA3StaticDynamicSelDP::test_static_dynamic_sel_matches_packed_dynamic_sel -q
  • dp convert-backend DPA-3.2-5M.pth DPA-3.2-5M.savedmodel

Summary by CodeRabbit

  • New Features
    • Enhanced dynamic neighbor/edge selection using a fixed-capacity layout to improve JAX/JAX-to-export compatibility.
  • Bug Fixes
    • Improved static-dynamic selection to keep edge/angle counts aligned and maintain consistent outputs, including correct handling of padded neighbor slots.
    • Updated export/restoration behavior to preserve the selected layout mode.
  • Chores
    • Refined aggregation to compute bin counts only when needed.
  • Tests
    • Added coverage verifying static-dynamic selection matches packed dynamic behavior.

@coderabbitai

coderabbitai Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a7b8ef9f-76ca-4a54-b504-de5f59abf7e4

📥 Commits

Reviewing files that changed from the base of the PR and between 31516c6 and 6933bca.

📒 Files selected for processing (1)
  • deepmd/dpmodel/descriptor/repflows.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/dpmodel/descriptor/repflows.py

📝 Walkthrough

Walkthrough

Adds a backend-internal _use_static_dynamic_sel flag to DescrptBlockRepflows and RepFlowLayer, enabling a fixed-capacity (padded) execution path for edges and angles as an alternative to compact boolean-mask compaction. A new _get_static_graph_index helper builds the padded index tensors. The JAX subclasses override the flag to True. The aggregate utility is refactored to skip bin_count computation for pure sum reductions. An equivalence test verifies both paths produce matching outputs.

Changes

Static dynamic selection for RepFlows

Layer / File(s) Summary
Flag definition and instance snapshot
deepmd/dpmodel/descriptor/repflows.py
Defines _use_static_dynamic_sel as a class-level bool on both DescrptBlockRepflows and RepFlowLayer. Adds import warnings to support runtime alerts. In DescrptBlockRepflows.__init__, snapshots the class default to the instance and conditionally issues a warning when use_dynamic_sel and static mode are both enabled. RepFlowLayer.__init__ initializes the flag from the class default for the descriptor to override later.
Flag propagation to owned layers and deserialization
deepmd/dpmodel/descriptor/repflows.py
During DescrptBlockRepflows.__init__, propagates the snapped _use_static_dynamic_sel value to each owned RepFlowLayer. After deserialize(), re-copies the restored descriptor's flag to all restored layers to maintain consistent layout mode across the deserialized graph.
_get_static_graph_index helper and call() branching
deepmd/dpmodel/descriptor/repflows.py
Implements _get_static_graph_index returning fixed-capacity edge_index (2×n_edges) and angle_index (3×n_angles) with padded slots. Extends DescrptBlockRepflows.call() to branch on the flag: static path uses _get_static_graph_index and reshapes to flattened fixed capacities with full (j,k) angle gating; compact path keeps get_graph_index with boolean-mask compaction. Updates RepFlowLayer.call() to read n_edge from h2.shape[0] in static mode versus a masked sum in compact mode.
aggregate: conditional bin_count computation
deepmd/dpmodel/utils/network.py
Refactors aggregate to compute bin_count only when num_owner is absent or averaging is requested; skips bin_count for pure sum reductions by setting it to None. Allocates output directly with (num_owner, feature_dim) and asserts bin_count is not None before the averaging divide.
JAX backend override
deepmd/jax/descriptor/repflows.py
Sets _use_static_dynamic_sel = True on both DescrptBlockRepflows and RepFlowLayer JAX subclasses to activate fixed-capacity layout for JAX/jax2tf export.
Equivalence test
source/tests/universal/dpmodel/descriptor/test_descriptor.py
Adds TestDPA3StaticDynamicSelDP with imports for NumPy, DescrptBlockRepflows, and TestCaseSingleFrameWithNlist. Includes a _make_dpa3 helper that toggles _use_static_dynamic_sel at construction time, and a test that asserts packed dynamic and static dynamic outputs match by masking padded slots via nlist != -1.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5355: Both PRs modify deepmd/dpmodel/descriptor/repflows.py in the same core classes (DescrptBlockRepflows/RepFlowLayer) and extend call/forward control-flow with new execution-mode flags, so they overlap at the graph-execution logic level.

Suggested reviewers

  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: stabilizing repflow dynamic selection export for JAX, which is the primary objective of the PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/universal/dpmodel/descriptor/test_descriptor.py (1)

915-933: ⚡ Quick win

Cover the use_loc_mapping=False static-index branch too.

_get_static_graph_index() changes its indexing stride when use_loc_mapping is disabled, but this helper always constructs the default mapped configuration. Extending this test to run both values would cover the second branch of the new static-dynamic layout logic that JAX now enables by default.

♻️ Suggested test expansion
-    def _make_dpa3(self, use_static_dynamic_sel: bool) -> DescrptDPA3:
+    def _make_dpa3(
+        self,
+        use_static_dynamic_sel: bool,
+        *,
+        use_loc_mapping: bool,
+    ) -> DescrptDPA3:
         # The switch is intentionally class-level and internal, so tests toggle
         # it only around construction and then restore the previous backend mode.
         old_use_static_dynamic_sel = DescrptBlockRepflows._use_static_dynamic_sel
         DescrptBlockRepflows._use_static_dynamic_sel = use_static_dynamic_sel
         try:
             return DescrptDPA3(
                 **DescriptorParamDPA3(
                     self.nt,
                     self.rcut,
                     self.rcut_smth,
                     self.sel,
                     ["O", "H"],
                     smooth_edge_update=True,
                     use_dynamic_sel=True,
+                    use_loc_mapping=use_loc_mapping,
                 )
             )
         finally:
             DescrptBlockRepflows._use_static_dynamic_sel = old_use_static_dynamic_sel

     def test_static_dynamic_sel_matches_packed_dynamic_sel(self) -> None:
-        packed = self._make_dpa3(False)
-        static = self._make_dpa3(True)
+        for use_loc_mapping in (True, False):
+            packed = self._make_dpa3(False, use_loc_mapping=use_loc_mapping)
+            static = self._make_dpa3(True, use_loc_mapping=use_loc_mapping)
 
-        packed_out = packed(
-            self.coord_ext,
-            self.atype_ext,
-            self.nlist,
-            mapping=self.mapping,
-        )
-        static_out = static(
-            self.coord_ext,
-            self.atype_ext,
-            self.nlist,
-            mapping=self.mapping,
-        )
+            packed_out = packed(
+                self.coord_ext,
+                self.atype_ext,
+                self.nlist,
+                mapping=self.mapping,
+            )
+            static_out = static(
+                self.coord_ext,
+                self.atype_ext,
+                self.nlist,
+                mapping=self.mapping,
+            )

-        np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol)
-        np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol)
+            np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol)
+            np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol)

-        valid_edge_mask = np.reshape(self.nlist != -1, (-1,))
-        assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel)
-        np.testing.assert_allclose(
-            packed_out[2], static_out[2][valid_edge_mask], atol=self.atol
-        )
-        np.testing.assert_allclose(
-            packed_out[3], static_out[3][valid_edge_mask], atol=self.atol
-        )
-        np.testing.assert_allclose(
-            packed_out[4], static_out[4][valid_edge_mask], atol=self.atol
-        )
+            valid_edge_mask = np.reshape(self.nlist != -1, (-1,))
+            assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel)
+            np.testing.assert_allclose(
+                packed_out[2], static_out[2][valid_edge_mask], atol=self.atol
+            )
+            np.testing.assert_allclose(
+                packed_out[3], static_out[3][valid_edge_mask], atol=self.atol
+            )
+            np.testing.assert_allclose(
+                packed_out[4], static_out[4][valid_edge_mask], atol=self.atol
+            )

Also applies to: 935-968

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/universal/dpmodel/descriptor/test_descriptor.py` around lines
915 - 933, The _make_dpa3 helper method currently only constructs the default
mapped configuration, but _get_static_graph_index() has different indexing
behavior when use_loc_mapping is disabled. Extend the _make_dpa3 method to
accept a parameter for use_loc_mapping (similar to how it accepts
use_static_dynamic_sel) and ensure it applies this parameter when constructing
DescriptorParamDPA3. Then update all test methods that use _make_dpa3 (including
those at lines 935-968) to run test assertions with both use_loc_mapping=True
and use_loc_mapping=False so that both branches of the static-dynamic layout
logic are covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/tests/universal/dpmodel/descriptor/test_descriptor.py`:
- Around line 915-933: The _make_dpa3 helper method currently only constructs
the default mapped configuration, but _get_static_graph_index() has different
indexing behavior when use_loc_mapping is disabled. Extend the _make_dpa3 method
to accept a parameter for use_loc_mapping (similar to how it accepts
use_static_dynamic_sel) and ensure it applies this parameter when constructing
DescriptorParamDPA3. Then update all test methods that use _make_dpa3 (including
those at lines 935-968) to run test assertions with both use_loc_mapping=True
and use_loc_mapping=False so that both branches of the static-dynamic layout
logic are covered.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b5b5ca58-1ff0-4fbd-beea-b662619d0e7b

📥 Commits

Reviewing files that changed from the base of the PR and between c0b0319 and b338bb1.

📒 Files selected for processing (4)
  • deepmd/dpmodel/descriptor/repflows.py
  • deepmd/dpmodel/utils/network.py
  • deepmd/jax/descriptor/repflows.py
  • source/tests/universal/dpmodel/descriptor/test_descriptor.py

@codecov

codecov Bot commented Jun 14, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 87.93103% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.23%. Comparing base (c0b0319) to head (6933bca).
⚠️ Report is 21 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/dpmodel/utils/network.py 30.00% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5533      +/-   ##
==========================================
+ Coverage   82.18%   82.23%   +0.04%     
==========================================
  Files         890      894       +4     
  Lines      101357   102048     +691     
  Branches     4240     4274      +34     
==========================================
+ Hits        83301    83916     +615     
- Misses      16754    16829      +75     
- Partials     1302     1303       +1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz requested review from iProzd and wanghan-iapcm June 14, 2026 17:32
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Jun 16, 2026
@github-actions github-actions Bot removed the Test CUDA Trigger test CUDA workflow label Jun 16, 2026
Comment thread deepmd/jax/descriptor/repflows.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/descriptor/repflows.py`:
- Around line 280-285: In the warnings.warn() call that checks
self.use_dynamic_sel and self._use_static_dynamic_sel, add the stacklevel=2
parameter to the warnings.warn() function invocation. This parameter should be
added as a keyword argument after the warning message string to ensure the
warning points to the caller's instantiation site rather than the internal line
where the warning is issued.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 75695fc0-9c46-4dbf-ad41-383d69efbd80

📥 Commits

Reviewing files that changed from the base of the PR and between b338bb1 and 31516c6.

📒 Files selected for processing (1)
  • deepmd/dpmodel/descriptor/repflows.py

Comment thread deepmd/dpmodel/descriptor/repflows.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <njzjz@qq.com>

@iProzd iProzd left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two non-blocking suggestions:

  1. The a_sel**2 angle materialization is the main caveat. The warnings.warn helps, but it'd be good to also document in the export docs that JAX export materializes nf * nloc * a_sel**2 angle slots, so a_sel should be kept modest for exportable models.
  2. Patch coverage on network.py is low (~30%) — the new aggregate branches (the num_owner concat path and the average=True assert) aren't exercised. A small unit test targeting aggregate directly would prevent regressions.

LGTM otherwise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants