Skip to content

feat(xtoken): add multi-teacher support to cross-tokenizer off-policy distillation#2797

Open
avenkateshha wants to merge 18 commits into
mainfrom
avenkateshha/xtoken-multi-teacher
Open

feat(xtoken): add multi-teacher support to cross-tokenizer off-policy distillation#2797
avenkateshha wants to merge 18 commits into
mainfrom
avenkateshha/xtoken-multi-teacher

Conversation

@avenkateshha

Copy link
Copy Markdown
Contributor

Builds on #2757 (CE + H-KL gold loss). Extends cross-tokenizer off-policy distillation to support multiple teachers via a teachers: list of TeacherConfig.

What

  • Same-vocab teachers (projection_matrix_path=null) use direct per-position KL with no alignment; cross-tokenizer teachers keep the projection + alignment path.
  • Per-teacher indexed data-dict keys and serial teacher forward passes.
  • kd_loss_mode aggregation knob: sum / averaged_logits / select_teacher.
  • New examples/configs/multi_teacher_cross_tokenizer_smoke.yaml exemplar (Qwen3-4B cross-tokenizer teacher + Llama-3.2-3B same-vocab teacher).
  • Updated unit tests in tests/unit/algorithms/x_token/test_xtoken_off_policy_distillation.py.

🤖 Generated with Claude Code

… distillation

Extend cross-tokenizer off-policy distillation to support multiple teachers
via a `teachers:` list of TeacherConfig. Same-vocab teachers
(projection_matrix_path=null) use direct per-position KL with no alignment;
cross-tokenizer teachers keep the projection + alignment path. Adds
per-teacher indexed data-dict keys, serial teacher forward passes, and a
kd_loss_mode (sum / averaged_logits / select_teacher) aggregation knob.

Includes a multi_teacher_cross_tokenizer_smoke.yaml exemplar (Qwen3-4B
cross-tokenizer + Llama-3.2-3B same-vocab teacher) and updated unit tests.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha requested review from a team as code owners June 12, 2026 07:22
@copy-pr-bot

copy-pr-bot Bot commented Jun 12, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

- Shorten the H1 to 'Cross-Tokenizer (X-Token)'.
- Trim the future-work note to a generic 'actively improving support'.
- Reframe Step 2 around tokenizer overlap (similar algorithms such as BPE).
- Describe Step 3's sparse [V_student, top_k] projection representation and
  why it avoids a dense [student_vocab, teacher_vocab] matrix.
- Remove the --preserve_last paragraph; the recommended recipe disables the
  scale trick, so it never engages.
- Drop the 'via CUDA IPC' qualifier from the P-KL loss-mode row.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha requested a review from a team as a code owner June 14, 2026 01:28
@github-actions github-actions Bot added the Documentation Improvements or additions to documentation label Jun 14, 2026
- Drop 'Off-Policy' from the X-Token section/subsection headings and the
  feature-list entry; update the matching table-of-contents anchors.
- Reword the section's first line to 'distillation' (no 'off-policy').
- Add a link to the x-token distillation paper alongside the implementation
  guide in the 'read about the details' line.

Real file paths (run_xtoken_off_policy_distillation.py,
xtoken_off_policy_distillation.yaml, the guide filename) are left unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha requested a review from a team as a code owner June 14, 2026 01:48
avenkateshha and others added 4 commits June 13, 2026 20:18
Replace the single-teacher P-KL results with the multi-teacher run
(Llama-3.2-1B student <- Phi-4-mini-instruct + Llama-3.2-3B teachers):

- Regenerate the training-curve figure from the multi-teacher run and swap
  the asset (xtoken_pkl_smoke_curves.png -> xtoken_mt_curves.png).
- Rewrite the config sentence and loss bullets; note that the cross-tokenizer
  Phi-4-mini teacher carries the KL while the same-tokenizer Llama-3.2-3B
  teacher's direct KL stays small.
