Skip to content

feat(gemma4): context parallelism for dense E2B/E4B#2621

Merged
HuiyingLi merged 17 commits into
mainfrom
huiyingl/feat/gemma4-2b-4b-dense-cp
Jun 19, 2026
Merged

feat(gemma4): context parallelism for dense E2B/E4B#2621
HuiyingLi merged 17 commits into
mainfrom
huiyingl/feat/gemma4-2b-4b-dense-cp

Conversation

@HuiyingLi

@HuiyingLi HuiyingLi commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Summary

Enables context parallelism for the small dense Gemma4 models (E2B/E4B) on top of the model-owned flex-attention ring, plus a correctness fix and a throughput optimization.

  • E2B/E4B dense CP: route the dense path through the model-owned p2p ring CP attention (contiguous sequence shard + flex ring), threading the per-token CP metadata (packing / padding / vision-group ids) into the mask builder.
  • kv-sharing fix under CP (_Gemma4KVShareHolder): E2B/E4B share K/V across their trailing layers, gated by past_key_values is not None. Under use_cache=false (forced by CP) that read falls back to dead K/V projections and inflates loss ~4×. A cache-free holder restores HF's kv-sharing without a per-token cache (gated to kv-sharing variants only, so 31B/MoE are unaffected).
  • Perf: cache the ring's create_block_mask results (168 → ~4 builds/step; all same-type layers + both chunks + fwd/bwd share one entry, keyed on position scalars, reset per batch via a pinned data_ptr generation), and disable flex use_duck_shape so variable-length batches stop recompiling every step. Bit-identical loss; medpix cp2 single-stream ~22 → ~70 tps/gpu.
  • Recipe: gemma4_e4b_tulu3_text_cp16_64k.yaml — the validated 2-node config (dp1×cp16, 64k packed → 4k tokens/rank, ~54 GiB/GPU, step-0 loss ~1). Self-documents all E4B-specific settings.

