diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index c59aa4346..857b60a62 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -31,8 +31,10 @@ import shutil import signal import tempfile +import threading import uuid from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from dataclasses import replace as dataclass_replace from datetime import datetime @@ -326,8 +328,12 @@ def _precompute_isl_for_multi_turn( Only affects dataset-history turns; live-history turns override 'messages' at runtime so the stored input_tokens are stale (acceptable approximation). """ + samples_with_messages = [s for s in (dataloader.data or []) if s.get("messages")] + if not samples_with_messages: + return + try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + AutoTokenizer.from_pretrained(tokenizer_name) except Exception: logger.exception( "ISL pre-computation: failed to load tokenizer %s; " @@ -335,15 +341,20 @@ def _precompute_isl_for_multi_turn( tokenizer_name, ) return - skipped = 0 - first_failure_logged = False - for sample in dataloader.data or []: - messages = sample.get("messages") - if not messages: - continue + + thread_local = threading.local() + first_failure_lock = threading.Lock() + + def _get_tokenizer() -> Any: + if getattr(thread_local, "tokenizer", None) is None: + thread_local.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + return thread_local.tokenizer + + def _tokenize_sample(sample: dict) -> list[int] | None: try: + tokenizer = _get_tokenizer() normalized_messages = [] - for msg in messages: + for msg in sample["messages"]: if msg.get("tool_calls"): msg = { **msg, @@ -362,26 +373,47 @@ def _precompute_isl_for_multi_turn( # Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding # instead of a plain list; extract .input_ids in that case. token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw - sample["input_tokens"] = token_ids + return token_ids except Exception: - if not first_failure_logged: + if first_failure_lock.acquire(blocking=False): logger.exception( "ISL pre-computation: apply_chat_template failed (first failure shown)" ) - first_failure_logged = True - skipped += 1 + return None + + n_workers = min(os.cpu_count() or 4, 8) + skipped = 0 + with ThreadPoolExecutor( + max_workers=n_workers, thread_name_prefix="ISLPrecompute" + ) as pool: + futures = { + pool.submit(_tokenize_sample, sample): sample + for sample in samples_with_messages + } + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Pre-computing ISL", + unit="turn", + ): + sample = futures[future] + token_ids = future.result() + if token_ids is not None: + sample["input_tokens"] = token_ids + else: + skipped += 1 + if skipped: logger.warning( "ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)", skipped, ) - total_with_messages = len([s for s in (dataloader.data or []) if s.get("messages")]) - if total_with_messages > 0 and skipped == total_with_messages: + if skipped == len(samples_with_messages): logger.warning( "ISL precomputation: all %d turn(s) failed apply_chat_template; " "ISL metrics will use text-tokenization fallback. " "Check tokenizer/template compatibility.", - total_with_messages, + len(samples_with_messages), )