- Add a Downstream evaluation table (distilled student vs. base Llama-3.2-1B).
- Refresh the throughput table (mean step time 6.65 s, teacher forward
  dominates, ~29.5k tok/s); drop the stale peak-memory row.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Same-vocab teachers previously shipped top-k logits via the driver path (get_topk_logits), which gathers dense [B, T, k] value/index tensors to the driver and re-shards them to the student every step (~16-19 GB/step for k=8192), inflating policy_training/submit_training_futures. Cross-tokenizer teachers already used zero-copy CUDA IPC.

Add a single producer, get_teacher_logits_ipc(mode='full'|'topk'), that handles both transports over persistent CUDA IPC buffers (1 buffer for full logits; 2 for top-k values+indices), generalizing the existing single-buffer machinery. Top-k values ship in native compute dtype (bf16) and indices in int32 (lossless for <2^31 vocab; cast to int64 at the gather). get_full_logits_ipc now delegates to mode='full'; same-vocab top-k teachers ship a teacher_{i}_topk_ipc handle list instead of dense tensors. Numerically bf16-exact vs the prior driver path.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Compare the MOPD (on-policy, Megatron) and xToken (off-policy, DTensor V2)
distillation recipes across multi-teacher, async, policy, loss, tokenizer,
and backend, in a new Distillation overview preceding both recipe sections.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Add a paragraph and figure on pairwise tokenizer vocabulary overlap
(intersection over min vocab size) to motivate why the cross-tokenizer
projection is necessary.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@yuki-97 yuki-97 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avenkateshha thanks for supporting this, nice work! reviewed partly and left some comments.

