feat(xtoken): add multi-teacher support to cross-tokenizer off-policy distillation#2797
feat(xtoken): add multi-teacher support to cross-tokenizer off-policy distillation#2797avenkateshha wants to merge 18 commits into
Conversation
… distillation Extend cross-tokenizer off-policy distillation to support multiple teachers via a `teachers:` list of TeacherConfig. Same-vocab teachers (projection_matrix_path=null) use direct per-position KL with no alignment; cross-tokenizer teachers keep the projection + alignment path. Adds per-teacher indexed data-dict keys, serial teacher forward passes, and a kd_loss_mode (sum / averaged_logits / select_teacher) aggregation knob. Includes a multi_teacher_cross_tokenizer_smoke.yaml exemplar (Qwen3-4B cross-tokenizer + Llama-3.2-3B same-vocab teacher) and updated unit tests. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
- Shorten the H1 to 'Cross-Tokenizer (X-Token)'. - Trim the future-work note to a generic 'actively improving support'. - Reframe Step 2 around tokenizer overlap (similar algorithms such as BPE). - Describe Step 3's sparse [V_student, top_k] projection representation and why it avoids a dense [student_vocab, teacher_vocab] matrix. - Remove the --preserve_last paragraph; the recommended recipe disables the scale trick, so it never engages. - Drop the 'via CUDA IPC' qualifier from the P-KL loss-mode row. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
- Drop 'Off-Policy' from the X-Token section/subsection headings and the feature-list entry; update the matching table-of-contents anchors. - Reword the section's first line to 'distillation' (no 'off-policy'). - Add a link to the x-token distillation paper alongside the implementation guide in the 'read about the details' line. Real file paths (run_xtoken_off_policy_distillation.py, xtoken_off_policy_distillation.yaml, the guide filename) are left unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Replace the single-teacher P-KL results with the multi-teacher run (Llama-3.2-1B student <- Phi-4-mini-instruct + Llama-3.2-3B teachers): - Regenerate the training-curve figure from the multi-teacher run and swap the asset (xtoken_pkl_smoke_curves.png -> xtoken_mt_curves.png). - Rewrite the config sentence and loss bullets; note that the cross-tokenizer Phi-4-mini teacher carries the KL while the same-tokenizer Llama-3.2-3B teacher's direct KL stays small. - Add a Downstream evaluation table (distilled student vs. base Llama-3.2-1B). - Refresh the throughput table (mean step time 6.65 s, teacher forward dominates, ~29.5k tok/s); drop the stale peak-memory row. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Same-vocab teachers previously shipped top-k logits via the driver path (get_topk_logits), which gathers dense [B, T, k] value/index tensors to the driver and re-shards them to the student every step (~16-19 GB/step for k=8192), inflating policy_training/submit_training_futures. Cross-tokenizer teachers already used zero-copy CUDA IPC.
Add a single producer, get_teacher_logits_ipc(mode='full'|'topk'), that handles both transports over persistent CUDA IPC buffers (1 buffer for full logits; 2 for top-k values+indices), generalizing the existing single-buffer machinery. Top-k values ship in native compute dtype (bf16) and indices in int32 (lossless for <2^31 vocab; cast to int64 at the gather). get_full_logits_ipc now delegates to mode='full'; same-vocab top-k teachers ship a teacher_{i}_topk_ipc handle list instead of dense tensors. Numerically bf16-exact vs the prior driver path.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Compare the MOPD (on-policy, Megatron) and xToken (off-policy, DTensor V2) distillation recipes across multi-teacher, async, policy, loss, tokenizer, and backend, in a new Distillation overview preceding both recipe sections. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Add a paragraph and figure on pairwise tokenizer vocabulary overlap (intersection over min vocab size) to motivate why the cross-tokenizer projection is necessary. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com> Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
yuki-97
left a comment
There was a problem hiding this comment.
@avenkateshha thanks for supporting this, nice work! reviewed partly and left some comments.
| reduction="mean", | ||
| ) | ||
| ces.append(ce_i.item()) | ||
| best = int(min(range(self.num_teachers), key=lambda j: ces[j])) |
There was a problem hiding this comment.
ces.append(ce_i.item()) + min is on rank-local CE — DP ranks can pick different best teachers. _compute_teacher_kd(best, ...) then runs DP all-reduce collectives whose ordering must match across ranks (per the comment at line 1829), so divergent participation deadlocks or silently produces inconsistent reductions. Same issue at _compute_dynamic_weights:1988: softmax(alpha * stack(rank-local-scores)) gives each rank different weights. All-reduce the per-teacher scalar score before the min / softmax.
There was a problem hiding this comment.
Fixed in f9fc467. Added _dp_global_masked_mean, which all-reduces the masked score sum and the valid count across the DP group, and used it for the select-teacher CE and the weight-metric scores. Every rank now computes the identical global score, so min (select) and softmax (dynamic weights) agree across ranks — no divergent participation in the per-teacher dp_all_reduce_sum. The score is detached (it gates selection/weighting, not back-propagated, matching the reference'''s no_grad scoring). Verified with a 2-rank gloo run: ranks disagree on the local mean but agree on the global mean.
| checkpointing: CheckpointingConfig | ||
|
|
||
|
|
||
| def normalize_multi_teacher_config(raw_config: dict[str, Any]) -> dict[str, Any]: |
There was a problem hiding this comment.
we can remove this function and always use teachers instead of teacher.
There was a problem hiding this comment.
Done in 5f33a30. Removed normalize_multi_teacher_config and its call in the entrypoint; configs now must use teachers: (one teacher is just a 1-element list). Also dropped the loss constructor's single-teacher fallback branch so it always reads the injected per-teacher lists.
| # Per-teacher metadata injected at runtime by | ||
| # ``xtoken_off_policy_distillation.setup`` (parallel to ``teachers``). | ||
| teacher_vocab_sizes: NotRequired[list[int]] | ||
| projection_matrix_paths: NotRequired[list[Optional[str]]] | ||
| teacher_weights: NotRequired[list[float]] | ||
| teacher_send_full_logits: NotRequired[list[bool]] | ||
| teacher_gold_loss: NotRequired[list[Optional[bool]]] | ||
| teacher_xtoken_loss: NotRequired[list[Optional[bool]]] |
There was a problem hiding this comment.
am I understanding correctly that we will always use these even for single teacher (I think it's great)? if so,
- can we remove the legacy params like
teacher_vocab_size? - is it sometimes NotRequired, if not can we remove
NotRequired?
There was a problem hiding this comment.
Done in 5f33a30. 1) Removed the legacy singular fields teacher_vocab_size and projection_matrix_path from CrossTokenizerDistillationLossConfig (and the constructor branch that read them) — the per-teacher teacher_vocab_sizes / projection_matrix_paths lists carry these now. 2) On NotRequired: I kept it on the per-teacher lists. They're injected at runtime by setup and are absent in the raw YAML, so the same TypedDict describes both the pre- and post-injection dict; marking them required would be inaccurate at config-load time. So those stay NotRequired by design rather than being removed.
| - Cross-tokenizer teacher: ``teacher_{i}_full_logits_ipc`` (List[B] of | ||
| CUDA IPC handle dicts from ``Policy.get_full_logits_ipc``), | ||
| ``teacher_{i}_input_ids`` / ``teacher_{i}_token_mask`` ``[B, T_t]``, and | ||
| ``alignment_{i}_*`` (``pair_valid`` / ``pair_is_correct`` | ||
| ``[B, max_pairs]``; ``student_exact_partition_mask`` / | ||
| ``student_chunk_id`` ``[B, T_s]``; ``teacher_exact_partition_mask`` / | ||
| ``teacher_chunk_id`` ``[B, T_t]``; ``num_chunks`` ``[B]``). | ||
| - Same-vocab teacher: either ``teacher_{i}_full_logits_ipc`` (when | ||
| ``send_full_logits``) or ``teacher_{i}_topk_logits`` / | ||
| ``teacher_{i}_topk_indices`` ``[B, T_s, k]``; no ``alignment_{i}_*`` | ||
| (it reuses the student tokenization, identity-aligned). |
There was a problem hiding this comment.
can we use list or dict like alignment_pair_valid: list[torch.Tensor] and still list things at CrossTokenizerDistillationLossDataDict?
for maybe full_logtis or topk_logits, we can mark them Optional but still put them both under CrossTokenizerDistillationLossDataDict.
There was a problem hiding this comment.
Thanks — thought about this carefully and I would keep the per-teacher indexed keys rather than convert them to list[torch.Tensor] fields, for a concrete reason in the data path:
train_data is split into microbatches by BatchedDataDict (shared across all algorithms), whose slicer treats every value's first axis as the batch (example) axis — slice(start, end) does value[start:end] for every key. With per-teacher separate keys, each value is a [B, ...] tensor, so the example-slice is correct. If a field were a list indexed by teacher ([t0_tensor, t1_tensor]), the slicer would slice the teacher axis (list[start:end]), not the example axis — so microbatches would be wrong. Two more blockers: (1) a list value already means per-example here (e.g. teacher_{i}_full_logits_ipc is a length-B list of per-sample IPC handles, sliced correctly by example), so a per-teacher list would be ambiguous to the same machinery; (2) cross-tokenizer teachers are ragged — different T_t and max_pairs per teacher — so they can't be stacked into one [B, num_teachers, ...] tensor either. Supporting per-teacher lists would mean changing the shared BatchedDataDict slicing/sharding for a cosmetic schema gain.
The typed view you're after does exist where it's consumed: the loss rehydrates each teacher's flat keys into a typed AlignmentBatch via alignment_from_flat_batch(data, f"alignment_{i}_"). And I corrected the CrossTokenizerDistillationLossDataDict docstring in 41ee5a9 so it accurately enumerates the per-teacher keys + shapes and notes the two conditionally-present transports (full_logits_ipc vs topk_ipc) — i.e. the listing this schema can't express as fields, kept accurate in the docstring instead. Let me know if you'd still prefer the list form despite the slicing constraint.
| """ | ||
| device = logits.device | ||
| if self.sum_weights_metric is not None: | ||
| weights = self._compute_dynamic_weights(data, teacher_full_logits_by_idx) |
There was a problem hiding this comment.
[BUG] Dynamic weights are computed per-DP-rank, so the loss is not the same function across ranks.
_teacher_weight_score uses rank-local cross_entropy(reduction="mean") / entropy / max-prob, then softmaxes — each DP rank gets different weights[i] for the same teacher. Same root cause as the previous comment on _select_teacher_kd:1922.
Suggest fix: Aggregate the score across DP before the softmax.
There was a problem hiding this comment.
Fixed in f9fc467 (same change as the select-teacher thread). _teacher_weight_score now reduces ce/entropy/max_prob via _dp_global_masked_mean (all-reduce of masked-sum and valid-count across DP), so the softmax sees the same global scores on every rank and weights[i] is identical across ranks. The loss is now the same function on all ranks.
| assert total_kd is not None | ||
| return total_kd, per_metrics | ||
|
|
||
| total_w = sum(self.teacher_weights) |
There was a problem hiding this comment.
[BUG] Fast path silently ignores sum_weights_metric and always uses static teacher_weights[i].
The docstring only warns about this on the fallback branch, so kd_loss_mode="averaged_logits" + sum_weights_metric="ce" quietly drops dynamic weighting.
There was a problem hiding this comment.
Fixed in d152377 (chose fail-loud over honoring it, to keep reference parity). The PT reference (train_distillation_ddp.py) does not apply dynamic weights in averaged_logits — it sets dynamic_weights = None and blends with the static config weight (lines 1589-1624, token_weights=None), and --sum_weights_metric is documented as a sum-mode-only knob (lines 109-110). select_teacher likewise sets dynamic_weights = None. So honoring the metric here would diverge from the reference; instead the loss fn now raises at construction if sum_weights_metric is set with any mode other than sum, so it can no longer be silently dropped. Added a test covering both averaged_logits and select_teacher.
| "kl_loss_weight": 1.0, | ||
| "ce_loss_scale": 1.0, | ||
| "dynamic_loss_scaling": False, | ||
| "kd_loss_mode": "sum", |
There was a problem hiding this comment.
No unit coverage for averaged_logits or select_teacher — every test sets kd_loss_mode="sum".
Those two branches are exactly where the prev. comments on 1922 / 1919 / 1871 and the new loss-side comments live, so CI won't catch any of those regressions.
There was a problem hiding this comment.
Partially addressed in 842c095 — added two tests that exercise the averaged_logits guard directly against the real loss fn (the cross-tokenizer-shape bug above), built via __new__ since this module otherwise mocks the loss fn. Broader mode coverage (select_teacher, the DP-aggregation issue) is still pending and I'll follow up separately.
There was a problem hiding this comment.
Closed out in aa79d97. Added a parametrized end-to-end test asserting select_teacher picks the lowest-CE teacher (two same-vocab teachers: one peaks on the true next token, the other is uniform; parametrized over which index is better). Together with the averaged_logits guard tests (M1), the masked ce/entropy/max_prob scoring tests (M2), the sum_weights_metric-rejection tests (M3) and the DP-global-score helper test (H1), both previously-uncovered modes now have coverage.
| @@ -492,68 +538,218 @@ def test_ipc_buffer_released_on_student_train_failure(mock_xtoken_components): | |||
|
|
|||
There was a problem hiding this comment.
test_ipc_buffer_released_on_student_train_failure (~lines 527-537, unchanged by this PR) wouldn't catch a multi-teacher regression.
It's single-teacher only and asserts release_ipc_buffer.call_count >= 1 — releasing only teacher_policies[0] (a moved finally, an early break) still passes.
Suggest fix: Parametrize on num_teachers ∈ {1, 2} and assert call_count == len(teacher_policies).
There was a problem hiding this comment.
Done in 1b124e7. Parametrized on num_teachers in {1, 2} and now assert release_ipc_buffer is called exactly once per teacher (== len(teacher_policies)) instead of >= 1, so a release-only-teacher-0 bug would now fail. (Used same-vocab teachers so the release-loop check doesn't depend on per-teacher batch keys.)
| t1.get_teacher_logits_ipc.assert_called_once() | ||
| assert t1.get_teacher_logits_ipc.call_args.kwargs["mode"] == "topk" | ||
| assert t1.get_teacher_logits_ipc.call_args.kwargs["k"] == 8 | ||
| # Serial collocated execution: each teacher onloaded then offloaded. |
There was a problem hiding this comment.
"Serial" is in the test name but isn't actually verified.
The test only checks that each per-teacher mock method is called once, not that t0.offload happens BEFORE t1.prepare — which is the whole point of serial collocation.
Suggest fix: Attach the per-teacher mocks to a shared MagicMock parent and assert parent.mock_calls shows the interleaved order.
There was a problem hiding this comment.
Done in 1b124e7. Attached both teacher mocks to a shared parent and now assert the interleaved order — teacher 0's offload_after_refit happens before teacher 1's prepare_for_lp_inference — so loading both teachers at once (the failure the 'serial' name guards against) would now fail. Kept the per-teacher called-once checks too.
| """ | ||
| distill_cfg = master_config.distillation | ||
| timer = timer if timer is not None else Timer() | ||
| skip_keys = xtoken_non_student_seq_keys(loss_fn) |
There was a problem hiding this comment.
skip_keys = xtoken_non_student_seq_keys(loss_fn) is rebuilt here on every validate call, parallel to the train loop's build at line 624.
It's purely a function of loss_fn.{num_teachers, projection_matrix_paths, teacher_ships_full}, so duplicating it lets the two paths drift.
Suggest fix: Cache on loss_fn (or return from setup) and thread through.
There was a problem hiding this comment.
Done in 00358ad. xtoken_non_student_seq_keys(loss_fn) is now built once in the train loop and passed into validate as a skip_keys arg; removed validate's own rebuild. It is a pure function of the loss fn's static per-teacher metadata, so both sites always produced the identical set — this is behavior-preserving (verified the validate + train-loop tests still pass), just removes the duplicate build and the drift risk. (Kept the build in the train loop rather than setup: the setup unit tests patch the loss-fn class to an unconfigured mock, so building there would fail on range(mock.num_teachers).)
| @@ -1,8 +1,12 @@ | |||
| # Single-teacher cross-tokenizer off-policy distillation. | |||
| # Multi-teacher cross-tokenizer off-policy distillation. | |||
There was a problem hiding this comment.
could you add a nightly test to guard multi-teacher?
you can refer to the three file in #2745.
examples/configs/recipes/llm/distillation-xtoken-off-policy-qwen3-4b-to-llama3.2-1b-1n8g-dtensor-tp4cp2.yamltests/test_suites/llm/distillation-xtoken-off-policy-qwen3-4b-to-llama3.2-1b-1n8g-dtensor-tp4cp2.shtests/test_suites/nightly.txt
|
@terrykong could you take a review as well? |
- Rename module docstring to multi-teacher (was single-teacher).
- Add "Off-Policy" to the X-Token guide title and README section (and
fix the README TOC anchor) so the off-policy nature is explicit.
- Set `teacher: null` in the exemplar (inherited from distillation_math,
unused by xtoken) with a clarifying comment.
- Drop the no-op `packed.to("cpu")` and its stale comment: train_data
holds only IPC handle dicts and already-CPU dataloader tensors.
- Remove the redundant multi_teacher_cross_tokenizer_smoke.yaml exemplar.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
The final box's right vertical border sat 2 columns past its top and bottom horizontal borders, so the box rendered crooked. Trim the 2 trailing pad spaces on each content line so the right edge lines up at the same column as the box corners. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…ct-KL The averaged_logits fast path took a single direct per-position KL whenever all teachers' full-logit tensors shared a shape. Two cross-tokenizer teachers can share a shape while still needing the projection/alignment path, so they slipped into the fast path and mismatched the student's token_mask when the teacher sequence length differs (T_t != T_s). Gate the fast path on every projection_matrix_path being None (genuine same-tokenizer teachers) in addition to the shape check. Add regression tests for both branches of the guard in the tracked test module so CI exercises them. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
select_teacher and dynamic-weight (sum_weights_metric) scoring averaged the per-teacher ce/entropy/max_prob over every position with reduction="mean". Padded positions carry near-uniform logits that dominate the score on long-padded batches and skew teacher selection / weights. Mask each score by the teacher's own token_mask (the student's for a same-vocab teacher) and the sample_mask before reducing, matching the main loss; thread the per-teacher token mask through _teacher_score_inputs. Add regression tests asserting padded positions and dropped samples do not move the score. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Dynamic teacher weighting (sum_weights_metric) is only applied in kd_loss_mode="sum"; averaged_logits and select_teacher use static per-teacher weights, matching the reference. Combining sum_weights_metric with either mode silently dropped the metric. Reject the combo at loss-fn construction instead, alongside the existing xtoken/gold guard. Add a test asserting the combo raises for both averaged_logits and select_teacher. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
select_teacher and dynamic-weight (sum_weights_metric) scoring reduced each teacher's metric over the rank-local shard only, so DP ranks could pick a different teacher / different weights. The per-teacher KD then fires a dp_all_reduce_sum collective, so divergent choices across ranks deadlock (a cross-tokenizer teacher fires the collective, a same-vocab one does not) or silently mix reductions. Add _dp_global_masked_mean, which all-reduces the masked score sum and the valid count across the DP group, and use it for the select-teacher CE and the weight-metric (ce/entropy/max_prob) scores. Every rank now computes the identical global score, so all ranks select the same teacher / same weights and hit the same collective sequence. The score is detached (it gates selection/weighting and is not back-propagated, matching the reference's no_grad scoring). Verified cross-rank agreement with a 2-rank gloo run (ranks disagree on the local mean, agree on the global mean). Add a single-process unit test for the helper. No effect on the single-teacher sum/null path used for gold parity. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Multi-teacher is now the only model (one teacher is a 1-element `teachers:` list), so the back-compat single-`teacher:` path is dead weight. - Remove `normalize_multi_teacher_config` (the legacy `teacher:` -> `teachers:` fold) and its call in the entrypoint; configs must use `teachers:`. - Remove the loss constructor's single-teacher fallback branch; it now always reads the injected per-teacher lists. - Drop the now-unused legacy fields `teacher_vocab_size` and `projection_matrix_path` from CrossTokenizerDistillationLossConfig. The per-teacher `teacher_vocab_sizes` / `projection_matrix_paths` lists (and `TeacherConfig.projection_matrix_path`) carry these; the `NotRequired` markers stay on the runtime-injected lists since they are absent in raw YAML. - Migrate the `_ct_loss_cfg` test helper to the plural per-teacher form and remove the now-obsolete shim unit test. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
`xtoken_non_student_seq_keys(loss_fn)` was rebuilt inside `validate` on every validation pass, parallel to the train loop's build, so the two copies could drift. It is a pure function of the loss fn's static per-teacher metadata, so both always produce the identical set. Build it once in the train loop and pass it into `validate` as a `skip_keys` argument; drop validate's rebuild. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
- IPC release: parametrize on num_teachers in {1, 2} and assert
release_ipc_buffer is called exactly once per teacher. The old test was
single-teacher and asserted call_count >= 1, so a bug releasing only
teacher 0 (misplaced finally / early break) would still pass while leaking
the rest.
- Serial collocation: attach both teacher mocks to a shared parent and assert
teacher 0 is offloaded before teacher 1 is onloaded. The test was named
"runs_serially" but only checked each load/offload was called once, never
the order, so loading both teachers at once would have passed.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Asserts select_teacher uses the teacher with the lowest next-token CE: two same-vocab teachers, one peaking on the true next token (CE ~ 0) and one uniform (CE = log V), parametrized over which index is better so it can't pass by always returning 0. Completes the averaged_logits / select_teacher coverage gap (the guard + scoring branches were already covered by the M1/M2/M3/H1 fixes). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…ring
The CrossTokenizerDistillationLossDataDict docstring described same-vocab
top-K teachers as shipping dense teacher_{i}_topk_logits / _indices [B, T_s, k]
and attributed teacher_{i}_full_logits_ipc to Policy.get_full_logits_ipc. Both
are stale after the IPC-producer unification: top-K teachers now ship
teacher_{i}_topk_ipc (a CUDA IPC handle list), and the handles are produced by
get_teacher_logits_ipc(mode="full"/"topk").
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Builds on #2757 (CE + H-KL gold loss). Extends cross-tokenizer off-policy distillation to support multiple teachers via a
teachers:list ofTeacherConfig.What
projection_matrix_path=null) use direct per-position KL with no alignment; cross-tokenizer teachers keep the projection + alignment path.kd_loss_modeaggregation knob:sum/averaged_logits/select_teacher.examples/configs/multi_teacher_cross_tokenizer_smoke.yamlexemplar (Qwen3-4B cross-tokenizer teacher + Llama-3.2-3B same-vocab teacher).tests/unit/algorithms/x_token/test_xtoken_off_policy_distillation.py.🤖 Generated with Claude Code