feat(gemma4): context parallelism for dense E2B/E4B#2621
Merged
Conversation
Wire context parallelism into the plain-dense Gemma4 31B (google/gemma-4-31B-it) using the model-owned p2p ring flex_attention already used by the 26B MoE path. The ring installs from within the gemma4 model: prepare_model_inputs_for_cp attaches Gemma4's own _cp_shard_batch as the batch's _cp_make_batch_fn, which receives the CP submesh from cp_utils.make_cp_batch_and_ctx, idempotently installs the ring (setup_cp_attention), then shards the batch. - get_capabilities: plain-dense 31B now supports_cp=True. E2B/E4B (dense+audio) left unsupported -- kv-sharing + per-layer-inputs are not yet wired through the ring; tracked separately. - Dense forward stashes CP-sharded vision/packed metadata on the ring-hooked self_attn modules so the ring builds the correct bidirectional/packed mask. - Adds 31B recipe: tulu3 text-only 16k (cp8). Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Cover the new dense-CP lines added with this PR: - cp_attention: _patch_fsdp_accumulated_grad_guard (skip uninitialized param, run original when present, idempotent, no-op on import failure) and the _pre_hook fallback to module._cp_dense_metadata (plus kwarg precedence). - model: get_capabilities for plain-dense 31B (cp on), dense+audio E2B/E4B (cp off), and MoE (cp+ep); dense __init__ ring attach; model-owned setup_cp_attention fan-out + idempotency; _cp_shard_batch install+delegate; prepare_model_inputs_for_cp attaching the bound batch fn; dense CP forward stashing CP-sharded metadata on the ring-hooked self_attn modules. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Extend the model-owned p2p ring context-parallelism (built for plain-dense
gemma4 31B) to the two small dense gemma4 checkpoints, google/gemma-4-E2B-it
and google/gemma-4-E4B-it (the dense+audio variant). These differ from 31B in
two ways that were investigated and confirmed to flow through the existing CP
path with no change to the ring or batch sharder:
* per-layer inputs (hidden_size_per_layer_input=256): computed on the full
sequence in prepare_model_inputs_for_cp and sharded contiguously on the
seq dim alongside inputs_embeds (the 4D per_layer_inputs tensor is already
a known key in cp_batch). Verified the CP-prep inputs_embeds and projected
per-layer-inputs are bit-identical (MAXDIFF=0) to HF's native prep.
* KV-sharing (E2B shares K/V across the last 20 layers, E4B the last 18):
when the KV cache is active a shared layer reads its source layer's
CP-local sharded K/V from DynamicCache.shared_layers and rotates it through
the same ring as any other K/V. Under the activation checkpointing that CP
training requires, HF disables the cache, so shared layers recompute their
own K/V -- identical between CP and non-CP, so parity holds either way (the
dead shared-layer K/V projections are frozen by
freeze_unused_kv_sharing_params).
Only get_capabilities is changed (dense+audio variant flipped to
supports_cp=True) plus an accurate docstring; cp_attention.py, cp_batch.py and
the common infrastructure.py / cp_utils.py are untouched. TP stays OFF for this
variant: HF's Gemma4Model.forward builds per-layer inputs via
torch.where(multimodal_mask, pad_embedding, inputs_embeds) where pad_embedding
is sliced from the (TP-sharded) embedding weight, which raises a mixed
torch.Tensor/DTensor error under DTensor -- an HF-side limitation not fixable
without patching frozen transformers source.
Validation (single node, attn=sdpa, activation_checkpointing=true, bf16,
deterministic text mock, cp1 vs cp2 at matching steps, full layers):
E2B lr=0: loss 20.2786 vs 20.2964 (Dloss 1.8e-2); grad_norm 3650 vs 2846
E2B lr=1e-5: step1 12.83 vs 12.85, step5 2.86 vs 3.08, step14 0.038 vs 0.041
E4B lr=0: loss 18.8539 vs 18.8965 (Dloss 4.3e-2); grad_norm 1178 vs 1168 (<1%)
E4B lr=1e-5: step1 8.19 vs 8.18, step5 0.37 vs 0.32, step9 0.021 vs 0.022
No growing drift; deltas stay ~1e-2 at normal lr.
CAVEAT: CP for E2B/E4B is validated only UNDER activation_checkpointing (cache
off -> shared-KV layers recompute their own K/V). The cache-ON
kv-sharing-through-the-ring path is reasoned-correct but NOT empirically
validated. This is acceptable scope since CP training always uses activation
checkpointing, but it is documented rather than hidden.
Adds tests/unit_tests/models/gemma4/test_gemma4_2b4b_cp.py (per_layer_inputs
4D seq-sharding in cp_batch, kv-shared source-layer resolution, capability flip,
per-layer-input threading) and updates the dense+audio capability assertion in
test_gemma4_model_cp.py.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Add a working E4B (small dense 4B) 64k sequence-length recipe: Tulu-3 text, sequence-packed to 64k, single 8-GPU node, dp1xcp8 (CP shards 64k -> 8k/rank). Peak ~32 GiB/GPU; validated 10 steps, loss descends 12.6 -> 4.6, no OOM. Three E4B-specific settings (each validated, vs the 31B tulu3 recipe): - dataset.inject_fake_images=false: the VLM pipeline injects a fake image into text-only samples to keep the (frozen) vision tower in the FSDP graph; E4B's vision patch-embedder then builds a ~54 GiB position one-hot and OOMs. The vision tower is frozen, so the fake images are unnecessary -> disable. - No thinking-prefix hook: E4B is not a reasoning model and cannot predict the injected <|channel>thought tokens (plain assistant loss ~1.5 vs blown-up with thinking); the 31B is a reasoning model and is the opposite. - FusedLinearCrossEntropy: avoids the [S, 262k-vocab] logit spike at 64k. E4B forward verified bit-identical to HuggingFace (logit parity max|diff|=0). TP unavailable for E2B/E4B (HF per-layer-input merge crashes under DTensor), so scale with dp_shard x cp only. Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
E2B/E4B share K/V across their trailing num_kv_shared_layers layers: each shared layer reads its source layer's K/V from past_key_values.shared_layers (HF Gemma4Attention.forward). HF gates that read on `past_key_values is not None`, so with use_cache=False -- which the CP path uses -- the shared layers fall back to their frozen/unused K/V projections and emit garbage, inflating the loss ~4x (step-0 ~12 vs ~3). This is why E4B CP runs trained at the wrong loss. Fix: inject a minimal cache-free _Gemma4KVShareHolder as past_key_values in the dense forward so HF's own kv-sharing fires (source layers populate shared_layers, shared layers read it) with no per-token cache accumulation. Gated strictly to kv-sharing variants (num_kv_shared_layers>0); 31B / MoE / other dense models get past_key_values=None exactly as before -- no behavior change, no regression to the flex-ring CP path. Validated: E4B + use_cache=False + holder reproduces the use_cache=True reference loss bit-for-bit (3.36 medpix); E4B cp8 (ring) text SFT now starts at the correct ~1-2 loss (was ~12). Note: kv-sharing is incompatible with activation checkpointing (shared read breaks under recompute), so the E4B recipe runs with activation_checkpointing=false, capping single-node context at ~32k (cp8, ~54 GiB). Adds recipe gemma4_e4b_tulu3_text_cp8_32k.yaml; removes the 64k variant (does not fit without AC on a single node). Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…mpiles The Gemma4 model-owned CP ring rebuilt create_block_mask (~7.6ms each) once per layer per ring chunk -- 168 builds/step for E4B's 42 layers -- and the compiled flex_attention recompiled every step on variable sequence lengths (duck-shape specialization guarding on block_mask.kv_indices.size). On short sequences these two costs dominated the step (~66% in create_block_mask). - Disable torch.fx use_duck_shape before compiling flex_attention so variable seqlens stop triggering per-step recompiles (no-op for fixed-length runs). - Cache block masks on position scalars only (never tensor storage -- the CUDA allocator recycles addresses), reset per batch via the metadata data_ptr while pinning the tensor so the address cannot recycle mid-step. The mask is fully position-determined within a step, so all same-type layers, both ring chunks, and forward+backward share one entry: 168 -> 4 create_block_mask builds/step. Loss is bit-identical before/after (single-sample and batched-mask paths). On medpix cp2 single-stream tps goes 22 -> ~70/gpu; the 64k packed recipe is unaffected in correctness and benefits from the reduced mask builds. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The single-node cp8 64k recipe (added in a46ff01, removed in 067b9e7) OOMs without activation checkpointing, which E4B's kv-sharing forbids. The validated 64k config shards across 2 nodes (16 GPUs, dp1 x cp16 -> 4k tokens/rank, the same local footprint as the single-node cp8/32k recipe, ~54 GiB/GPU). Step-0 loss ~1, matching the use_cache=true reference. Per-recipe slurm launchers are not tracked (repo keeps only the generic slurm.sub); see the 2-node sbatch in the working tree for invocation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Add unit tests for the block-mask cache helpers (_cached_block_mask, _block_mask_set_generation): build-once-then-hit, eviction above the cap, clear-only-on-new-data_ptr, and the None generation staying stable. Also sort the lazy imports in _compiled_flex_attention and add the missing blank lines flagged by ruff. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/ok to test 6febcb5 |
Remove gemma4_e4b_tulu3_text_cp8_32k.yaml and inline its E4B-settings rationale into the 2-node cp16/64k recipe (which previously referenced it), so the 64k recipe is self-documenting. The 64k/2-node recipe is the validated E4B CP config. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/ok to test 69c180c |
…ers ver L0_Unit_Tests_CPU surfaced two env-specific failures in the dense-CP tests: - 3 contiguous-shard tests hit `_FakeCPMesh has no get_group` because the full suite leaves torch.distributed initialized (test-order pollution), pushing the shard code down its distributed branch. Add an autouse fixture pinning `is_initialized` False so these single-process tests deterministically take the non-distributed branch. Verified by reproducing with a real gloo group. - test_kv_shared_layers_resolve_source_layer asserted HF's internal `kv_shared_layer_index`, which exists in transformers 5.5 but not 5.8 (CI). Guard it with getattr; still assert the version-stable `is_kv_shared_layer`. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/ok to test 0340a2f |
import_linting failed in setup (`uv sync` network timeout downloading torch==2.10.0+cu130, UV_HTTP_TIMEOUT 30s) — same transient infra flake that hit linting earlier and cleared on retry. No-op commit to re-run. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/ok to test eca19f4 |
Brings in the merged dense-31B CP work (PR #2592, squash-merged). Resolved conflicts in favor of the E2B/E4B CP enablement this branch adds: - get_capabilities dense+audio variant -> supports_cp=True (was the 31B's "CP not yet supported" placeholder). - dense forward injects the cache-free _Gemma4KVShareHolder for kv-sharing. - kept the matching unit tests (capabilities + holder + non-CP forward). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/ok to test d3437bd |
The CP dense-forward holder injection (model.py: past_key_values = _Gemma4KVShareHolder() under cp_enabled + kv-sharing) was the one new line not hit by the CPU unit tests. Add test_forward_dense_cp_injects_kv_share_holder (kv-sharing config + CP forward, mocked language_model) to assert the holder is threaded through — closing the patch-coverage gap (both changed files now 100% on new lines). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/ok to test 07adf3d |
1 similar comment
Contributor
Author
|
/ok to test 07adf3d |
The prior L0_Unit_Tests_GPU hung (30-min timeout x3 retries) in the unowned gemma4_drafter/test_composite_fsdp2.py NCCL mp.spawn test. Not reproducible locally under CI-matching conditions (transformers 5.8.1, 2 GPUs, coverage, full gemma4+drafter order: 300 passed) — consistent with a degraded-runner flake. No-op commit to re-run on a fresh runner. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/ok to test a34cb4d |
akoumpa
approved these changes
Jun 19, 2026
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enables context parallelism for the small dense Gemma4 models (E2B/E4B) on top of the model-owned flex-attention ring, plus a correctness fix and a throughput optimization.
_Gemma4KVShareHolder): E2B/E4B share K/V across their trailing layers, gated bypast_key_values is not None. Underuse_cache=false(forced by CP) that read falls back to dead K/V projections and inflates loss ~4×. A cache-free holder restores HF's kv-sharing without a per-token cache (gated to kv-sharing variants only, so 31B/MoE are unaffected).create_block_maskresults (168 → ~4 builds/step; all same-type layers + both chunks + fwd/bwd share one entry, keyed on position scalars, reset per batch via a pinneddata_ptrgeneration), and disable flexuse_duck_shapeso variable-length batches stop recompiling every step. Bit-identical loss; medpix cp2 single-stream ~22 → ~70 tps/gpu.gemma4_e4b_tulu3_text_cp16_64k.yaml— the validated 2-node config (dp1×cp16, 64k packed → 4k tokens/rank, ~54 GiB/GPU, step-0 loss ~1). Self-documents all E4B-specific settings.Note: this branch is stacked on the dense-31B ring CP work (PR #2592); since this PR targets
main, the 31B commits (incl.gemma4_31b_tulu3_text_cp8_16k.yaml) appear here until #2592 merges.Validation runs (W&B)
gemma4_e4b_tulu3_text_cp16_64k.yaml, 100 steps, loss ~1.02, matching theuse_cache=truereference: https://wandb.ai/Nemo-automodel/huiyingl_workspace/runs/0ubhxrs6Test plan
tests/unit_tests/models/gemma4/test_cp_attention.py— CPU-testable ring helpers: base/sliding/global masks, ring chunk collection + rotation, flex-chunk packed/padding/vision branches, GQA, and the block-mask cache (build-once/hit, eviction cap, clear-on-new-data_ptr, None-stable). 51 passing.test_gemma4_2b4b_cp.py,test_gemma4_model_cp.py,test_kv_sharing_freeze.py— model-level CP + kv-share paths.🤖 Generated with Claude Code