Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions scripts/agentic_inference_isl_precompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Comment thread
tianmu-li marked this conversation as resolved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Offline ISL (Input Sequence Length) computation for multi-turn datasets.

Run from the repo root to print the ISL distribution for a dataset::

python scripts/agentic_inference_isl_precompute.py \\
--dataset path/to/dataset.jsonl \\
--tokenizer <model-name-or-path>
"""

from __future__ import annotations

import argparse
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import (
_normalize_tool_calls_for_template,
)
from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset
from tqdm import tqdm
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)


def _precompute_isl(dataloader: MultiTurnDataset, tokenizer_name: str) -> None:
samples_with_messages = [s for s in (dataloader.data or []) if s.get("messages")]
if not samples_with_messages:
return

try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
except Exception:
logger.exception("Failed to load tokenizer %s", tokenizer_name)
return

first_failure_logged = False
first_failure_lock = threading.Lock()

def _tokenize_sample(sample: dict) -> list[int] | None:
Comment thread
tianmu-li marked this conversation as resolved.
try:
normalized_messages = []
for msg in sample["messages"]:
if msg.get("tool_calls"):
msg = {
**msg,
"tool_calls": _normalize_tool_calls_for_template(
msg["tool_calls"]
),
}
normalized_messages.append(msg)
tools = sample.get("tools")
raw = tokenizer.apply_chat_template(
normalized_messages,
tools=tools if tools else None,
tokenize=True,
add_generation_prompt=True,
)
# Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding
# instead of a plain list; extract .input_ids in that case.
token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw
return token_ids
except Exception:
nonlocal first_failure_logged
with first_failure_lock:
if not first_failure_logged:
logger.exception("apply_chat_template failed (first failure shown)")
first_failure_logged = True
return None
Comment thread
tianmu-li marked this conversation as resolved.

n_workers = min(os.cpu_count() or 32, 32)
skipped = 0
with ThreadPoolExecutor(
max_workers=n_workers, thread_name_prefix="ISLPrecompute"
) as pool:
futures = {
pool.submit(_tokenize_sample, sample): sample
for sample in samples_with_messages
}
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Pre-computing ISL",
unit="turn",
):
sample = futures[future]
token_ids = future.result()
if token_ids is not None:
sample["input_tokens"] = token_ids
else:
skipped += 1

if skipped:
logger.warning("%d turn(s) skipped (apply_chat_template failed)", skipped)
if skipped == len(samples_with_messages):
logger.warning(
"All %d turn(s) failed apply_chat_template. "
"Check tokenizer/template compatibility.",
len(samples_with_messages),
)


def _isl_distribution(dataloader: MultiTurnDataset) -> dict[str, float]:
values = sorted(
len(s["input_tokens"])
for s in (dataloader.data or [])
if s.get("input_tokens") is not None
)
if not values:
raise ValueError(
"No input_tokens found — tokenization may have failed entirely."
)
n = len(values)

def percentile(p: float) -> float:
idx = (p / 100) * (n - 1)
lo, frac = int(idx), idx % 1
return values[lo] + frac * (values[lo + 1] - values[lo] if lo + 1 < n else 0)

return {
"min": values[0],
"max": values[-1],
"mean": sum(values) / n,
"p50": percentile(50),
"p99": percentile(99),
}


def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")

parser = argparse.ArgumentParser(
description="Compute ISL distribution for a multi-turn dataset."
)
parser.add_argument("--dataset", required=True, help="Path to JSONL dataset file.")
parser.add_argument(
"--tokenizer", required=True, help="HuggingFace repo ID or local path."
)
args = parser.parse_args()

ds = MultiTurnDataset(pd.read_json(args.dataset, lines=True))
ds.load()
_precompute_isl(ds, args.tokenizer)

stats = _isl_distribution(ds)
Comment thread
tianmu-li marked this conversation as resolved.
n = sum(1 for s in (ds.data or []) if s.get("input_tokens") is not None)
print(f"ISL distribution ({n} turns)")
print(f" min : {stats['min']:.0f}")
print(f" mean : {stats['mean']:.1f}")
print(f" p50 : {stats['p50']:.0f}")
print(f" p99 : {stats['p99']:.0f}")
print(f" max : {stats['max']:.0f}")


if __name__ == "__main__":
main()
77 changes: 0 additions & 77 deletions src/inference_endpoint/commands/benchmark/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import msgspec.json
from huggingface_hub import model_info
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.utils import logging as transformers_logging

from inference_endpoint.async_utils.event_publisher import EventPublisherService
Expand All @@ -58,9 +57,6 @@
from inference_endpoint.async_utils.services.metrics_aggregator.subscriber import (
MetricsSnapshotSubscriber,
)
from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import (
_normalize_tool_calls_for_template,
)
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
from inference_endpoint.config.runtime_settings import RuntimeSettings
from inference_endpoint.config.schema import (
Expand Down Expand Up @@ -314,75 +310,6 @@ def _load_datasets(
return dataloader, accuracy_datasets, eval_configs


def _precompute_isl_for_multi_turn(
dataloader: MultiTurnDataset, tokenizer_name: str
) -> None:
"""Tokenize pre-built message lists and store token counts in each sample.

Runs apply_chat_template once per client turn so the hot-path IslTrigger
sync path (len(token_ids)) is used instead of on-the-fly text tokenization.
Only affects dataset-history turns; live-history turns override 'messages'
at runtime so the stored input_tokens are stale (acceptable approximation).
"""
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
except Exception:
logger.exception(
"ISL pre-computation: failed to load tokenizer %s; "
"falling back to text-tokenization at runtime",
tokenizer_name,
)
return
skipped = 0
first_failure_logged = False
for sample in dataloader.data or []:
messages = sample.get("messages")
if not messages:
continue
try:
normalized_messages = []
for msg in messages:
if msg.get("tool_calls"):
msg = {
**msg,
"tool_calls": _normalize_tool_calls_for_template(
msg["tool_calls"]
),
}
normalized_messages.append(msg)
tools = sample.get("tools")
raw = tokenizer.apply_chat_template(
normalized_messages,
tools=tools if tools else None,
tokenize=True,
add_generation_prompt=True,
)
# Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding
# instead of a plain list; extract .input_ids in that case.
token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw
sample["input_tokens"] = token_ids
except Exception:
if not first_failure_logged:
logger.exception(
"ISL pre-computation: apply_chat_template failed (first failure shown)"
)
first_failure_logged = True
skipped += 1
if skipped:
logger.warning(
"ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)",
skipped,
)
total_with_messages = len([s for s in (dataloader.data or []) if s.get("messages")])
if total_with_messages > 0 and skipped == total_with_messages:
logger.warning(
"ISL precomputation: all %d turn(s) failed apply_chat_template; "
"ISL metrics will use text-tokenization fallback. "
"Check tokenizer/template compatibility.",
total_with_messages,
)


def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext:
"""Load tokenizer, dataset, create scheduler, setup report dir."""
# CPU affinity
Expand Down Expand Up @@ -423,10 +350,6 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo
# Datasets
dataloader, accuracy_datasets, eval_configs = _load_datasets(config, report_dir)

if isinstance(dataloader, MultiTurnDataset) and tokenizer_name is not None:
logger.info("Pre-computing ISL token counts for multi-turn dataset…")
_precompute_isl_for_multi_turn(dataloader, tokenizer_name)

# Setup runtime settings using factory method
Comment thread
tianmu-li marked this conversation as resolved.
rt_settings = RuntimeSettings.from_config(config, dataloader.num_samples())

Expand Down
Loading
Loading