Note: this branch is stacked on the dense-31B ring CP work (PR #2592); since this PR targets main, the 31B commits (incl. gemma4_31b_tulu3_text_cp8_16k.yaml) appear here until #2592 merges.

Validation runs (W&B)

image image

Test plan

  • tests/unit_tests/models/gemma4/test_cp_attention.py — CPU-testable ring helpers: base/sliding/global masks, ring chunk collection + rotation, flex-chunk packed/padding/vision branches, GQA, and the block-mask cache (build-once/hit, eviction cap, clear-on-new-data_ptr, None-stable). 51 passing.
  • test_gemma4_2b4b_cp.py, test_gemma4_model_cp.py, test_kv_sharing_freeze.py — model-level CP + kv-share paths.
  • Numerical validation is the W&B runs above (integration, not in unit CI — needs multi-GPU).

🤖 Generated with Claude Code

HuiyingLi and others added 8 commits June 16, 2026 06:55
Wire context parallelism into the plain-dense Gemma4 31B (google/gemma-4-31B-it)
using the model-owned p2p ring flex_attention already used by the 26B MoE path.

The ring installs from within the gemma4 model: prepare_model_inputs_for_cp
attaches Gemma4's own _cp_shard_batch as the batch's _cp_make_batch_fn, which
receives the CP submesh from cp_utils.make_cp_batch_and_ctx, idempotently installs
the ring (setup_cp_attention), then shards the batch.

- get_capabilities: plain-dense 31B now supports_cp=True. E2B/E4B (dense+audio)
  left unsupported -- kv-sharing + per-layer-inputs are not yet wired through the
  ring; tracked separately.
- Dense forward stashes CP-sharded vision/packed metadata on the ring-hooked
  self_attn modules so the ring builds the correct bidirectional/packed mask.
- Adds 31B recipe: tulu3 text-only 16k (cp8).

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Cover the new dense-CP lines added with this PR:
- cp_attention: _patch_fsdp_accumulated_grad_guard (skip uninitialized param,
  run original when present, idempotent, no-op on import failure) and the
  _pre_hook fallback to module._cp_dense_metadata (plus kwarg precedence).
- model: get_capabilities for plain-dense 31B (cp on), dense+audio E2B/E4B
  (cp off), and MoE (cp+ep); dense __init__ ring attach; model-owned
  setup_cp_attention fan-out + idempotency; _cp_shard_batch install+delegate;
  prepare_model_inputs_for_cp attaching the bound batch fn; dense CP forward
  stashing CP-sharded metadata on the ring-hooked self_attn modules.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Extend the model-owned p2p ring context-parallelism (built for plain-dense
gemma4 31B) to the two small dense gemma4 checkpoints, google/gemma-4-E2B-it
and google/gemma-4-E4B-it (the dense+audio variant). These differ from 31B in
two ways that were investigated and confirmed to flow through the existing CP
path with no change to the ring or batch sharder:

  * per-layer inputs (hidden_size_per_layer_input=256): computed on the full
    sequence in prepare_model_inputs_for_cp and sharded contiguously on the
    seq dim alongside inputs_embeds (the 4D per_layer_inputs tensor is already
    a known key in cp_batch). Verified the CP-prep inputs_embeds and projected
    per-layer-inputs are bit-identical (MAXDIFF=0) to HF's native prep.

  * KV-sharing (E2B shares K/V across the last 20 layers, E4B the last 18):
    when the KV cache is active a shared layer reads its source layer's
    CP-local sharded K/V from DynamicCache.shared_layers and rotates it through
    the same ring as any other K/V. Under the activation checkpointing that CP
    training requires, HF disables the cache, so shared layers recompute their
    own K/V -- identical between CP and non-CP, so parity holds either way (the
    dead shared-layer K/V projections are frozen by
    freeze_unused_kv_sharing_params).

Only get_capabilities is changed (dense+audio variant flipped to
supports_cp=True) plus an accurate docstring; cp_attention.py, cp_batch.py and
the common infrastructure.py / cp_utils.py are untouched. TP stays OFF for this
variant: HF's Gemma4Model.forward builds per-layer inputs via
torch.where(multimodal_mask, pad_embedding, inputs_embeds) where pad_embedding
is sliced from the (TP-sharded) embedding weight, which raises a mixed
torch.Tensor/DTensor error under DTensor -- an HF-side limitation not fixable
without patching frozen transformers source.

Validation (single node, attn=sdpa, activation_checkpointing=true, bf16,
deterministic text mock, cp1 vs cp2 at matching steps, full layers):
  E2B lr=0:     loss 20.2786 vs 20.2964 (Dloss 1.8e-2); grad_norm 3650 vs 2846
  E2B lr=1e-5:  step1 12.83 vs 12.85, step5 2.86 vs 3.08, step14 0.038 vs 0.041
  E4B lr=0:     loss 18.8539 vs 18.8965 (Dloss 4.3e-2); grad_norm 1178 vs 1168 (<1%)
  E4B lr=1e-5:  step1 8.19 vs 8.18, step5 0.37 vs 0.32, step9 0.021 vs 0.022
No growing drift; deltas stay ~1e-2 at normal lr.

CAVEAT: CP for E2B/E4B is validated only UNDER activation_checkpointing (cache
off -> shared-KV layers recompute their own K/V). The cache-ON
kv-sharing-through-the-ring path is reasoned-correct but NOT empirically
validated. This is acceptable scope since CP training always uses activation
checkpointing, but it is documented rather than hidden.

Adds tests/unit_tests/models/gemma4/test_gemma4_2b4b_cp.py (per_layer_inputs
4D seq-sharding in cp_batch, kv-shared source-layer resolution, capability flip,
per-layer-input threading) and updates the dense+audio capability assertion in
test_gemma4_model_cp.py.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Add a working E4B (small dense 4B) 64k sequence-length recipe: Tulu-3 text,
sequence-packed to 64k, single 8-GPU node, dp1xcp8 (CP shards 64k -> 8k/rank).
Peak ~32 GiB/GPU; validated 10 steps, loss descends 12.6 -> 4.6, no OOM.

Three E4B-specific settings (each validated, vs the 31B tulu3 recipe):
- dataset.inject_fake_images=false: the VLM pipeline injects a fake image into
  text-only samples to keep the (frozen) vision tower in the FSDP graph; E4B's
  vision patch-embedder then builds a ~54 GiB position one-hot and OOMs. The
  vision tower is frozen, so the fake images are unnecessary -> disable.
- No thinking-prefix hook: E4B is not a reasoning model and cannot predict the
  injected <|channel>thought tokens (plain assistant loss ~1.5 vs blown-up with
  thinking); the 31B is a reasoning model and is the opposite.
- FusedLinearCrossEntropy: avoids the [S, 262k-vocab] logit spike at 64k.

E4B forward verified bit-identical to HuggingFace (logit parity max|diff|=0).
TP unavailable for E2B/E4B (HF per-layer-input merge crashes under DTensor), so
scale with dp_shard x cp only.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
E2B/E4B share K/V across their trailing num_kv_shared_layers layers: each shared
layer reads its source layer's K/V from past_key_values.shared_layers (HF
Gemma4Attention.forward). HF gates that read on `past_key_values is not None`, so
with use_cache=False -- which the CP path uses -- the shared layers fall back to
their frozen/unused K/V projections and emit garbage, inflating the loss ~4x
(step-0 ~12 vs ~3). This is why E4B CP runs trained at the wrong loss.

Fix: inject a minimal cache-free _Gemma4KVShareHolder as past_key_values in the
dense forward so HF's own kv-sharing fires (source layers populate shared_layers,
shared layers read it) with no per-token cache accumulation. Gated strictly to
kv-sharing variants (num_kv_shared_layers>0); 31B / MoE / other dense models get
past_key_values=None exactly as before -- no behavior change, no regression to
the flex-ring CP path.

Validated: E4B + use_cache=False + holder reproduces the use_cache=True reference
loss bit-for-bit (3.36 medpix); E4B cp8 (ring) text SFT now starts at the correct
~1-2 loss (was ~12). Note: kv-sharing is incompatible with activation
checkpointing (shared read breaks under recompute), so the E4B recipe runs with
activation_checkpointing=false, capping single-node context at ~32k (cp8, ~54 GiB).