reduction="mean",
)
ces.append(ce_i.item())
best = int(min(range(self.num_teachers), key=lambda j: ces[j]))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ces.append(ce_i.item()) + min is on rank-local CE — DP ranks can pick different best teachers. _compute_teacher_kd(best, ...) then runs DP all-reduce collectives whose ordering must match across ranks (per the comment at line 1829), so divergent participation deadlocks or silently produces inconsistent reductions. Same issue at _compute_dynamic_weights:1988: softmax(alpha * stack(rank-local-scores)) gives each rank different weights. All-reduce the per-teacher scalar score before the min / softmax.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in f9fc467. Added _dp_global_masked_mean, which all-reduces the masked score sum and the valid count across the DP group, and used it for the select-teacher CE and the weight-metric scores. Every rank now computes the identical global score, so min (select) and softmax (dynamic weights) agree across ranks — no divergent participation in the per-teacher dp_all_reduce_sum. The score is detached (it gates selection/weighting, not back-propagated, matching the reference'''s no_grad scoring). Verified with a 2-rank gloo run: ranks disagree on the local mean but agree on the global mean.

Comment thread nemo_rl/algorithms/loss/loss_functions.py
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread docs/guides/xtoken-off-policy-distillation.md Outdated
Comment thread docs/guides/xtoken-off-policy-distillation.md
Comment thread examples/configs/multi_teacher_cross_tokenizer_smoke.yaml Outdated
checkpointing: CheckpointingConfig


def normalize_multi_teacher_config(raw_config: dict[str, Any]) -> dict[str, Any]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this function and always use teachers instead of teacher.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 5f33a30. Removed normalize_multi_teacher_config and its call in the entrypoint; configs now must use teachers: (one teacher is just a 1-element list). Also dropped the loss constructor's single-teacher fallback branch so it always reads the injected per-teacher lists.

Comment thread examples/configs/xtoken_off_policy_distillation.yaml
Comment on lines +1353 to +1360
# Per-teacher metadata injected at runtime by
# ``xtoken_off_policy_distillation.setup`` (parallel to ``teachers``).
teacher_vocab_sizes: NotRequired[list[int]]
projection_matrix_paths: NotRequired[list[Optional[str]]]
teacher_weights: NotRequired[list[float]]
teacher_send_full_logits: NotRequired[list[bool]]
teacher_gold_loss: NotRequired[list[Optional[bool]]]
teacher_xtoken_loss: NotRequired[list[Optional[bool]]]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I understanding correctly that we will always use these even for single teacher (I think it's great)? if so,

  1. can we remove the legacy params like teacher_vocab_size?
  2. is it sometimes NotRequired, if not can we remove NotRequired?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 5f33a30. 1) Removed the legacy singular fields teacher_vocab_size and projection_matrix_path from CrossTokenizerDistillationLossConfig (and the constructor branch that read them) — the per-teacher teacher_vocab_sizes / projection_matrix_paths lists carry these now. 2) On NotRequired: I kept it on the per-teacher lists. They're injected at runtime by setup and are absent in the raw YAML, so the same TypedDict describes both the pre- and post-injection dict; marking them required would be inaccurate at config-load time. So those stay NotRequired by design rather than being removed.

Comment on lines +1370 to +1380
- Cross-tokenizer teacher: ``teacher_{i}_full_logits_ipc`` (List[B] of
CUDA IPC handle dicts from ``Policy.get_full_logits_ipc``),
``teacher_{i}_input_ids`` / ``teacher_{i}_token_mask`` ``[B, T_t]``, and
``alignment_{i}_*`` (``pair_valid`` / ``pair_is_correct``
``[B, max_pairs]``; ``student_exact_partition_mask`` /
``student_chunk_id`` ``[B, T_s]``; ``teacher_exact_partition_mask`` /
``teacher_chunk_id`` ``[B, T_t]``; ``num_chunks`` ``[B]``).
- Same-vocab teacher: either ``teacher_{i}_full_logits_ipc`` (when
``send_full_logits``) or ``teacher_{i}_topk_logits`` /
``teacher_{i}_topk_indices`` ``[B, T_s, k]``; no ``alignment_{i}_*``
(it reuses the student tokenization, identity-aligned).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use list or dict like alignment_pair_valid: list[torch.Tensor] and still list things at CrossTokenizerDistillationLossDataDict?

for maybe full_logtis or topk_logits, we can mark them Optional but still put them both under CrossTokenizerDistillationLossDataDict.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks — thought about this carefully and I would keep the per-teacher indexed keys rather than convert them to list[torch.Tensor] fields, for a concrete reason in the data path:

train_data is split into microbatches by BatchedDataDict (shared across all algorithms), whose slicer treats every value's first axis as the batch (example) axisslice(start, end) does value[start:end] for every key. With per-teacher separate keys, each value is a [B, ...] tensor, so the example-slice is correct. If a field were a list indexed by teacher ([t0_tensor, t1_tensor]), the slicer would slice the teacher axis (list[start:end]), not the example axis — so microbatches would be wrong. Two more blockers: (1) a list value already means per-example here (e.g. teacher_{i}_full_logits_ipc is a length-B list of per-sample IPC handles, sliced correctly by example), so a per-teacher list would be ambiguous to the same machinery; (2) cross-tokenizer teachers are ragged — different T_t and max_pairs per teacher — so they can't be stacked into one [B, num_teachers, ...] tensor either. Supporting per-teacher lists would mean changing the shared BatchedDataDict slicing/sharding for a cosmetic schema gain.

The typed view you're after does exist where it's consumed: the loss rehydrates each teacher's flat keys into a typed AlignmentBatch via alignment_from_flat_batch(data, f"alignment_{i}_"). And I corrected the CrossTokenizerDistillationLossDataDict docstring in 41ee5a9 so it accurately enumerates the per-teacher keys + shapes and notes the two conditionally-present transports (full_logits_ipc vs topk_ipc) — i.e. the listing this schema can't express as fields, kept accurate in the docstring instead. Let me know if you'd still prefer the list form despite the slicing constraint.

"""
device = logits.device
if self.sum_weights_metric is not None:
weights = self._compute_dynamic_weights(data, teacher_full_logits_by_idx)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BUG] Dynamic weights are computed per-DP-rank, so the loss is not the same function across ranks.

_teacher_weight_score uses rank-local cross_entropy(reduction="mean") / entropy / max-prob, then softmaxes — each DP rank gets different weights[i] for the same teacher. Same root cause as the previous comment on _select_teacher_kd:1922.

Suggest fix: Aggregate the score across DP before the softmax.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in f9fc467 (same change as the select-teacher thread). _teacher_weight_score now reduces ce/entropy/max_prob via _dp_global_masked_mean (all-reduce of masked-sum and valid-count across DP), so the softmax sees the same global scores on every rank and weights[i] is identical across ranks. The loss is now the same function on all ranks.

Comment thread nemo_rl/algorithms/loss/loss_functions.py
assert total_kd is not None
return total_kd, per_metrics

total_w = sum(self.teacher_weights)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BUG] Fast path silently ignores sum_weights_metric and always uses static teacher_weights[i].

The docstring only warns about this on the fallback branch, so kd_loss_mode="averaged_logits" + sum_weights_metric="ce" quietly drops dynamic weighting.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in d152377 (chose fail-loud over honoring it, to keep reference parity). The PT reference (train_distillation_ddp.py) does not apply dynamic weights in averaged_logits — it sets dynamic_weights = None and blends with the static config weight (lines 1589-1624, token_weights=None), and --sum_weights_metric is documented as a sum-mode-only knob (lines 109-110). select_teacher likewise sets dynamic_weights = None. So honoring the metric here would diverge from the reference; instead the loss fn now raises at construction if sum_weights_metric is set with any mode other than sum, so it can no longer be silently dropped. Added a test covering both averaged_logits and select_teacher.

"kl_loss_weight": 1.0,
"ce_loss_scale": 1.0,
"dynamic_loss_scaling": False,
"kd_loss_mode": "sum",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No unit coverage for averaged_logits or select_teacher — every test sets kd_loss_mode="sum".

Those two branches are exactly where the prev. comments on 1922 / 1919 / 1871 and the new loss-side comments live, so CI won't catch any of those regressions.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially addressed in 842c095 — added two tests that exercise the averaged_logits guard directly against the real loss fn (the cross-tokenizer-shape bug above), built via __new__ since this module otherwise mocks the loss fn. Broader mode coverage (select_teacher, the DP-aggregation issue) is still pending and I'll follow up separately.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closed out in aa79d97. Added a parametrized end-to-end test asserting select_teacher picks the lowest-CE teacher (two same-vocab teachers: one peaks on the true next token, the other is uniform; parametrized over which index is better). Together with the averaged_logits guard tests (M1), the masked ce/entropy/max_prob scoring tests (M2), the sum_weights_metric-rejection tests (M3) and the DP-global-score helper test (H1), both previously-uncovered modes now have coverage.

@@ -492,68 +538,218 @@ def test_ipc_buffer_released_on_student_train_failure(mock_xtoken_components):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_ipc_buffer_released_on_student_train_failure (~lines 527-537, unchanged by this PR) wouldn't catch a multi-teacher regression.

It's single-teacher only and asserts release_ipc_buffer.call_count >= 1 — releasing only teacher_policies[0] (a moved finally, an early break) still passes.

Suggest fix: Parametrize on num_teachers ∈ {1, 2} and assert call_count == len(teacher_policies).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 1b124e7. Parametrized on num_teachers in {1, 2} and now assert release_ipc_buffer is called exactly once per teacher (== len(teacher_policies)) instead of >= 1, so a release-only-teacher-0 bug would now fail. (Used same-vocab teachers so the release-loop check doesn't depend on per-teacher batch keys.)

t1.get_teacher_logits_ipc.assert_called_once()
assert t1.get_teacher_logits_ipc.call_args.kwargs["mode"] == "topk"
assert t1.get_teacher_logits_ipc.call_args.kwargs["k"] == 8
# Serial collocated execution: each teacher onloaded then offloaded.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Serial" is in the test name but isn't actually verified.

The test only checks that each per-teacher mock method is called once, not that t0.offload happens BEFORE t1.prepare — which is the whole point of serial collocation.

Suggest fix: Attach the per-teacher mocks to a shared MagicMock parent and assert parent.mock_calls shows the interleaved order.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 1b124e7. Attached both teacher mocks to a shared parent and now assert the interleaved order — teacher 0's offload_after_refit happens before teacher 1's prepare_for_lp_inference — so loading both teachers at once (the failure the 'serial' name guards against) would now fail. Kept the per-teacher called-once checks too.

Comment thread nemo_rl/algorithms/xtoken_off_policy_distillation.py
Comment thread nemo_rl/algorithms/xtoken_off_policy_distillation.py Outdated
"""
distill_cfg = master_config.distillation
timer = timer if timer is not None else Timer()
skip_keys = xtoken_non_student_seq_keys(loss_fn)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_keys = xtoken_non_student_seq_keys(loss_fn) is rebuilt here on every validate call, parallel to the train loop's build at line 624.

It's purely a function of loss_fn.{num_teachers, projection_matrix_paths, teacher_ships_full}, so duplicating it lets the two paths drift.

Suggest fix: Cache on loss_fn (or return from setup) and thread through.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 00358ad. xtoken_non_student_seq_keys(loss_fn) is now built once in the train loop and passed into validate as a skip_keys arg; removed validate's own rebuild. It is a pure function of the loss fn's static per-teacher metadata, so both sites always produced the identical set — this is behavior-preserving (verified the validate + train-loop tests still pass), just removes the duplicate build and the drift risk. (Kept the build in the train loop rather than setup: the setup unit tests patch the loss-fn class to an unconfigured mock, so building there would fail on range(mock.num_teachers).)

@@ -1,8 +1,12 @@
# Single-teacher cross-tokenizer off-policy distillation.
# Multi-teacher cross-tokenizer off-policy distillation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a nightly test to guard multi-teacher?

you can refer to the three file in #2745.

  1. examples/configs/recipes/llm/distillation-xtoken-off-policy-qwen3-4b-to-llama3.2-1b-1n8g-dtensor-tp4cp2.yaml
  2. tests/test_suites/llm/distillation-xtoken-off-policy-qwen3-4b-to-llama3.2-1b-1n8g-dtensor-tp4cp2.sh
  3. tests/test_suites/nightly.txt

@yuki-97

yuki-97 commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

@terrykong could you take a review as well?

@yuki-97 yuki-97 requested a review from terrykong June 17, 2026 13:47
avenkateshha and others added 10 commits June 17, 2026 14:33
- Rename module docstring to multi-teacher (was single-teacher).
- Add "Off-Policy" to the X-Token guide title and README section (and
  fix the README TOC anchor) so the off-policy nature is explicit.
- Set `teacher: null` in the exemplar (inherited from distillation_math,
  unused by xtoken) with a clarifying comment.
- Drop the no-op `packed.to("cpu")` and its stale comment: train_data
  holds only IPC handle dicts and already-CPU dataloader tensors.
- Remove the redundant multi_teacher_cross_tokenizer_smoke.yaml exemplar.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
The final box's right vertical border sat 2 columns past its top and
bottom horizontal borders, so the box rendered crooked. Trim the 2
trailing pad spaces on each content line so the right edge lines up at
the same column as the box corners.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…ct-KL

The averaged_logits fast path took a single direct per-position KL whenever
all teachers' full-logit tensors shared a shape. Two cross-tokenizer teachers
can share a shape while still needing the projection/alignment path, so they
slipped into the fast path and mismatched the student's token_mask when the
teacher sequence length differs (T_t != T_s). Gate the fast path on every
projection_matrix_path being None (genuine same-tokenizer teachers) in
addition to the shape check.

Add regression tests for both branches of the guard in the tracked test
module so CI exercises them.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
select_teacher and dynamic-weight (sum_weights_metric) scoring averaged the
per-teacher ce/entropy/max_prob over every position with reduction="mean".
Padded positions carry near-uniform logits that dominate the score on
long-padded batches and skew teacher selection / weights. Mask each score by
the teacher's own token_mask (the student's for a same-vocab teacher) and the
sample_mask before reducing, matching the main loss; thread the per-teacher
token mask through _teacher_score_inputs.

Add regression tests asserting padded positions and dropped samples do not
move the score.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Dynamic teacher weighting (sum_weights_metric) is only applied in
kd_loss_mode="sum"; averaged_logits and select_teacher use static per-teacher
weights, matching the reference. Combining sum_weights_metric with either
mode silently dropped the metric. Reject the combo at loss-fn construction
instead, alongside the existing xtoken/gold guard.

Add a test asserting the combo raises for both averaged_logits and
select_teacher.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
select_teacher and dynamic-weight (sum_weights_metric) scoring reduced each
teacher's metric over the rank-local shard only, so DP ranks could pick a
different teacher / different weights. The per-teacher KD then fires a
dp_all_reduce_sum collective, so divergent choices across ranks deadlock (a
cross-tokenizer teacher fires the collective, a same-vocab one does not) or
silently mix reductions.

Add _dp_global_masked_mean, which all-reduces the masked score sum and the
valid count across the DP group, and use it for the select-teacher CE and the
weight-metric (ce/entropy/max_prob) scores. Every rank now computes the
identical global score, so all ranks select the same teacher / same weights
and hit the same collective sequence. The score is detached (it gates
selection/weighting and is not back-propagated, matching the reference's
no_grad scoring).

Verified cross-rank agreement with a 2-rank gloo run (ranks disagree on the
local mean, agree on the global mean). Add a single-process unit test for the
helper. No effect on the single-teacher sum/null path used for gold parity.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Multi-teacher is now the only model (one teacher is a 1-element `teachers:`
list), so the back-compat single-`teacher:` path is dead weight.

- Remove `normalize_multi_teacher_config` (the legacy `teacher:` -> `teachers:`
  fold) and its call in the entrypoint; configs must use `teachers:`.
- Remove the loss constructor's single-teacher fallback branch; it now always
  reads the injected per-teacher lists.
- Drop the now-unused legacy fields `teacher_vocab_size` and
  `projection_matrix_path` from CrossTokenizerDistillationLossConfig. The
  per-teacher `teacher_vocab_sizes` / `projection_matrix_paths` lists (and
  `TeacherConfig.projection_matrix_path`) carry these; the `NotRequired`
  markers stay on the runtime-injected lists since they are absent in raw YAML.
- Migrate the `_ct_loss_cfg` test helper to the plural per-teacher form and
  remove the now-obsolete shim unit test.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
`xtoken_non_student_seq_keys(loss_fn)` was rebuilt inside `validate` on every
validation pass, parallel to the train loop's build, so the two copies could
drift. It is a pure function of the loss fn's static per-teacher metadata, so
both always produce the identical set. Build it once in the train loop and
pass it into `validate` as a `skip_keys` argument; drop validate's rebuild.
No behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
- IPC release: parametrize on num_teachers in {1, 2} and assert
  release_ipc_buffer is called exactly once per teacher. The old test was
  single-teacher and asserted call_count >= 1, so a bug releasing only
  teacher 0 (misplaced finally / early break) would still pass while leaking
  the rest.
- Serial collocation: attach both teacher mocks to a shared parent and assert
  teacher 0 is offloaded before teacher 1 is onloaded. The test was named
  "runs_serially" but only checked each load/offload was called once, never
  the order, so loading both teachers at once would have passed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Asserts select_teacher uses the teacher with the lowest next-token CE: two
same-vocab teachers, one peaking on the true next token (CE ~ 0) and one
uniform (CE = log V), parametrized over which index is better so it can't
pass by always returning 0. Completes the averaged_logits / select_teacher
coverage gap (the guard + scoring branches were already covered by the
M1/M2/M3/H1 fixes).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…ring

The CrossTokenizerDistillationLossDataDict docstring described same-vocab
top-K teachers as shipping dense teacher_{i}_topk_logits / _indices [B, T_s, k]
and attributed teacher_{i}_full_logits_ipc to Policy.get_full_logits_ipc. Both
are stale after the IPC-producer unification: top-K teachers now ship
teacher_{i}_topk_ipc (a CUDA IPC handle list), and the handles are produced by
get_teacher_logits_ipc(mode="full"/"topk").

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants