Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/packages/core/agent_framework/_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,15 @@ async def _prepare_messages_for_model_call(
return prepared_messages
from ._compaction import apply_compaction

# Compact the caller's list in place when possible. A compaction operation has
# two halves: exclusion flags (mutated on shared Message objects) and inserted
# summary messages. Operating on the original list keeps both halves on the list
# the function-invocation tool loop reuses across iterations; otherwise inserted
# summaries would be lost on a throwaway copy while exclusions persisted, silently
# dropping older groups (issue #4991).
working_messages = messages if isinstance(messages, list) else prepared_messages
return await apply_compaction(
prepared_messages,
working_messages,
strategy=compaction_strategy,
tokenizer=tokenizer,
)
Expand Down
41 changes: 34 additions & 7 deletions python/packages/core/agent_framework/_compaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import json
import logging
from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -92,10 +92,23 @@ def _is_reasoning_only_assistant(message: Message) -> bool:
return all(content.type == "text_reasoning" for content in message.contents)


def _ensure_message_ids(messages: list[Message]) -> None:
def _ensure_message_ids(
messages: list[Message], *, id_offset: int = 0, reserved_ids: Iterable[str] | None = None
) -> None:
existing_ids: set[str] = set(reserved_ids) if reserved_ids is not None else set()
existing_ids.update(message.message_id for message in messages if message.message_id)
for index, message in enumerate(messages):
if not message.message_id:
message.message_id = f"msg_{index}"
if message.message_id:
continue
candidate = f"msg_{id_offset + index}"
if candidate in existing_ids:
counter = id_offset + len(messages)
candidate = f"msg_{counter}"
while candidate in existing_ids:
counter += 1
candidate = f"msg_{counter}"
message.message_id = candidate
existing_ids.add(candidate)


def _group_id_for(message: Message, group_index: int) -> str:
Expand All @@ -104,14 +117,27 @@ def _group_id_for(message: Message, group_index: int) -> str:
return f"group_index_{group_index}"


def group_messages(messages: list[Message]) -> list[dict[str, Any]]:
def group_messages(
messages: list[Message], *, id_offset: int = 0, reserved_ids: Iterable[str] | None = None
) -> list[dict[str, Any]]:
"""Compute group spans and metadata for annotation.

Args:
messages: The messages (or a slice of them) to group.

Keyword Args:
id_offset: Absolute starting index used when auto-assigning ``message_id``
values, so incremental annotation of a list slice produces ids that
stay unique across the full list.
reserved_ids: Message ids that already exist outside ``messages`` (for
example in a preserved prefix). Auto-assigned ids are guaranteed not
to collide with these, preventing duplicate ids across the full list.

Returns:
Ordered list of lightweight span dicts with keys:
``group_id``, ``kind``, ``start_index``, ``end_index``, ``has_reasoning``.
"""
_ensure_message_ids(messages)
_ensure_message_ids(messages, id_offset=id_offset, reserved_ids=reserved_ids)
spans: list[dict[str, Any]] = []
i = 0
group_index = 0
Expand Down Expand Up @@ -439,7 +465,8 @@ def annotate_message_groups(
if previous_group_index is not None:
group_index_offset = previous_group_index + 1

spans = group_messages(messages[start_index:])
reserved_ids = {message.message_id for message in messages[:start_index] if message.message_id}
spans = group_messages(messages[start_index:], id_offset=start_index, reserved_ids=reserved_ids)
for span_index, span in enumerate(spans):
group_id = str(span["group_id"])
kind = _coerce_group_kind(span["kind"])
Expand Down
194 changes: 194 additions & 0 deletions python/packages/core/tests/core/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
GROUP_TOKEN_COUNT_KEY,
BaseChatClient,
ChatResponse,
ChatResponseUpdate,
Content,
Message,
SlidingWindowStrategy,
SupportsChatGetResponse,
ToolResultCompactionStrategy,
TruncationStrategy,
tool,
)


Expand Down Expand Up @@ -258,6 +262,196 @@ async def _capture(
assert captured_token_counts == [[19, 19]]


def _tool_call_response(call_id: str, location: str) -> ChatResponse:
return ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id=call_id,
name="lookup_weather",
arguments=f'{{"location": "{location}"}}',
)
],
),
response_id=f"resp_{call_id}",
)


def _is_tool_result_summary(message: Message) -> bool:
text = message.text or ""
return message.role == "assistant" and text.startswith("[Tool results:")


async def test_function_loop_persists_inserted_summaries_across_iterations(
chat_client_base: SupportsChatGetResponse,
) -> None:
# Regression test for #4991: compaction inserts summary messages and excludes the
# originals. Across tool-loop iterations the exclusion flags persisted (shared Message
# objects) but the inserted summaries were dropped (they only lived on a throwaway copy),
# so older tool groups were silently lost with no summary representing them.
chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined]
chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined]
chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined]

@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"

chat_client_base.run_responses = [ # type: ignore[attr-defined]
_tool_call_response("call_1", "London"),
_tool_call_response("call_2", "Paris"),
_tool_call_response("call_3", "Tokyo"),
]

captured_inputs: list[list[Message]] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]

async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_inputs.append(list(messages))
return await original(messages=messages, options=options, **kwargs)

chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]

await chat_client_base.get_response(
[Message(role="user", contents=["What is the weather in London?"])],
options={"tools": [lookup_weather]}, # type: ignore[typeddict-unknown-key]
)

# The final model call should represent every compacted tool group with a summary.
# Two older tool groups get collapsed (London, Paris) while the last (Tokyo) is kept.
final_input = captured_inputs[-1]
summaries = [message for message in final_input if _is_tool_result_summary(message)]
summary_text = " ".join(message.text or "" for message in summaries)

assert len(summaries) == 2, [message.text for message in final_input]
assert "London" in summary_text
assert "Paris" in summary_text


def _tool_call_update(call_id: str, location: str) -> list[ChatResponseUpdate]:
return [
ChatResponseUpdate(
contents=[
Content.from_function_call(
call_id=call_id,
name="lookup_weather",
arguments=f'{{"location": "{location}"}}',
)
],
role="assistant",
finish_reason="stop",
response_id=f"resp_{call_id}",
)
]


async def test_function_loop_persists_inserted_summaries_across_iterations_streaming(
chat_client_base: SupportsChatGetResponse,
) -> None:
# Streaming counterpart of the #4991 regression test: the summary persistence fix in
# ``_prepare_messages_for_model_call`` must cover the streaming tool loop too.
chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined]
chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined]
chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined]

@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"

chat_client_base.streaming_responses = [ # type: ignore[attr-defined]
_tool_call_update("call_1", "London"),
_tool_call_update("call_2", "Paris"),
_tool_call_update("call_3", "Tokyo"),
]

captured_inputs: list[list[Message]] = []
original = chat_client_base._get_streaming_response # type: ignore[attr-defined]

def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
):
captured_inputs.append(list(messages))
return original(messages=messages, options=options, **kwargs)

chat_client_base._get_streaming_response = _capture # type: ignore[attr-defined,method-assign]

stream = chat_client_base.get_response(
[Message(role="user", contents=["What is the weather in London?"])],
stream=True,
options={"tools": [lookup_weather]}, # type: ignore[typeddict-unknown-key]
)
async for _ in stream:
pass

final_input = captured_inputs[-1]
summaries = [message for message in final_input if _is_tool_result_summary(message)]
summary_text = " ".join(message.text or "" for message in summaries)

assert len(summaries) == 2, [message.text for message in final_input]
assert "London" in summary_text
assert "Paris" in summary_text


async def test_function_loop_compaction_conversation_id_mode_does_not_resend_history(
chat_client_base: SupportsChatGetResponse,
) -> None:
# In conversation-id mode the server owns prior context, so the tool loop clears
# ``prepped_messages`` and only sends the latest message. Compaction must not fight that
# by re-inserting summaries or re-sending earlier turns.
chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined]
chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined]
chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined]

@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"

def _conversation_tool_call(call_id: str, location: str) -> ChatResponse:
response = _tool_call_response(call_id, location)
response.conversation_id = "conv_1"
return response

chat_client_base.run_responses = [ # type: ignore[attr-defined]
_conversation_tool_call("call_1", "London"),
_conversation_tool_call("call_2", "Paris"),
_conversation_tool_call("call_3", "Tokyo"),
]

captured_inputs: list[list[Message]] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]

async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_inputs.append(list(messages))
return await original(messages=messages, options=options, **kwargs)

chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]

await chat_client_base.get_response(
[Message(role="user", contents=["What is the weather in London?"])],
options={"tools": [lookup_weather]}, # type: ignore[typeddict-unknown-key]
)

# After the conversation id is established the loop only forwards the latest message,
# so subsequent model calls never receive the full history or summary messages.
for sent in captured_inputs[1:]:
assert len(sent) <= 1, [message.text for message in sent]
assert not any(_is_tool_result_summary(message) for message in sent)


def test_base_client_as_agent_does_not_copy_client_compaction_defaults(
chat_client_base: SupportsChatGetResponse,
) -> None:
Expand Down
Loading
Loading