Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions src/inference_endpoint/commands/benchmark/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
import shutil
import signal
import tempfile
import threading
import uuid
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from dataclasses import replace as dataclass_replace
from datetime import datetime
Expand Down Expand Up @@ -326,24 +328,33 @@ def _precompute_isl_for_multi_turn(
Only affects dataset-history turns; live-history turns override 'messages'
at runtime so the stored input_tokens are stale (acceptable approximation).
"""
samples_with_messages = [s for s in (dataloader.data or []) if s.get("messages")]
if not samples_with_messages:
return

try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
AutoTokenizer.from_pretrained(tokenizer_name)
except Exception:
logger.exception(
"ISL pre-computation: failed to load tokenizer %s; "
"falling back to text-tokenization at runtime",
tokenizer_name,
)
return
skipped = 0
first_failure_logged = False
for sample in dataloader.data or []:
messages = sample.get("messages")
if not messages:
continue

thread_local = threading.local()
first_failure_lock = threading.Lock()

def _get_tokenizer() -> Any:
if getattr(thread_local, "tokenizer", None) is None:
thread_local.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
return thread_local.tokenizer

def _tokenize_sample(sample: dict) -> list[int] | None:
try:
tokenizer = _get_tokenizer()
normalized_messages = []
for msg in messages:
for msg in sample["messages"]:
if msg.get("tool_calls"):
msg = {
**msg,
Expand All @@ -362,26 +373,47 @@ def _precompute_isl_for_multi_turn(
# Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding
# instead of a plain list; extract .input_ids in that case.
token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw
sample["input_tokens"] = token_ids
return token_ids
except Exception:
if not first_failure_logged:
if first_failure_lock.acquire(blocking=False):
logger.exception(
"ISL pre-computation: apply_chat_template failed (first failure shown)"
)
first_failure_logged = True
skipped += 1
return None

n_workers = min(os.cpu_count() or 4, 8)
skipped = 0
with ThreadPoolExecutor(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If tokenization is belived to be CPU compute-bound, would ProcessPoolExecutor be a better choice here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try. Fast tokenizer doesn't release GIL until after chat template, so ProcessPoolExecutor might help

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark pins first 5 cores by default, and with that there isn't not much difference between ThreadPool and ProcessPool. In fact, 16 workers oversubscribe cpus.

max_workers=n_workers, thread_name_prefix="ISLPrecompute"
) as pool:
futures = {
pool.submit(_tokenize_sample, sample): sample
for sample in samples_with_messages
}
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Pre-computing ISL",
unit="turn",
):
sample = futures[future]
token_ids = future.result()
if token_ids is not None:
sample["input_tokens"] = token_ids
else:
skipped += 1

if skipped:
logger.warning(
"ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)",
skipped,
)
total_with_messages = len([s for s in (dataloader.data or []) if s.get("messages")])
if total_with_messages > 0 and skipped == total_with_messages:
if skipped == len(samples_with_messages):
logger.warning(
"ISL precomputation: all %d turn(s) failed apply_chat_template; "
"ISL metrics will use text-tokenization fallback. "
"Check tokenizer/template compatibility.",
total_with_messages,
len(samples_with_messages),
)


Expand Down
Loading