Adds recipe gemma4_e4b_tulu3_text_cp8_32k.yaml; removes the 64k variant (does not
fit without AC on a single node).

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…mpiles

The Gemma4 model-owned CP ring rebuilt create_block_mask (~7.6ms each) once
per layer per ring chunk -- 168 builds/step for E4B's 42 layers -- and the
compiled flex_attention recompiled every step on variable sequence lengths
(duck-shape specialization guarding on block_mask.kv_indices.size). On short
sequences these two costs dominated the step (~66% in create_block_mask).

- Disable torch.fx use_duck_shape before compiling flex_attention so variable
  seqlens stop triggering per-step recompiles (no-op for fixed-length runs).
- Cache block masks on position scalars only (never tensor storage -- the CUDA
  allocator recycles addresses), reset per batch via the metadata data_ptr while
  pinning the tensor so the address cannot recycle mid-step. The mask is fully
  position-determined within a step, so all same-type layers, both ring chunks,
  and forward+backward share one entry: 168 -> 4 create_block_mask builds/step.

Loss is bit-identical before/after (single-sample and batched-mask paths). On
medpix cp2 single-stream tps goes 22 -> ~70/gpu; the 64k packed recipe is
unaffected in correctness and benefits from the reduced mask builds.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The single-node cp8 64k recipe (added in a46ff01, removed in 067b9e7)
OOMs without activation checkpointing, which E4B's kv-sharing forbids. The
validated 64k config shards across 2 nodes (16 GPUs, dp1 x cp16 -> 4k
tokens/rank, the same local footprint as the single-node cp8/32k recipe,
~54 GiB/GPU). Step-0 loss ~1, matching the use_cache=true reference.

Per-recipe slurm launchers are not tracked (repo keeps only the generic
slurm.sub); see the 2-node sbatch in the working tree for invocation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Add unit tests for the block-mask cache helpers (_cached_block_mask,
_block_mask_set_generation): build-once-then-hit, eviction above the cap,
clear-only-on-new-data_ptr, and the None generation staying stable. Also
sort the lazy imports in _compiled_flex_attention and add the missing
blank lines flagged by ruff.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi HuiyingLi requested a review from a team as a code owner June 17, 2026 12:09
@copy-pr-bot

copy-pr-bot Bot commented Jun 17, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test 6febcb5

Remove gemma4_e4b_tulu3_text_cp8_32k.yaml and inline its E4B-settings
rationale into the 2-node cp16/64k recipe (which previously referenced it),
so the 64k recipe is self-documenting. The 64k/2-node recipe is the
validated E4B CP config.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test 69c180c

…ers ver

L0_Unit_Tests_CPU surfaced two env-specific failures in the dense-CP tests:

- 3 contiguous-shard tests hit `_FakeCPMesh has no get_group` because the full
  suite leaves torch.distributed initialized (test-order pollution), pushing the
  shard code down its distributed branch. Add an autouse fixture pinning
  `is_initialized` False so these single-process tests deterministically take the
  non-distributed branch. Verified by reproducing with a real gloo group.
- test_kv_shared_layers_resolve_source_layer asserted HF's internal
  `kv_shared_layer_index`, which exists in transformers 5.5 but not 5.8 (CI).
  Guard it with getattr; still assert the version-stable `is_kv_shared_layer`.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test 0340a2f

import_linting failed in setup (`uv sync` network timeout downloading
torch==2.10.0+cu130, UV_HTTP_TIMEOUT 30s) — same transient infra flake that
hit linting earlier and cleared on retry. No-op commit to re-run.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test eca19f4

Brings in the merged dense-31B CP work (PR #2592, squash-merged). Resolved
conflicts in favor of the E2B/E4B CP enablement this branch adds:
- get_capabilities dense+audio variant -> supports_cp=True (was the 31B's
  "CP not yet supported" placeholder).
- dense forward injects the cache-free _Gemma4KVShareHolder for kv-sharing.
- kept the matching unit tests (capabilities + holder + non-CP forward).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test d3437bd

The CP dense-forward holder injection (model.py: past_key_values =
_Gemma4KVShareHolder() under cp_enabled + kv-sharing) was the one new line not
hit by the CPU unit tests. Add test_forward_dense_cp_injects_kv_share_holder
(kv-sharing config + CP forward, mocked language_model) to assert the holder is
threaded through — closing the patch-coverage gap (both changed files now 100%
on new lines).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test 07adf3d

1 similar comment
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test 07adf3d

The prior L0_Unit_Tests_GPU hung (30-min timeout x3 retries) in the unowned
gemma4_drafter/test_composite_fsdp2.py NCCL mp.spawn test. Not reproducible
locally under CI-matching conditions (transformers 5.8.1, 2 GPUs, coverage,
full gemma4+drafter order: 300 passed) — consistent with a degraded-runner
flake. No-op commit to re-run on a fresh runner.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi

Copy link
Copy Markdown
Contributor Author

/ok to test a34cb4d

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants