From 0de5fed08de994b9ad3640e99e66113ce5b3bab3 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Tue, 2 Jun 2026 23:54:28 +0200 Subject: [PATCH 1/6] Add: add blocking gates for langgraph_stategraph and langchain_executor --- sdk/adrian/__init__.py | 443 +++++++++++++++++++++++++++++++---------- sdk/adrian/ws.py | 12 ++ 2 files changed, 350 insertions(+), 105 deletions(-) diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py index 0b1ed09..1d312ab 100644 --- a/sdk/adrian/__init__.py +++ b/sdk/adrian/__init__.py @@ -512,6 +512,23 @@ def _inject_callbacks(config: Any) -> Any: # noqa: ANN401 # ------------------------------------------------------------------ +def _warn_unsupported_frameworks() -> None: + """Warn when raw openai SDK is used without a patchable framework.""" + import importlib.util + + if importlib.util.find_spec("openai") is not None: + # Only warn if openai is present without langchain + if importlib.util.find_spec("langchain_core") is None: + logger.warning( + "Detected raw openai SDK without LangChain/LangGraph. " + "Adrian's pre-execution block gate requires a supported " + "framework (LangGraph ToolNode or LangChain AgentExecutor). " + "Tool calls from raw openai loops are OBSERVED but NOT " + "pre-blocked in MODE_BLOCK. " + "See https://docs.adrian.secureagentics.ai/supported-frameworks" + ) + + def _auto_instrument_langchain() -> None: """Apply all monkey-patches to LangChain / LangGraph.""" try: @@ -520,6 +537,8 @@ def _auto_instrument_langchain() -> None: _patch_chat_model() _patch_langgraph() _patch_tool_node() + _patch_agent_executor() + _warn_unsupported_frameworks() logger.debug("LangChain auto-instrumentation applied") except ImportError: logger.debug("LangChain not found, skipping auto-instrumentation") @@ -531,12 +550,14 @@ def _auto_instrument_langchain() -> None: def _patch_runnable() -> None: - """Patch ``Runnable.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``Runnable.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(Runnable, "_adrian_patched", False): return original_invoke = Runnable.invoke original_ainvoke = Runnable.ainvoke + original_astream = Runnable.astream + original_stream = Runnable.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -544,9 +565,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync Runnable call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config, **kwargs) async def patched_ainvoke( @@ -555,15 +574,35 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async Runnable call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """AgentExecutor calls astream on the agent chain by default.""" + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config, **kwargs) + Runnable.invoke = patched_invoke # type: ignore[assignment] Runnable.ainvoke = patched_ainvoke # type: ignore[assignment] + Runnable.astream = patched_astream # type: ignore[assignment] + Runnable.stream = patched_stream # type: ignore[assignment] Runnable._adrian_patched = True # type: ignore[attr-defined] - logger.debug("Patched Runnable.invoke / ainvoke") + logger.debug("Patched Runnable.invoke / ainvoke / astream / stream") # --- 2. CallbackManager --- @@ -634,12 +673,14 @@ def patched_configure( def _patch_chat_model() -> None: - """Patch ``BaseChatModel.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``BaseChatModel.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(BaseChatModel, "_adrian_chat_model_patched", False): return original_invoke = BaseChatModel.invoke original_ainvoke = BaseChatModel.ainvoke + original_astream = BaseChatModel.astream + original_stream = BaseChatModel.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -647,9 +688,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync chat model call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -658,15 +697,34 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async chat model call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config=config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config=config, **kwargs) + BaseChatModel.invoke = patched_invoke # type: ignore[assignment] BaseChatModel.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseChatModel.astream = patched_astream # type: ignore[assignment] + BaseChatModel.stream = patched_stream # type: ignore[assignment] BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] - logger.debug("Patched BaseChatModel.invoke / ainvoke") + logger.debug("Patched BaseChatModel.invoke / ainvoke / astream / stream") # --- 4. LangGraph Pregel --- @@ -761,26 +819,22 @@ async def patched_astream( def _extract_tool_calls( - state: dict[str, Any] | list[BaseMessage], + state: dict[str, Any] | list[BaseMessage] | Any, ) -> list[dict[str, str]]: - """Extract tool_calls from the last AIMessage in ToolNode state. - - LangGraph's ``ToolNode.ainvoke`` accepts two input shapes: a state - dict whose ``"messages"`` key holds the message list, or a bare - list of messages. We handle both. - - Args: - state: The ToolNode input, a state dict with a ``"messages"`` - key, or a direct list of ``BaseMessage`` instances. + """Extract tool_calls from ToolNode input (state dict, message list, or per-tool-call dict).""" + if isinstance(state, dict) and "tool_call" in state: + tc = state["tool_call"] + if isinstance(tc, dict) and tc.get("id"): + return [tc] + if hasattr(tc, "id") and tc.id: + return [{"id": tc.id, "name": getattr(tc, "name", ""), "args": getattr(tc, "args", {})}] - Returns: - List of tool call dicts from the most recent ``AIMessage``, or - an empty list when none is found. - """ if isinstance(state, dict): messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - else: + elif isinstance(state, list): messages = list(state) + else: + return [] for msg in reversed(messages): if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): @@ -792,10 +846,7 @@ def _extract_tool_calls( def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. - HITL resolutions override everything: ``continue_execution=False`` - means halt, ``True`` means continue. Otherwise the per-MAD policy - bool is the sole scope authority, if the verdict's tier is - in-scope, halt; if not, continue. + HITL resolutions override the per-MAD policy scope check. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution @@ -814,14 +865,7 @@ def _should_halt(verdict: pb.Verdict) -> bool: def _build_blocked_response( tool_calls: list[dict[str, str]], ) -> dict[str, list[ToolMessage]]: - """Build synthetic ToolMessage responses for blocked tool calls. - - Args: - tool_calls: List of tool call dicts extracted from the AIMessage. - - Returns: - Dict in the format ToolNode expects. - """ + """Build synthetic ToolMessage responses for blocked tool calls.""" blocked_messages: list[ToolMessage] = [ ToolMessage( content="[BLOCKED by security policy]", @@ -834,13 +878,67 @@ def _build_blocked_response( return {"messages": blocked_messages} +async def _adrian_tool_gate( + input: Any, # noqa: A002, ANN401 +) -> tuple[str, dict[str, Any] | None]: + """Pre-execution verdict gate. Returns ("halt", response), ("proceed", None), or ("skip", None).""" + ws = _ws_client + + if ws is None: + return ("skip", None) + + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "ToolNode: LoginAck not received within 5s; halting " + "(refusing to run a tool without a verified policy)" + ) + return ("halt", _build_blocked_response(_extract_tool_calls(input))) + + if not ws.policy_active(): + return ("skip", None) + + tool_calls = _extract_tool_calls(input) + tool_call_id = next( + (tc.get("id") for tc in tool_calls if tc.get("id")), + None, + ) + + if not tool_call_id: + return ("skip", None) + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + + verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) + + if verdict is None: + logger.warning( + "verdict timeout for tool_call_id=%s, fail-open", + tool_call_id, + ) + return ("skip", None) + + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return ("halt", _build_blocked_response(tool_calls)) + + return ("proceed", None) + + def _patch_tool_node() -> None: - """Patch ``ToolNode.invoke`` / ``ainvoke``. + """Patch ToolNode._afunc with the verdict gate, and public methods for callback injection. - In block mode, the async patch waits for the preceding LLM's verdict - before executing tools. On BLOCK (unless overridden by ``on_block``) - it returns synthetic ``ToolMessage`` responses instead of running the - tools. On timeout it fails open. + _afunc is the only reliable intercept -- Pregel bypasses ainvoke/astream entirely. """ try: from langgraph.prebuilt import ToolNode @@ -852,6 +950,22 @@ def _patch_tool_node() -> None: original_invoke = ToolNode.invoke original_ainvoke = ToolNode.ainvoke + original_astream = getattr(ToolNode, "astream", None) + original_stream = getattr(ToolNode, "stream", None) + original_afunc = ToolNode._afunc # type: ignore[attr-defined] + + async def patched_afunc( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + runtime: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate on ToolNode._afunc.""" + decision, blocked = await _adrian_tool_gate(input) + if decision == "halt": + return blocked + + return await original_afunc(self, input, config=config, runtime=runtime) def patched_invoke( self: Any, # noqa: ANN401 @@ -859,7 +973,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync ToolNode invocation.""" + """Inject Adrian callbacks into ToolNode.invoke.""" config = _inject_callbacks(config) return original_invoke(self, input, config=config, **kwargs) @@ -870,75 +984,194 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks; in BLOCK / HITL modes wait for verdict. + """Inject Adrian callbacks into ToolNode.ainvoke.""" + config = _inject_callbacks(config) - Per-tool-call correlation: every tool_call.id is mapped (in - ``WebSocketClient`` ) to the event_id of the LLM that emitted - it. Each ToolNode invocation awaits its specific LLM's verdict, - race-free under parallel agents, no graph-wide pause. - """ + return await original_ainvoke(self, input, config=config, **kwargs) + + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Inject Adrian callbacks into ToolNode.astream.""" config = _inject_callbacks(config) - ws = _ws_client - if ws is None: - return await original_ainvoke(self, input, config=config, **kwargs) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk - # First-tool-call window: the recv loop may not have processed - # ``LoginAck`` yet, so ``policy_active()`` reads False even - # when the org is in BLOCK or HITL. Wait for the LoginAck - # event before checking. If it doesn't arrive within the - # window, halt, refusing to run is the only safe outcome - # when we can't verify the org's policy. - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: - logger.warning( - "ToolNode: LoginAck not received within 5s; halting " - "(refusing to run a tool without a verified policy)" - ) - return _build_blocked_response(_extract_tool_calls(input)) - - if not ws.policy_active(): - return await original_ainvoke(self, input, config=config, **kwargs) + ToolNode._afunc = patched_afunc # type: ignore[attr-defined] + ToolNode.invoke = patched_invoke # type: ignore[assignment] + ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] + if original_astream is not None: + ToolNode.astream = patched_astream # type: ignore[assignment] + ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] + logger.debug("Patched ToolNode._afunc / invoke / ainvoke / astream") - tool_calls = _extract_tool_calls(input) - tool_call_id = next( - (tc.get("id") for tc in tool_calls if tc.get("id")), - None, - ) - if not tool_call_id: - # Direct ToolNode invocation outside an LLM flow, no - # producing event_id to wait on, so let the tool run. - return await original_ainvoke(self, input, config=config, **kwargs) +# --- 6. AgentExecutor (langchain / langchain-classic) --- - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) +_BLOCKED_OBSERVATION = "[BLOCKED by security policy]" - if verdict is None: - logger.warning( - "verdict timeout for tool_call_id=%s, fail-open", - tool_call_id, - ) - return await original_ainvoke(self, input, config=config, **kwargs) - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return _build_blocked_response(tool_calls) +def _patch_agent_executor() -> None: + """Patch AgentExecutor tool dispatch with the verdict gate. - return await original_ainvoke(self, input, config=config, **kwargs) + Covers the legacy AgentExecutor path which bypasses ToolNode entirely. + Falls through for ReAct parsers that don't emit tool_call_id. + """ + AgentExecutor = None + AgentStep = None + for mod_path in ("langchain_classic.agents.agent", "langchain.agents.agent"): + try: + mod = __import__(mod_path, fromlist=["AgentExecutor", "AgentStep"]) + AgentExecutor = getattr(mod, "AgentExecutor", None) + AgentStep = getattr(mod, "AgentStep", None) + if AgentExecutor and AgentStep: + break + except ImportError: + continue + + if AgentExecutor is None or AgentStep is None: + return - ToolNode.invoke = patched_invoke # type: ignore[assignment] - ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] - ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] - logger.debug("Patched ToolNode.invoke / ainvoke") + if getattr(AgentExecutor, "_adrian_executor_patched", False): + return + + original_aperform = AgentExecutor._aperform_agent_action + original_perform = AgentExecutor._perform_agent_action + + async def patched_aperform( + self: Any, # noqa: ANN401 + name_to_tool_map: Any, # noqa: ANN401 + color_mapping: Any, # noqa: ANN401 + agent_action: Any, # noqa: ANN401 + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate before AgentExecutor dispatches a tool (async).""" + tool_call_id = getattr(agent_action, "tool_call_id", None) + + if tool_call_id: + ws = _ws_client + + if ws is not None: + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "AgentExecutor: LoginAck not received within 5s; " + "blocking tool %s", + agent_action.tool, + ) + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + + if ws.policy_active(): + cfg = _get_config() + # Short timeout: AgentExecutor LLM callbacks may not propagate, + # so verdicts may never arrive. + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict( + tool_call_id, timeout, + ) + + if verdict is not None and _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s " + "mad_code=%s (AgentExecutor path)", + verdict.event_id, + verdict.mad_code, + ) + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for " + "tool_call_id=%s, fail-open", + tool_call_id, + ) + + return await original_aperform( + self, name_to_tool_map, color_mapping, agent_action, run_manager, + ) + + def patched_perform( + self: Any, # noqa: ANN401 + name_to_tool_map: Any, # noqa: ANN401 + color_mapping: Any, # noqa: ANN401 + agent_action: Any, # noqa: ANN401 + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate before AgentExecutor dispatches a tool (sync).""" + tool_call_id = getattr(agent_action, "tool_call_id", None) + + if tool_call_id: + ws = _ws_client + + if ws is not None and ws._login_ack_received.is_set() and ws.policy_active(): # pyright: ignore[reportPrivateUsage] + import concurrent.futures + + async def _gate() -> bool: + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict( + tool_call_id, timeout, + ) + if verdict is not None and _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s " + "mad_code=%s (AgentExecutor sync path)", + verdict.event_id, + verdict.mad_code, + ) + return True + return False + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + future: concurrent.futures.Future[bool] = concurrent.futures.Future() + + async def _run() -> None: + try: + result = await _gate() + future.set_result(result) + except Exception as exc: + future.set_exception(exc) + + loop.create_task(_run()) + should_block = future.result(timeout=35) + else: + should_block = loop.run_until_complete(_gate()) + + if should_block: + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + except Exception: + logger.debug( + "AgentExecutor sync gate failed, falling through", + exc_info=True, + ) + + return original_perform( + self, name_to_tool_map, color_mapping, agent_action, run_manager, + ) + + AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] + AgentExecutor._perform_agent_action = patched_perform # type: ignore[assignment] + AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] + logger.debug("Patched AgentExecutor._aperform_agent_action / _perform_agent_action") diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py index 1ab5df4..30f1ab5 100644 --- a/sdk/adrian/ws.py +++ b/sdk/adrian/ws.py @@ -513,6 +513,18 @@ async def connect(self) -> None: else: logger.info("WebSocket connected: %s", self._url) + # Eager login: send the SessionLogin frame immediately + # so the server responds with LoginAck before any tool + # gate fires. Previously login was deferred to the + # first _send_frame call, which meant frameworks that + # don't trigger paired events (AgentExecutor) would + # never receive LoginAck and the block gate would time + # out. Provider/model are best-effort at this point + # (empty until the first LLM event auto-detects them). + if not self._logged_in: + await self._send_login(self._ws) + self._logged_in = True + # Drain anything buffered while we were offline, even # on the very first connect. ``_send_mcp_inventory`` # and other init-time emitters queue frames before the From 912456d1a5ccae9b6e1c0327c3195e99aa7831fa Mon Sep 17 00:00:00 2001 From: netan-sa Date: Tue, 9 Jun 2026 14:48:10 +0200 Subject: [PATCH 2/6] Add: agent executor flow, and fixes --- sdk/adrian/__init__.py | 561 ++++++++++++++++++++--------------------- sdk/adrian/ws.py | 12 - 2 files changed, 276 insertions(+), 297 deletions(-) diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py index 1d312ab..1b55d83 100644 --- a/sdk/adrian/__init__.py +++ b/sdk/adrian/__init__.py @@ -74,7 +74,7 @@ from adrian.types import ToolCallRecord, VerdictContext from adrian.ws import WebSocketClient -__version__ = "1.0.0" +__version__ = "1.0.2" __all__ = [ "init", "shutdown", @@ -231,10 +231,12 @@ def init( resolved_key = api_key or os.getenv("ADRIAN_API_KEY") or None resolved_file = Path(os.getenv("ADRIAN_LOG_FILE", str(log_file))) - # Default to a local self-hosted backend (the one `make dev` brings - # up at deploy/compose.yaml). OSS users pointing at a remote - # deployment override via ws_url= or ADRIAN_WS_URL. - resolved_ws_url = os.getenv("ADRIAN_WS_URL") or ws_url or "ws://localhost:8080/ws" + # Default to the hosted Adrian backend so `adrian.init(api_key=...)` + # Just Works for freemium users. Self-hosted users override via + # ws_url= or ADRIAN_WS_URL. + resolved_ws_url = ( + os.getenv("ADRIAN_WS_URL") or ws_url or "wss://adrian.secureagentics.ai/ws" + ) resolved_session = ( os.getenv("ADRIAN_SESSION_ID") or session_id or resolve_session_id() ) @@ -512,23 +514,6 @@ def _inject_callbacks(config: Any) -> Any: # noqa: ANN401 # ------------------------------------------------------------------ -def _warn_unsupported_frameworks() -> None: - """Warn when raw openai SDK is used without a patchable framework.""" - import importlib.util - - if importlib.util.find_spec("openai") is not None: - # Only warn if openai is present without langchain - if importlib.util.find_spec("langchain_core") is None: - logger.warning( - "Detected raw openai SDK without LangChain/LangGraph. " - "Adrian's pre-execution block gate requires a supported " - "framework (LangGraph ToolNode or LangChain AgentExecutor). " - "Tool calls from raw openai loops are OBSERVED but NOT " - "pre-blocked in MODE_BLOCK. " - "See https://docs.adrian.secureagentics.ai/supported-frameworks" - ) - - def _auto_instrument_langchain() -> None: """Apply all monkey-patches to LangChain / LangGraph.""" try: @@ -537,8 +522,8 @@ def _auto_instrument_langchain() -> None: _patch_chat_model() _patch_langgraph() _patch_tool_node() + _patch_base_tool() _patch_agent_executor() - _warn_unsupported_frameworks() logger.debug("LangChain auto-instrumentation applied") except ImportError: logger.debug("LangChain not found, skipping auto-instrumentation") @@ -583,7 +568,6 @@ async def patched_astream( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """AgentExecutor calls astream on the agent chain by default.""" config = _inject_callbacks(config) async for chunk in original_astream(self, input, config, **kwargs): yield chunk @@ -602,7 +586,7 @@ def patched_stream( Runnable.astream = patched_astream # type: ignore[assignment] Runnable.stream = patched_stream # type: ignore[assignment] Runnable._adrian_patched = True # type: ignore[attr-defined] - logger.debug("Patched Runnable.invoke / ainvoke / astream / stream") + logger.debug("Patched Runnable.invoke / ainvoke") # --- 2. CallbackManager --- @@ -724,7 +708,7 @@ def patched_stream( BaseChatModel.astream = patched_astream # type: ignore[assignment] BaseChatModel.stream = patched_stream # type: ignore[assignment] BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] - logger.debug("Patched BaseChatModel.invoke / ainvoke / astream / stream") + logger.debug("Patched BaseChatModel.invoke / ainvoke") # --- 4. LangGraph Pregel --- @@ -815,157 +799,256 @@ async def patched_astream( logger.debug("Patched Pregel.invoke / ainvoke / astream") -# --- 5. ToolNode --- - - -def _extract_tool_calls( - state: dict[str, Any] | list[BaseMessage] | Any, -) -> list[dict[str, str]]: - """Extract tool_calls from ToolNode input (state dict, message list, or per-tool-call dict).""" - if isinstance(state, dict) and "tool_call" in state: - tc = state["tool_call"] - if isinstance(tc, dict) and tc.get("id"): - return [tc] - if hasattr(tc, "id") and tc.id: - return [{"id": tc.id, "name": getattr(tc, "name", ""), "args": getattr(tc, "args", {})}] - - if isinstance(state, dict): - messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - elif isinstance(state, list): - messages = list(state) - else: - return [] - - for msg in reversed(messages): - if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): - return msg.tool_calls # type: ignore[no-any-return] - - return [] +# --- 5. ToolNode (callback injection only — gate is on BaseTool) --- def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. - HITL resolutions override the per-MAD policy scope check. + HITL resolutions override per-MAD policy when present. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution mad_prefix = verdict.mad_code[:2] - in_scope = { + return { "M0": verdict.policy.policy_m0, "M2": verdict.policy.policy_m2, "M3": verdict.policy.policy_m3, "M4": verdict.policy.policy_m4, }.get(mad_prefix, False) - return in_scope - -def _build_blocked_response( - tool_calls: list[dict[str, str]], -) -> dict[str, list[ToolMessage]]: - """Build synthetic ToolMessage responses for blocked tool calls.""" - blocked_messages: list[ToolMessage] = [ - ToolMessage( - content="[BLOCKED by security policy]", - tool_call_id=str(tc.get("id", "")), - name=str(tc.get("name", "")), - ) - for tc in tool_calls - ] +def _patch_tool_node() -> None: + """Patch ToolNode for callback injection + async verdict gate. - return {"messages": blocked_messages} + ToolNode dispatches tools via tool.invoke (sync) even within async + Pregel. BaseTool.invoke can't await a verdict from the event loop + thread, so we add the verdict gate here on ToolNode.ainvoke — the + entry point Pregel calls before tool dispatch begins. This is a + complementary gate to BaseTool (which covers direct callers). + """ + try: + from langgraph.prebuilt import ToolNode + except ImportError: + return + if getattr(ToolNode, "_adrian_tool_node_patched", False): + return -async def _adrian_tool_gate( - input: Any, # noqa: A002, ANN401 -) -> tuple[str, dict[str, Any] | None]: - """Pre-execution verdict gate. Returns ("halt", response), ("proceed", None), or ("skip", None).""" - ws = _ws_client + original_invoke = ToolNode.invoke + original_ainvoke = ToolNode.ainvoke + original_astream = getattr(ToolNode, "astream", None) - if ws is None: - return ("skip", None) + def _extract_tool_call_ids(state: Any) -> list[str]: # noqa: ANN401 + """Extract tool_call_ids from ToolNode input (any shape).""" + # Shape 3: per-tool-call dict from _afunc dispatch + if isinstance(state, dict) and "tool_call" in state: + tc = state["tool_call"] + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + return [tc_id] if tc_id else [] + # Shape 1/2: state dict or message list + messages = ( + list(state.get("messages") or []) + if isinstance(state, dict) + else list(state) + if isinstance(state, list) + else [] + ) + for msg in reversed(messages): + if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): + return [tc.get("id") for tc in msg.tool_calls if tc.get("id")] + return [] - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: + async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 + """Returns True if tools should be BLOCKED.""" + ws = _ws_client + if ws is None: + return False + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for(ws._login_ack_received.wait(), timeout=5.0) # pyright: ignore[reportPrivateUsage] + except TimeoutError: + logger.warning("ToolNode: LoginAck not received within 5s; blocking") + return True + if not ws.policy_active(): + return False + + tc_ids = _extract_tool_call_ids(state) + if not tc_ids: + return False + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + # Gate on the first tool_call_id (all come from the same LLM turn) + verdict = await ws.wait_for_tool_call_verdict(tc_ids[0], timeout) + if verdict is None: + logger.warning("ToolNode: verdict timeout, blocking (fail-closed)") + return True + if _should_halt(verdict): logger.warning( - "ToolNode: LoginAck not received within 5s; halting " - "(refusing to run a tool without a verified policy)" + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, ) - return ("halt", _build_blocked_response(_extract_tool_calls(input))) + return True + return False + + def _build_blocked(state: Any) -> dict[str, list[ToolMessage]]: # noqa: ANN401 + tc_ids = _extract_tool_call_ids(state) + return { + "messages": [ + ToolMessage( + content="[BLOCKED by security policy]", tool_call_id=tid, name="" + ) + for tid in tc_ids + ] + } - if not ws.policy_active(): - return ("skip", None) + def patched_invoke( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + return original_invoke(self, input, config=config, **kwargs) - tool_calls = _extract_tool_calls(input) - tool_call_id = next( - (tc.get("id") for tc in tool_calls if tc.get("id")), - None, - ) + async def patched_ainvoke( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + if await _gate_tool_calls(input): + return _build_blocked(input) + return await original_ainvoke(self, input, config=config, **kwargs) + + async def patched_astream( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + if await _gate_tool_calls(input): + yield _build_blocked(input) + return + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk - if not tool_call_id: - return ("skip", None) + ToolNode.invoke = patched_invoke # type: ignore[assignment] + ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] + if original_astream is not None: + ToolNode.astream = patched_astream # type: ignore[assignment] + ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] + logger.debug("Patched ToolNode.invoke / ainvoke / astream") - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) +# --- 6. BaseTool (universal verdict gate) --- - if verdict is None: - logger.warning( - "verdict timeout for tool_call_id=%s, fail-open", - tool_call_id, - ) - return ("skip", None) - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return ("halt", _build_blocked_response(tool_calls)) +_BLOCKED_CONTENT = "[BLOCKED by security policy]" - return ("proceed", None) +def _patch_base_tool() -> None: + """Patch ``BaseTool.invoke`` and ``BaseTool.ainvoke`` with the verdict gate. -def _patch_tool_node() -> None: - """Patch ToolNode._afunc with the verdict gate, and public methods for callback injection. + Every LangChain tool — whether dispatched by ToolNode, AgentExecutor, + create_react_agent, or a manual ``tool.invoke(tool_call)`` loop — + funnels through ``BaseTool.invoke`` (sync) or ``BaseTool.ainvoke`` + (async). Gating here covers all frameworks in one place. + + The gate extracts ``tool_call_id`` from the input (a ``ToolCall`` + TypedDict), awaits the classifier verdict for the producing LLM + event, and returns a ``[BLOCKED]`` string instead of running the + tool body when the verdict is in-scope (M3/M4 under MODE_BLOCK). - _afunc is the only reliable intercept -- Pregel bypasses ainvoke/astream entirely. + In MODE_BLOCK, verdict timeout is fail-closed (block the tool) + because the absence of a verdict in block mode is a policy violation. + In MODE_ALERT, no gate fires at all (skip). """ - try: - from langgraph.prebuilt import ToolNode - except ImportError: - return + from langchain_core.tools import BaseTool + from langchain_core.tools.base import _is_tool_call # pyright: ignore[reportPrivateUsage] - if getattr(ToolNode, "_adrian_tool_node_patched", False): + if getattr(BaseTool, "_adrian_base_tool_patched", False): return - original_invoke = ToolNode.invoke - original_ainvoke = ToolNode.ainvoke - original_astream = getattr(ToolNode, "astream", None) - original_stream = getattr(ToolNode, "stream", None) - original_afunc = ToolNode._afunc # type: ignore[attr-defined] + original_invoke = BaseTool.invoke + original_ainvoke = BaseTool.ainvoke - async def patched_afunc( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - runtime: Any = None, # noqa: ANN401 - ) -> Any: # noqa: ANN401 - """Verdict gate on ToolNode._afunc.""" - decision, blocked = await _adrian_tool_gate(input) - if decision == "halt": - return blocked + def _extract_tool_call_id(input: Any) -> str | None: # noqa: A002, ANN401 + """Extract tool_call_id from a ToolCall input, or None.""" + if isinstance(input, dict) and _is_tool_call(input): + return input.get("id") + return None + + async def _async_gate(tool_call_id: str) -> bool: + """Returns True if the tool should be BLOCKED.""" + ws = _ws_client + if ws is None: + return False + + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "BaseTool: LoginAck not received within 5s; " + "blocking tool (refusing to run without verified policy)" + ) + return True + + if not ws.policy_active(): + return False + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) + + if verdict is None: + # Fail-closed in block mode: no verdict = block. + logger.warning( + "BaseTool: verdict timeout for tool_call_id=%s; " + "blocking (fail-closed in MODE_BLOCK)", + tool_call_id, + ) + return True + + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return True + + return False + + def _sync_gate(tool_call_id: str) -> bool: + """Sync verdict gate for pure-sync callers (no running event loop). + + When called from within an async loop (ToolNode._func dispatched + by Pregel), this cannot work — use the ToolNode.ainvoke gate + instead. Returns False (skip) when a running loop is detected. + """ + ws = _ws_client + if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] + return False - return await original_afunc(self, input, config=config, runtime=runtime) + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Can't block the event loop thread. The ToolNode.ainvoke + # gate handles this path. + return False + return loop.run_until_complete(_async_gate(tool_call_id)) + except RuntimeError: + return False def patched_invoke( self: Any, # noqa: ANN401 @@ -973,9 +1056,10 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into ToolNode.invoke.""" config = _inject_callbacks(config) - + tc_id = _extract_tool_call_id(input) + if tc_id and _sync_gate(tc_id): + return _BLOCKED_CONTENT return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -984,43 +1068,46 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into ToolNode.ainvoke.""" config = _inject_callbacks(config) - + tc_id = _extract_tool_call_id(input) + if tc_id and await _async_gate(tc_id): + return _BLOCKED_CONTENT return await original_ainvoke(self, input, config=config, **kwargs) - async def patched_astream( + original_arun = BaseTool.arun + + async def patched_arun( self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 + tool_input: Any, # noqa: ANN401 + *args: Any, + tool_call_id: str | None = None, **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into ToolNode.astream.""" - config = _inject_callbacks(config) - - async for chunk in original_astream(self, input, config=config, **kwargs): - yield chunk - - ToolNode._afunc = patched_afunc # type: ignore[attr-defined] - ToolNode.invoke = patched_invoke # type: ignore[assignment] - ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] - if original_astream is not None: - ToolNode.astream = patched_astream # type: ignore[assignment] - ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] - logger.debug("Patched ToolNode._afunc / invoke / ainvoke / astream") - + """Gate on arun — AgentExecutor calls tool.arun directly.""" + if tool_call_id and await _async_gate(tool_call_id): + return _BLOCKED_CONTENT + return await original_arun( + self, tool_input, *args, tool_call_id=tool_call_id, **kwargs + ) -# --- 6. AgentExecutor (langchain / langchain-classic) --- + BaseTool.invoke = patched_invoke # type: ignore[assignment] + BaseTool.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseTool.arun = patched_arun # type: ignore[assignment] + BaseTool._adrian_base_tool_patched = True # type: ignore[attr-defined] + logger.debug("Patched BaseTool.invoke / ainvoke / arun (universal verdict gate)") -_BLOCKED_OBSERVATION = "[BLOCKED by security policy]" +# --- 7. AgentExecutor (tool_call_id on agent_action, not on tool.arun) --- def _patch_agent_executor() -> None: - """Patch AgentExecutor tool dispatch with the verdict gate. + """Patch AgentExecutor._aperform_agent_action for the executor path. - Covers the legacy AgentExecutor path which bypasses ToolNode entirely. - Falls through for ReAct parsers that don't emit tool_call_id. + AgentExecutor calls tool.arun without forwarding tool_call_id, + so the BaseTool.arun gate can't extract it. The tool_call_id lives + on agent_action.tool_call_id (set by OpenAI-style parsers). We + intercept here, await the verdict, and return a blocked observation + instead of calling the tool. """ AgentExecutor = None AgentStep = None @@ -1036,142 +1123,46 @@ def _patch_agent_executor() -> None: if AgentExecutor is None or AgentStep is None: return - if getattr(AgentExecutor, "_adrian_executor_patched", False): return original_aperform = AgentExecutor._aperform_agent_action - original_perform = AgentExecutor._perform_agent_action async def patched_aperform( - self: Any, # noqa: ANN401 - name_to_tool_map: Any, # noqa: ANN401 - color_mapping: Any, # noqa: ANN401 - agent_action: Any, # noqa: ANN401 - run_manager: Any = None, # noqa: ANN401 - ) -> Any: # noqa: ANN401 - """Verdict gate before AgentExecutor dispatches a tool (async).""" - tool_call_id = getattr(agent_action, "tool_call_id", None) - - if tool_call_id: - ws = _ws_client - - if ws is not None: - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: - logger.warning( - "AgentExecutor: LoginAck not received within 5s; " - "blocking tool %s", - agent_action.tool, - ) - return AgentStep( - action=agent_action, - observation=_BLOCKED_OBSERVATION, - ) - - if ws.policy_active(): - cfg = _get_config() - # Short timeout: AgentExecutor LLM callbacks may not propagate, - # so verdicts may never arrive. - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict( - tool_call_id, timeout, - ) - - if verdict is not None and _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s " - "mad_code=%s (AgentExecutor path)", - verdict.event_id, - verdict.mad_code, - ) - return AgentStep( - action=agent_action, - observation=_BLOCKED_OBSERVATION, - ) - - if verdict is None: - logger.warning( - "AgentExecutor: verdict timeout for " - "tool_call_id=%s, fail-open", - tool_call_id, - ) - - return await original_aperform( - self, name_to_tool_map, color_mapping, agent_action, run_manager, - ) - - def patched_perform( - self: Any, # noqa: ANN401 - name_to_tool_map: Any, # noqa: ANN401 + self: Any, + name_to_tool_map: Any, color_mapping: Any, # noqa: ANN401 - agent_action: Any, # noqa: ANN401 + agent_action: Any, run_manager: Any = None, # noqa: ANN401 ) -> Any: # noqa: ANN401 - """Verdict gate before AgentExecutor dispatches a tool (sync).""" - tool_call_id = getattr(agent_action, "tool_call_id", None) - - if tool_call_id: + tc_id = getattr(agent_action, "tool_call_id", None) + if tc_id: ws = _ws_client - - if ws is not None and ws._login_ack_received.is_set() and ws.policy_active(): # pyright: ignore[reportPrivateUsage] - import concurrent.futures - - async def _gate() -> bool: - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict( - tool_call_id, timeout, + if ( + ws is not None + and ws._login_ack_received.is_set() + and ws.policy_active() + ): # pyright: ignore[reportPrivateUsage] + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", + tc_id, ) - if verdict is not None and _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s " - "mad_code=%s (AgentExecutor sync path)", - verdict.event_id, - verdict.mad_code, - ) - return True - return False - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - future: concurrent.futures.Future[bool] = concurrent.futures.Future() - - async def _run() -> None: - try: - result = await _gate() - future.set_result(result) - except Exception as exc: - future.set_exception(exc) - - loop.create_task(_run()) - should_block = future.result(timeout=35) - else: - should_block = loop.run_until_complete(_gate()) - - if should_block: - return AgentStep( - action=agent_action, - observation=_BLOCKED_OBSERVATION, - ) - except Exception: - logger.debug( - "AgentExecutor sync gate failed, falling through", - exc_info=True, + return AgentStep(action=agent_action, observation=_BLOCKED_CONTENT) + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, ) - - return original_perform( - self, name_to_tool_map, color_mapping, agent_action, run_manager, + return AgentStep(action=agent_action, observation=_BLOCKED_CONTENT) + return await original_aperform( + self, name_to_tool_map, color_mapping, agent_action, run_manager ) AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] - AgentExecutor._perform_agent_action = patched_perform # type: ignore[assignment] AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] - logger.debug("Patched AgentExecutor._aperform_agent_action / _perform_agent_action") + logger.debug("Patched AgentExecutor._aperform_agent_action") diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py index 30f1ab5..1ab5df4 100644 --- a/sdk/adrian/ws.py +++ b/sdk/adrian/ws.py @@ -513,18 +513,6 @@ async def connect(self) -> None: else: logger.info("WebSocket connected: %s", self._url) - # Eager login: send the SessionLogin frame immediately - # so the server responds with LoginAck before any tool - # gate fires. Previously login was deferred to the - # first _send_frame call, which meant frameworks that - # don't trigger paired events (AgentExecutor) would - # never receive LoginAck and the block gate would time - # out. Provider/model are best-effort at this point - # (empty until the first LLM event auto-detects them). - if not self._logged_in: - await self._send_login(self._ws) - self._logged_in = True - # Drain anything buffered while we were offline, even # on the very first connect. ``_send_mcp_inventory`` # and other init-time emitters queue frames before the From b66f0bc619899509131ed6762766809fc6b762b3 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Mon, 15 Jun 2026 22:22:27 +0200 Subject: [PATCH 3/6] Fix: remove double-gating --- sdk/adrian/__init__.py | 200 ++++++++++++++++++++--------- sdk/adrian/ws.py | 41 ++++-- sdk/tests/test_block_mode.py | 29 +++-- sdk/tests/test_block_mode_races.py | 22 ++-- sdk/tests/test_exec_modes.py | 2 +- 5 files changed, 201 insertions(+), 93 deletions(-) diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py index 1b55d83..e50ddda 100644 --- a/sdk/adrian/__init__.py +++ b/sdk/adrian/__init__.py @@ -799,7 +799,46 @@ async def patched_astream( logger.debug("Patched Pregel.invoke / ainvoke / astream") -# --- 5. ToolNode (callback injection only — gate is on BaseTool) --- +# --- 5. ToolNode --- + + +def _extract_tool_calls( + state: dict[str, Any] | list[BaseMessage] | Any, +) -> list[dict[str, Any]]: + """Extract tool_calls from ToolNode input (all three dispatch shapes). + + Returns full tool_call dicts (with id, name, args) for backward + compat with tests and callers that need the full shape. + """ + # Shape 3: per-tool-call dict from _afunc dispatch + if isinstance(state, dict) and "tool_call" in state: + tc = state["tool_call"] + if isinstance(tc, dict) and tc.get("id"): + return [tc] + tc_id = getattr(tc, "id", None) + if tc_id: + return [ + { + "id": tc_id, + "name": getattr(tc, "name", ""), + "args": getattr(tc, "args", {}), + } + ] + return [] + + # Shape 1/2: state dict or message list + if isinstance(state, dict): + messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] + elif isinstance(state, list): + messages = list(state) + else: + return [] + + for msg in reversed(messages): + if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): + return msg.tool_calls # type: ignore[no-any-return] + + return [] def _should_halt(verdict: pb.Verdict) -> bool: @@ -840,26 +879,6 @@ def _patch_tool_node() -> None: original_ainvoke = ToolNode.ainvoke original_astream = getattr(ToolNode, "astream", None) - def _extract_tool_call_ids(state: Any) -> list[str]: # noqa: ANN401 - """Extract tool_call_ids from ToolNode input (any shape).""" - # Shape 3: per-tool-call dict from _afunc dispatch - if isinstance(state, dict) and "tool_call" in state: - tc = state["tool_call"] - tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) - return [tc_id] if tc_id else [] - # Shape 1/2: state dict or message list - messages = ( - list(state.get("messages") or []) - if isinstance(state, dict) - else list(state) - if isinstance(state, list) - else [] - ) - for msg in reversed(messages): - if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): - return [tc.get("id") for tc in msg.tool_calls if tc.get("id")] - return [] - async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 """Returns True if tools should be BLOCKED.""" ws = _ws_client @@ -874,13 +893,14 @@ async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 if not ws.policy_active(): return False - tc_ids = _extract_tool_call_ids(state) + tc_ids: list[str] = [ + str(tc.get("id")) for tc in _extract_tool_calls(state) if tc.get("id") + ] if not tc_ids: return False cfg = _get_config() timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - # Gate on the first tool_call_id (all come from the same LLM turn) verdict = await ws.wait_for_tool_call_verdict(tc_ids[0], timeout) if verdict is None: logger.warning("ToolNode: verdict timeout, blocking (fail-closed)") @@ -895,7 +915,7 @@ async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 return False def _build_blocked(state: Any) -> dict[str, list[ToolMessage]]: # noqa: ANN401 - tc_ids = _extract_tool_call_ids(state) + tc_ids = [tc.get("id") for tc in _extract_tool_calls(state) if tc.get("id")] return { "messages": [ ToolMessage( @@ -921,8 +941,11 @@ async def patched_ainvoke( **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 config = _inject_callbacks(config) - if await _gate_tool_calls(input): - return _build_blocked(input) + # Verdict gate removed — BaseTool.ainvoke/arun is the single + # gate layer. Gating here too caused double-gate: ToolNode + # consumed the verdict future, BaseTool's gate registered a + # fresh future that never resolved → 30s timeout on a benign + # verdict. Callback injection is kept so events still flow. return await original_ainvoke(self, input, config=config, **kwargs) async def patched_astream( @@ -932,9 +955,7 @@ async def patched_astream( **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 config = _inject_callbacks(config) - if await _gate_tool_calls(input): - yield _build_blocked(input) - return + assert original_astream is not None # guarded by line below async for chunk in original_astream(self, input, config=config, **kwargs): yield chunk @@ -970,7 +991,9 @@ def _patch_base_tool() -> None: In MODE_ALERT, no gate fires at all (skip). """ from langchain_core.tools import BaseTool - from langchain_core.tools.base import _is_tool_call # pyright: ignore[reportPrivateUsage] + from langchain_core.tools.base import ( + _is_tool_call, # pyright: ignore[reportPrivateUsage] + ) if getattr(BaseTool, "_adrian_base_tool_patched", False): return @@ -1030,11 +1053,19 @@ async def _async_gate(tool_call_id: str) -> bool: return False def _sync_gate(tool_call_id: str) -> bool: - """Sync verdict gate for pure-sync callers (no running event loop). + """Sync verdict gate — works for pure-sync and worker-thread callers. + + Pure-sync (no event loop): runs ``_async_gate`` via + ``loop.run_until_complete``. + + Worker-thread (Pregel dispatches sync tools on a thread-pool + worker while the event loop runs on the main thread): bridges + the async gate to the main loop via ``run_coroutine_threadsafe`` + and blocks the worker thread until the verdict resolves. - When called from within an async loop (ToolNode._func dispatched - by Pregel), this cannot work — use the ToolNode.ainvoke gate - instead. Returns False (skip) when a running loop is detected. + Event-loop thread (calling tool.invoke directly from async + code): cannot block — returns False (skip). The async path + (BaseTool.ainvoke) handles this case. """ ws = _ws_client if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] @@ -1042,14 +1073,45 @@ def _sync_gate(tool_call_id: str) -> bool: try: loop = asyncio.get_event_loop() - if loop.is_running(): - # Can't block the event loop thread. The ToolNode.ainvoke - # gate handles this path. - return False + except RuntimeError: + return False + + if not loop.is_running(): + # Pure-sync caller — safe to block return loop.run_until_complete(_async_gate(tool_call_id)) + + # Check if we're on a worker thread (no running loop on THIS + # thread) vs the event-loop thread itself. + try: + asyncio.get_running_loop() + # We ARE on the event-loop thread — can't block it. + return False except RuntimeError: + pass + + # Worker thread: bridge the async gate to the main loop. + main_loop = getattr(ws, "_loop", None) + if main_loop is None or not main_loop.is_running(): + return False + + try: + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + future = asyncio.run_coroutine_threadsafe( + _async_gate(tool_call_id), main_loop + ) + return future.result(timeout=timeout if timeout else 60.0) + except Exception: return False + def _blocked_response(tc_id: str) -> Any: # noqa: ANN401 + """Return a blocked response compatible with ToolNode (ToolMessage) + and legacy callers (falls back to bare string).""" + try: + return ToolMessage(content=_BLOCKED_CONTENT, tool_call_id=tc_id, name="") + except Exception: + return _BLOCKED_CONTENT + def patched_invoke( self: Any, # noqa: ANN401 input: Any, # noqa: A002, ANN401 @@ -1059,7 +1121,7 @@ def patched_invoke( config = _inject_callbacks(config) tc_id = _extract_tool_call_id(input) if tc_id and _sync_gate(tc_id): - return _BLOCKED_CONTENT + return _blocked_response(tc_id) return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -1071,7 +1133,7 @@ async def patched_ainvoke( config = _inject_callbacks(config) tc_id = _extract_tool_call_id(input) if tc_id and await _async_gate(tc_id): - return _BLOCKED_CONTENT + return _blocked_response(tc_id) return await original_ainvoke(self, input, config=config, **kwargs) original_arun = BaseTool.arun @@ -1085,7 +1147,7 @@ async def patched_arun( ) -> Any: # noqa: ANN401 """Gate on arun — AgentExecutor calls tool.arun directly.""" if tool_call_id and await _async_gate(tool_call_id): - return _BLOCKED_CONTENT + return _blocked_response(tool_call_id) return await original_arun( self, tool_input, *args, tool_call_id=tool_call_id, **kwargs ) @@ -1138,27 +1200,41 @@ async def patched_aperform( tc_id = getattr(agent_action, "tool_call_id", None) if tc_id: ws = _ws_client - if ( - ws is not None - and ws._login_ack_received.is_set() - and ws.policy_active() - ): # pyright: ignore[reportPrivateUsage] - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) - if verdict is None: - logger.warning( - "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", - tc_id, - ) - return AgentStep(action=agent_action, observation=_BLOCKED_CONTENT) - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return AgentStep(action=agent_action, observation=_BLOCKED_CONTENT) + if ws is not None: + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "AgentExecutor: LoginAck not received within 5s; blocking" + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if ws.policy_active(): + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", + tc_id, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) return await original_aperform( self, name_to_tool_map, color_mapping, agent_action, run_manager ) diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py index 1ab5df4..169cbdc 100644 --- a/sdk/adrian/ws.py +++ b/sdk/adrian/ws.py @@ -52,6 +52,8 @@ _MAX_RUN_ID_MAP = 1024 # Cap on in-flight tool_call_id → event_id mappings (block-mode correlation). _MAX_TOOL_CALL_MAP = 1024 +# Cap on resolved verdict futures kept for late-waiter replay. +_MAX_PENDING_VERDICTS = 512 _DEFAULT_REPLAY_BUFFER_FRAMES = 1000 @@ -254,6 +256,10 @@ def __init__( # Set by close() so _handle_disconnect knows not to spawn a reconnect # during a graceful shutdown. self._closing = False + # Event loop running the WebSocket tasks. Captured on first + # connect so _sync_gate can bridge async waits from worker + # threads via run_coroutine_threadsafe. + self._loop: asyncio.AbstractEventLoop | None = None # Futures awaited by the patched ToolNode.ainvoke when the # active mode requires a wait (BLOCK or HITL). Each resolves # with the matching ``Verdict`` proto. Futures survive a @@ -472,6 +478,7 @@ async def connect(self) -> None: backoff = _INITIAL_BACKOFF loop = asyncio.get_running_loop() + self._loop = loop headers: dict[str, str] = {} @@ -491,7 +498,6 @@ async def connect(self) -> None: disconnected_at = self._disconnected_at is_reconnect = disconnected_at is not None - if disconnected_at is not None: downtime = time.monotonic() - disconnected_at self._disconnected_at = None @@ -927,6 +933,18 @@ def register_pending( return fut + def _evict_resolved_verdicts(self) -> None: + """Remove oldest resolved futures when the dict exceeds the cap.""" + while len(self._pending_verdicts) > _MAX_PENDING_VERDICTS: + # Evict the oldest entry (dict preserves insertion order). + oldest_id = next(iter(self._pending_verdicts)) + oldest_fut = self._pending_verdicts[oldest_id] + if oldest_fut.done(): + del self._pending_verdicts[oldest_id] + else: + # Don't evict an in-flight future; stop evicting. + break + async def wait_for_verdict( self, event_id: str, @@ -939,25 +957,30 @@ async def wait_for_verdict( ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the verdict, or ``None`` on timeout (fail-open). - Cleans up the ``_pending_verdicts`` entry on either path: - ``_on_verdict_frame`` only resolves the future, the dict - ownership belongs here so a late ``register_pending`` after the - verdict has already arrived can still find the resolved future. + Resolved futures are kept in ``_pending_verdicts`` so a second + waiter on the same event_id (e.g. BaseTool.ainvoke firing after + ToolNode.ainvoke already consumed the verdict) finds the already- + resolved future and returns instantly instead of timing out. + Timed-out (unconsumed) futures are removed immediately; resolved + futures are evicted when the dict exceeds ``_MAX_PENDING_VERDICTS``. """ fut = self.register_pending(event_id) try: - return await asyncio.wait_for(fut, timeout=timeout) + result = await asyncio.wait_for(fut, timeout=timeout) + # Keep resolved future in dict for late waiters; cap size. + self._evict_resolved_verdicts() + return result except TimeoutError: logger.warning( "Verdict timeout for event_id=%s after %ss", event_id, timeout, ) - - return None - finally: + # Timed-out future is useless — remove so a retry can + # register a fresh one. self._pending_verdicts.pop(event_id, None) + return None async def wait_for_tool_verdict( self, diff --git a/sdk/tests/test_block_mode.py b/sdk/tests/test_block_mode.py index 0d1c352..0bbbdaf 100644 --- a/sdk/tests/test_block_mode.py +++ b/sdk/tests/test_block_mode.py @@ -142,10 +142,16 @@ async def test_looks_up_llm_event_id_and_resolves(self) -> None: class TestToolNodePatchBlocking: async def test_in_scope_block_verdict_halts_tool(self, tmp_path: Path) -> None: - """MODE_BLOCK + policy_m4=true + mad_code='M4_a' → halt with synthetic ToolMessage.""" + """MODE_BLOCK + policy_m4=true + mad_code='M4_a' → BaseTool.ainvoke gate blocks. - def _real_tool(x: str) -> str: - """Real tool stub for block-mode tests.""" + The verdict gate lives on BaseTool (the universal layer), not + ToolNode.ainvoke. Uses an async tool so BaseTool.ainvoke (not + BaseTool.invoke) is the entry point — matching the production + path for create_react_agent with async tools. + """ + + async def _real_tool(x: str) -> str: + """Real async tool stub for block-mode tests.""" _real_tool.called = True # type: ignore[attr-defined] return x @@ -180,6 +186,7 @@ def _real_tool(x: str) -> str: result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + # BaseTool.ainvoke gate blocks — tool body does NOT run. assert _real_tool.called is False # type: ignore[attr-defined] msgs = result["messages"] assert len(msgs) == 1 @@ -190,7 +197,7 @@ async def test_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None: captured: list[str] = [] - def _real_tool(x: str) -> str: + async def _real_tool(x: str) -> str: """Real tool stub for block-mode tests.""" captured.append(x) @@ -226,11 +233,12 @@ def _real_tool(x: str) -> str: assert captured == ["hi"] - async def test_timeout_fail_open_runs_tool(self, tmp_path: Path) -> None: + async def test_timeout_fail_closed_blocks_tool(self, tmp_path: Path) -> None: + """Verdict timeout in MODE_BLOCK → fail-closed (tool does NOT run).""" captured: list[str] = [] - def _real_tool(x: str) -> str: - """Real tool stub for block-mode tests.""" + async def _real_tool(x: str) -> str: + """Real async tool stub for block-mode tests.""" captured.append(x) return x @@ -248,7 +256,7 @@ def _real_tool(x: str) -> str: _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) ws._connected.set() ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" - # No pending future → wait_for_verdict times out → fail-open. + # No pending future → wait_for_verdict times out → fail-closed (MODE_BLOCK). tool_node = ToolNode([_real_tool]) ai = AIMessage( @@ -259,7 +267,8 @@ def _real_tool(x: str) -> str: await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] - assert captured == ["hi"] + # Fail-closed: tool should NOT have run. + assert captured == [] class TestModeAlert: @@ -268,7 +277,7 @@ async def test_alert_mode_skips_wait(self, tmp_path: Path) -> None: captured: list[str] = [] - def _real_tool(x: str) -> str: + async def _real_tool(x: str) -> str: """Real tool stub for block-mode tests.""" captured.append(x) diff --git a/sdk/tests/test_block_mode_races.py b/sdk/tests/test_block_mode_races.py index fa0ad57..16d8e4a 100644 --- a/sdk/tests/test_block_mode_races.py +++ b/sdk/tests/test_block_mode_races.py @@ -5,17 +5,17 @@ LLM calls; no running backend. Scenarios mirror the validated shapes from the multi-agent work: - S1 subagents-as-tools - director → worker (nested) - S2 handoffs - triage → specialist (sequential) - S3 router - parallel fan-out via Send() - S4 hierarchical - 3-level deep (director → team-lead → worker) - S5 custom workflow - deterministic + LLM nodes mixed - S6 swarm - back-and-forth handoffs (Alice ↔ Bob) - S7 supervisor - central dispatcher to N workers - S8 deep research - parallel researchers via asyncio.gather + S1 subagents-as-tools , director → worker (nested) + S2 handoffs , triage → specialist (sequential) + S3 router , parallel fan-out via Send() + S4 hierarchical , 3-level deep (director → team-lead → worker) + S5 custom workflow , deterministic + LLM nodes mixed + S6 swarm , back-and-forth handoffs (Alice ↔ Bob) + S7 supervisor , central dispatcher to N workers + S8 deep research , parallel researchers via asyncio.gather The invariant under test: for EVERY pattern, each ToolNode invocation -blocks on the verdict of the LLM that emitted its specific tool_call.id - +blocks on the verdict of the LLM that emitted its specific tool_call.id , never a sibling, never a parent, never a stale global. """ @@ -117,9 +117,9 @@ def _init_block_mode(tmp_path: Path, block_timeout: float = 1.0) -> Any: def _tool(name: str, captured: list[str]) -> Any: - """Build a named stub tool that records its argument.""" + """Build a named async stub tool that records its argument.""" - def _impl(x: str) -> str: + async def _impl(x: str) -> str: """Stub tool.""" captured.append(f"{name}:{x}") diff --git a/sdk/tests/test_exec_modes.py b/sdk/tests/test_exec_modes.py index 1ea8ae1..f3f5e42 100644 --- a/sdk/tests/test_exec_modes.py +++ b/sdk/tests/test_exec_modes.py @@ -61,7 +61,7 @@ def _cleanup() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] def _stub_tool(captured: list[str]) -> Any: # noqa: ANN401 - def _impl(x: str) -> str: + async def _impl(x: str) -> str: """Stub tool.""" captured.append(x) From 3b9218876d4924bbe4bdfa23ea948b0a0eb66404 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Mon, 15 Jun 2026 22:55:47 +0200 Subject: [PATCH 4/6] Add: merge conflicts and merge --- sdk/adrian/ws.py | 1038 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1038 insertions(+) create mode 100644 sdk/adrian/ws.py diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py new file mode 100644 index 0000000..169cbdc --- /dev/null +++ b/sdk/adrian/ws.py @@ -0,0 +1,1038 @@ +"""Async WebSocket ``EventHandler`` that streams ``PairedEvent`` to the worker core API. + +Converts each ``PairedEvent`` into a ``pb.PairedEvent`` protobuf, wraps it in a +``ClientFrame.paired_batch``, and sends it over a long-lived WebSocket +connection. Verdicts received back resolve block-mode futures and fire the +callback handler's verdict processing. + +Implements the ``EventHandler`` protocol so it slots into the SDK's hook +registry alongside ``JSONLHandler``. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import time +from collections import OrderedDict, deque +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +import websockets + +if TYPE_CHECKING: + from adrian.config import OnDisconnectCallback, OnReconnectCallback + from adrian.handler import AdrianCallbackHandler + +from adrian.format.types import ( + AgentContext, + LlmPairData, + PairedEvent, + ParentContext, +) +from adrian.proto import event_pb2 as pb + +logger = logging.getLogger("adrian.ws") + +SCHEMA_VERSION = 2 + +_INITIAL_BACKOFF = 1.0 +_MAX_BACKOFF = 30.0 +# Server close code: quota exhausted. Spec'd in +# server/internal/websocket/handler.go (closeQuotaExceeded). Returning +# every 30s would hammer the server while quota is depleted; one +# minute is slow enough to be cheap, fast enough that the next hourly +# / daily / monthly window-rollover is picked up within tolerance. +_QUOTA_EXHAUSTED_CLOSE_CODE = 4003 +_QUOTA_RECONNECT_DELAY = 60.0 +# Cap on in-flight LLM run_id → event_id mappings. Evicted LRU-style; +# block-mode lookups for evicted entries fail open. +_MAX_RUN_ID_MAP = 1024 +# Cap on in-flight tool_call_id → event_id mappings (block-mode correlation). +_MAX_TOOL_CALL_MAP = 1024 +# Cap on resolved verdict futures kept for late-waiter replay. +_MAX_PENDING_VERDICTS = 512 + +_DEFAULT_REPLAY_BUFFER_FRAMES = 1000 + +# Heartbeat tuning. 10s interval / 15s pong timeout detects half-open +# connections (ALB idle cut, NAT drop, dead remote process) without +# flooding the wire. Kept in sync with the backend's pingInterval / +# pongTimeout, if these change, update server/internal/websocket/handler.go. +_PING_INTERVAL = 10.0 +_PING_TIMEOUT = 15.0 + +_PROVIDER_PREFIXES: dict[str, str] = { + "chatanthropic": "anthropic", + "chatopenai": "openai", + "chatgooglegenai": "google", + "chatcohere": "cohere", + "chatmistralai": "mistral", +} + +_PAIR_TYPE_MAP: dict[str, pb.PairType.ValueType] = { + "llm": pb.PAIR_TYPE_LLM, + "tool": pb.PAIR_TYPE_TOOL, +} + + +def _derive_provider(model_class_name: str) -> str: + """Derive the LLM provider from the model class name. + + Args: + model_class_name: Class name like ``"ChatAnthropic"`` or ``"ChatOpenAI"``. + + Returns: + Provider string (e.g. ``"anthropic"``), or the class name lower-cased + if no known prefix matches. + """ + key = model_class_name.lower() + + return _PROVIDER_PREFIXES.get(key, key) + + +def _fill_agent_context( + pb_ctx: pb.AgentContext, src: AgentContext | ParentContext +) -> None: + """Copy an AgentContext / ParentContext dataclass into its proto counterpart.""" + pb_ctx.agent_id = src.agent_id + pb_ctx.system_prompt = src.system_prompt + pb_ctx.user_instruction = src.user_instruction + + +def _safe_cancel( + task_or_future: asyncio.Task[Any] | asyncio.Future[Any] | None, +) -> None: + """Cancel a task / future, ignoring closed-loop errors at shutdown. + + Adrian's ``atexit`` handler may run after the user's loop has been + closed; in that path ``adrian.shutdown`` spawns a new ``asyncio.run`` + and walks each handler's ``close()``. Tasks bound to the *old* loop + can no longer be cancelled (``call_soon`` raises ``Event loop is + closed``). Swallowing the error here keeps the cleanup path quiet, + the task will be reaped when the dead loop is GC'd. + """ + if task_or_future is None or task_or_future.done(): + return + # "Event loop is closed", old loop is gone, nothing to cancel. + with contextlib.suppress(RuntimeError): + task_or_future.cancel() + + +def _paired_event_to_proto(event: PairedEvent) -> pb.PairedEvent: + """Convert a ``PairedEvent`` dataclass into its protobuf form. + + ``parent.agent_id`` empty-string signals "no parent agent". + ``parent_run_id`` empty-string signals "no parent in run tree". + """ + proto = pb.PairedEvent( + event_id=event.event_id, + invocation_id=event.invocation_id, + session_id=event.session_id, + run_id=event.run_id, + parent_run_id=event.parent_run_id, + timestamp=event.timestamp, + pair_type=_PAIR_TYPE_MAP.get(event.pair_type, pb.PAIR_TYPE_UNSPECIFIED), + ) + + _fill_agent_context(proto.agent, event.agent) + + if event.parent is not None: + _fill_agent_context(proto.parent, event.parent) + + if isinstance(event.data, LlmPairData): + proto.llm.model = event.data.model + + for msg in event.data.messages: + pb_msg = proto.llm.messages.add() + pb_msg.role = msg["role"] + pb_msg.content = msg["content"] + + proto.llm.output = event.data.output + + for tc in event.data.tool_calls: + pb_tc = proto.llm.tool_calls.add() + pb_tc.name = tc["name"] + pb_tc.args = json.dumps(tc["args"], default=str) + pb_tc.id = tc["id"] + + if event.data.usage is not None: + proto.llm.usage.prompt_tokens = event.data.usage["prompt_tokens"] + proto.llm.usage.completion_tokens = event.data.usage["completion_tokens"] + proto.llm.usage.total_tokens = event.data.usage["total_tokens"] + else: + # Union is LlmPairData | ToolPairData; this branch is the + # ToolPairData case. + proto.tool.tool_name = event.data.tool_name + proto.tool.tool_call_id = event.data.tool_call_id or "" + proto.tool.input = event.data.input + proto.tool.output = event.data.output + + if event.metadata: + proto.metadata_json = json.dumps(event.metadata, default=str).encode() + + return proto + + +class WebSocketClient: + """Streams ``PairedEvent`` instances to the worker core API. + + Connects eagerly via :meth:`schedule_connect` with exponential backoff, + auto-detects the LLM provider on the first LLM pair, sends paired events + as protobuf frames, and resolves block-mode futures when verdicts arrive. + """ + + def __init__( + self, + url: str, + session_id: str, + api_key: str, + handler: AdrianCallbackHandler | None = None, + on_disconnect: OnDisconnectCallback | None = None, + on_reconnect: OnReconnectCallback | None = None, + on_login_ack: Callable[[], Awaitable[None]] | None = None, + replay_buffer_frames: int = _DEFAULT_REPLAY_BUFFER_FRAMES, + ) -> None: + """Initialise without connecting. + + Args: + url: WebSocket endpoint URL. + session_id: Session ID sent in the login frame. + api_key: Adrian API key for the ``Authorization`` header. + handler: Callback handler for verdict processing. + on_disconnect: Fired when the connection is lost (sync or async). + Receives a reason string. + on_reconnect: Fired when the connection re-establishes after a + prior disconnect (sync or async). Does not fire on initial + connect. + on_login_ack: Async hook fired after each ``LoginAck`` frame is + applied, once per (re)connect. Used internally to push a + fresh ``McpInventory`` on every login. Exceptions are + logged and swallowed. + replay_buffer_frames: Ring-buffer capacity (frame count, not + bytes). When the cap is reached each further append evicts + the oldest frame; a one-shot WARN fires on first fill, and + the cumulative drop count is logged at WARN on the next + reconnect. + """ + self._url = url + self._session_id = session_id + self._api_key = api_key + self._handler = handler + self._on_disconnect = on_disconnect + self._on_reconnect = on_reconnect + self._on_login_ack_cb = on_login_ack + self._provider = "" + self._model = "" + # Server-supplied execution-mode policy. Populated when the + # first ServerFrame{login_ack} arrives after each (re)connect. + # ``policy_active()`` and ``block_timeout()`` read this state + # to decide whether the patched ToolNode should wait for a + # verdict and how long. + self._mode: int = pb.MODE_UNSPECIFIED + self._policy: pb.PolicySnapshot | None = None + # Set the first time a ``ServerFrame{login_ack}`` is applied. + # Used in two places: + # 1. ``on_paired_event`` defensively pre-registers a + # verdict-wait future when this event is unset, so the + # very first tool-bearing LLM emission is covered even + # though the recv loop hasn't yet processed LoginAck and + # ``policy_active()`` reads False. + # 2. The patched ``ToolNode.ainvoke`` ``await``s this event + # (with a short timeout) before deciding whether to wait + # for a verdict, so the first ToolNode invocation cannot + # run-through-without-waiting in the same window. + # Stays set across disconnect/reconnect because mode/policy + # state survives, a fresh LoginAck on reconnect simply re-sets + # an already-set event. + self._login_ack_received: asyncio.Event = asyncio.Event() + self._ws: websockets.ClientConnection | None = None + self._logged_in = False + self._connected = asyncio.Event() + self._connect_task: asyncio.Task[None] | None = None + self._recv_task: asyncio.Task[None] | None = None + # Set by close() so _handle_disconnect knows not to spawn a reconnect + # during a graceful shutdown. + self._closing = False + # Event loop running the WebSocket tasks. Captured on first + # connect so _sync_gate can bridge async waits from worker + # threads via run_coroutine_threadsafe. + self._loop: asyncio.AbstractEventLoop | None = None + # Futures awaited by the patched ToolNode.ainvoke when the + # active mode requires a wait (BLOCK or HITL). Each resolves + # with the matching ``Verdict`` proto. Futures survive a + # disconnect: a late verdict after reconnect still resolves + # the wait; if none arrives, ``wait_for_verdict``'s timeout + # produces a natural fail-open in BLOCK mode. + self._pending_verdicts: dict[str, asyncio.Future[pb.Verdict]] = {} + # Maps LLM pair run_id → event_id so a subsequent tool call can + # look up the verdict by its parent_run_id (the LLM's run_id). + # LRU-capped at _MAX_RUN_ID_MAP to bound memory on long sessions. + self._run_id_to_event_id: OrderedDict[str, str] = OrderedDict() + # Verdict-correlation map: maps each tool_call.id emitted by + # an LLM to the event_id of the LLM pair that emitted it. + # Populated on every LLM PairedEvent that has tool_calls. + # Consulted by the patched ``ToolNode.ainvoke`` so each tool + # in a parallel fan-out waits on its own producing LLM's + # verdict, not a global "last" pointer. LRU-capped at + # ``_MAX_TOOL_CALL_MAP``. + self._tool_call_id_to_event_id: OrderedDict[str, str] = OrderedDict() + # Serialises the lazy login-then-send sequence so two concurrent + # on_paired_event calls (parallel agents) cannot both send a login. + # Reused by _replay_buffer_to_ws to coordinate with live sends. + self._login_lock = asyncio.Lock() + # Ring buffer of recently serialised ClientFrame bytes. Appended + # only from the offline-or-send-failure paths in _send_frame; the + # happy path bypasses the ring entirely. Drained on reconnect. + self._replay_buffer: deque[bytes] = deque(maxlen=replay_buffer_frames) + # Flips True on the first append that reaches maxlen. Gates the + # one-shot "buffer full" WARN so we don't flood logs. + self._replay_buffer_filled: bool = False + # Monotonic counter of frames dropped due to buffer overflow + # (oldest evicted when a new append arrives at a full ring). + # Logged at WARN on the next reconnect. + self._replay_buffer_dropped: int = 0 + # True while the reconnect path is draining the replay buffer. + # Live sends observed during this window are routed back into + # the same deque so they slot in AFTER the pre-outage tail + # rather than racing onto the wire ahead of older buffered + # frames. Flipped on as the first sync line of + # _replay_buffer_to_ws and cleared in its finally. + self._replaying: bool = False + # Set by _handle_disconnect, cleared on successful reconnect. + # Used to gate on_reconnect and measure downtime. + self._disconnected_at: float | None = None + # One-shot delay applied before the next ``connect()`` attempt. + # Set when the server closes with a code that requests a longer + # wait (currently only 4003 quota exhausted); cleared by + # ``connect()`` after honouring it. ``None`` means use the + # standard exponential schedule. + self._next_reconnect_delay: float | None = None + + # -- Mode / policy state (populated by LoginAck) -- + + def policy_active(self) -> bool: + """Whether the active server mode requires waiting on verdicts. + + Single predicate consulted by the patched ``ToolNode.ainvoke``. + Returns ``True`` for ``MODE_BLOCK`` and ``MODE_HITL``; ``False`` + for ``MODE_ALERT`` and unset (pre-login) state. + """ + return self._mode in (pb.MODE_BLOCK, pb.MODE_HITL) + + def block_timeout(self, kwarg_default: float) -> float | None: + """Effective per-tool-call wait timeout for the active mode. + + - ``MODE_BLOCK``: ``kwarg_default`` (typically 30s), fail-open + if the server doesn't classify in time. + - ``MODE_HITL``: ``None``, wait indefinitely for human review. + - ``MODE_ALERT`` / unset: ``0``, caller short-circuits before + registering a future. + """ + if self._mode == pb.MODE_BLOCK: + return kwarg_default + elif self._mode == pb.MODE_HITL: + return None + else: + return 0 + + # -- EventHandler protocol -- + + async def on_paired_event(self, event: PairedEvent) -> None: + """Send a paired event over the WebSocket. + + Auto-detects the LLM provider on the first LLM pair, updates the + run_id → event_id map for block mode, converts the dataclass to + protobuf, and sends a ``ClientFrame.paired_batch`` frame. + + For LLM pairs that carry tool_calls, registers the verdict-wait + future *before* the frame leaves the SDK. This closes the race + where a fast verdict roundtrip resolves and is dropped before + the patched ``ToolNode.ainvoke`` reaches its own + ``register_pending`` call. The matching ``register_pending`` + from the wait site is a get-or-create that returns the existing + future. + + Args: + event: The paired event to stream. + """ + if ( + event.pair_type == "llm" + and not self._provider + and isinstance(event.data, LlmPairData) + ): + self._model = event.data.model + self._provider = _derive_provider(event.data.model) + + if event.pair_type == "llm": + self._run_id_to_event_id[event.run_id] = event.event_id + self._run_id_to_event_id.move_to_end(event.run_id) + + if len(self._run_id_to_event_id) > _MAX_RUN_ID_MAP: + self._run_id_to_event_id.popitem(last=False) + + # Populate tool_call.id → event_id so each tool call can block + # on its own producing LLM's verdict under parallel fan-out. + if isinstance(event.data, LlmPairData) and event.data.tool_calls: + for tc in event.data.tool_calls: + tc_id = tc.get("id") or "" + + if not tc_id: + continue + + self._tool_call_id_to_event_id[tc_id] = event.event_id + self._tool_call_id_to_event_id.move_to_end(tc_id) + + if len(self._tool_call_id_to_event_id) > _MAX_TOOL_CALL_MAP: + self._tool_call_id_to_event_id.popitem(last=False) + + # Pre-register the wait future so an eager verdict + # cannot race ahead of the ToolNode patch. Gated on + # ``policy_active()`` so ALERT-mode sessions don't + # accumulate futures that will never be resolved or + # awaited, except for the very first event of the + # session, where ``LoginAck`` may not yet have been + # processed by the recv loop and ``policy_active()`` + # therefore reads False even when the mode will + # imminently be set to BLOCK or HITL. Pre-register + # defensively in that window; in ALERT mode the gate + # filters out every subsequent event so the leak is + # bounded to one orphan future per session. + if self.policy_active() or not self._login_ack_received.is_set(): + self.register_pending(event.event_id) + + proto = _paired_event_to_proto(event) + frame = pb.ClientFrame() + added = frame.paired_batch.events.add() + added.CopyFrom(proto) + + await self._send_frame(frame) + + async def close(self) -> None: + """Cancel background tasks and close the WebSocket. + + Sets ``_closing`` so any in-flight ``_handle_disconnect`` does not + spawn a reconnect during graceful shutdown. + + Defensive against the ``atexit`` shutdown path: ``adrian.shutdown`` + spawns a fresh ``asyncio.run`` loop after the user's loop has + already closed, so background tasks bound to the old loop can no + longer be cancelled cleanly (``call_soon`` raises + ``Event loop is closed``). Skip the cancel in that case, the + old loop is gone, the task will be reaped by GC. + """ + self._closing = True + + _safe_cancel(self._recv_task) + self._recv_task = None + _safe_cancel(self._connect_task) + self._connect_task = None + + if self._ws is not None: + with contextlib.suppress(Exception): + await asyncio.wait_for(self._ws.close(), timeout=2.0) + self._ws = None + + for fut in self._pending_verdicts.values(): + if not fut.done(): + _safe_cancel(fut) + self._pending_verdicts.clear() + + # -- Connection lifecycle -- + + def schedule_connect(self, loop: asyncio.AbstractEventLoop) -> None: + """Schedule :meth:`connect` as a background task on the given loop.""" + if self._connect_task is None or self._connect_task.done(): + self._connect_task = loop.create_task(self.connect()) + + async def connect(self) -> None: + """Establish the WebSocket with exponential-backoff retry. + + Heartbeat (``ping_interval`` / ``ping_timeout``) is configured on + the underlying ``websockets`` client; if the server fails to pong + within ``_PING_TIMEOUT`` the library closes the connection and + ``_recv_loop`` surfaces the disconnect via ``_handle_disconnect``. + + On a reconnect (``_disconnected_at`` set by a prior disconnect), + drains the replay buffer and fires ``on_reconnect``. Login is + deferred to ``_send_frame`` / ``_replay_buffer_to_ws`` so the + auto-detected provider/model is included. An ``api_key``, if + configured, is sent as an ``Authorization: Bearer `` header. + + Honours ``_next_reconnect_delay`` if a previous disconnect set + it (e.g. 4003 quota exhausted requests a slower retry). The + delay is consumed on the first attempt; subsequent failures + fall back to the standard exponential schedule. + """ + initial_delay = self._next_reconnect_delay + self._next_reconnect_delay = None + + if initial_delay is not None: + logger.info( + "delaying reconnect by %.0fs (server-requested)", + initial_delay, + ) + await asyncio.sleep(initial_delay) + + backoff = _INITIAL_BACKOFF + loop = asyncio.get_running_loop() + self._loop = loop + + headers: dict[str, str] = {} + + if self._api_key: + headers["Authorization"] = f"Bearer {self._api_key}" + + while True: + try: + self._ws = await websockets.connect( + self._url, + additional_headers=headers, + ping_interval=_PING_INTERVAL, + ping_timeout=_PING_TIMEOUT, + ) + self._connected.set() + self._recv_task = loop.create_task(self._recv_loop()) + + disconnected_at = self._disconnected_at + is_reconnect = disconnected_at is not None + if disconnected_at is not None: + downtime = time.monotonic() - disconnected_at + self._disconnected_at = None + logger.warning( + "WebSocket reconnected: %s (session_id=%s, downtime=%.2fs)", + self._url, + self._session_id, + downtime, + ) + + if self._replay_buffer_dropped > 0: + logger.warning( + "replay buffer dropped %d frames due to overflow " + "before this reconnect (session_id=%s); " + "increase replay_buffer_frames if this recurs", + self._replay_buffer_dropped, + self._session_id, + ) + else: + logger.info("WebSocket connected: %s", self._url) + + # Drain anything buffered while we were offline, even + # on the very first connect. ``_send_mcp_inventory`` + # and other init-time emitters queue frames before the + # WS is open; without this drain those frames never + # ship until something else triggers a live send. + if self._replay_buffer: + logger.info( + "replaying %d buffered frames after connect", + len(self._replay_buffer), + ) + await self._replay_buffer_to_ws() + + if is_reconnect: + await self._fire_on_reconnect() + + return + except Exception: + logger.warning( + "WebSocket connect to %s failed, retrying in %.0fs", + self._url, + backoff, + ) + try: + await asyncio.sleep(backoff) + except RuntimeError: + # Loop closed mid-retry (atexit shutdown). Bail out + # quietly rather than dumping a traceback. + return + backoff = min(backoff * 2, _MAX_BACKOFF) + + async def _send_login(self, ws: websockets.ClientConnection) -> None: + """Send the mandatory SessionLogin frame.""" + frame = pb.ClientFrame() + frame.login.session_id = self._session_id + frame.login.llm_stack.provider = self._provider + frame.login.llm_stack.model = self._model + frame.login.schema_version = SCHEMA_VERSION + await ws.send(frame.SerializeToString()) + logger.debug( + "Sent login (session=%s, provider=%s, model=%s, schema=%d)", + self._session_id, + self._provider, + self._model, + SCHEMA_VERSION, + ) + + async def _send_frame(self, frame: pb.ClientFrame) -> None: + """Serialise and send a ``ClientFrame``, buffering on failure. + + Happy path (connected + healthy): send over WS, bypass the ring + entirely, zero overhead. Offline on entry: buffer for replay. + During reconnect replay: buffer as well, so the drain loop picks + this frame up after the pre-outage tail (preserves order across + the outage boundary). Send raises: buffer the in-flight frame + then trigger ``_handle_disconnect`` so state is cleared and + reconnect is spawned. + """ + frame_bytes = frame.SerializeToString() + kind = frame.WhichOneof("frame") + + if not self._connected.is_set() or self._replaying: + self._buffer_frame(frame_bytes) + reason = "disconnected" if not self._connected.is_set() else "replaying" + logger.info( + "buffered for replay (session_id=%s, kind=%s, " + "buffer_size=%d, reason=%s)", + self._session_id, + kind, + len(self._replay_buffer), + reason, + ) + + return + + ws = self._ws + + if ws is None: + self._buffer_frame(frame_bytes) + + return + + try: + async with self._login_lock: + if not self._logged_in: + await self._send_login(ws) + self._logged_in = True + + await ws.send(frame_bytes) + logger.debug("Sent %s frame", kind) + except Exception: + # Send raised, we cannot confirm the server received this frame. + # Buffer it so the reconnect replay ships it, then clean up state. + self._buffer_frame(frame_bytes) + await self._handle_disconnect("send_failure") + + async def _recv_loop(self) -> None: + """Read ``ServerFrame``s, dispatch by oneof kind. + + First frame after each (re)login MUST be ``login_ack``; anything + else is a protocol error and we tear the connection down so the + reconnect path can try again. Subsequent frames are + ``verdict``s. Unknown oneof kinds (future server additions like + a quota-exhausted signal) are logged and dropped rather than + crashing the loop. + + Any exit path (clean close, exception, cancellation) calls + ``_handle_disconnect`` via ``finally`` so state is cleared and a + reconnect is spawned. + """ + ws = self._ws + + if ws is None: + return + + awaiting_login_ack = True + try: + async for message in ws: + if not isinstance(message, bytes): + continue + + frame = pb.ServerFrame() + frame.ParseFromString(message) + kind = frame.WhichOneof("frame") + + if awaiting_login_ack: + awaiting_login_ack = False + if kind != "login_ack": + logger.error( + "expected ServerFrame{login_ack} as first frame, " + "got %r, closing connection", + kind, + ) + return + + if kind == "login_ack": + self._on_login_ack(frame.login_ack) + elif kind == "verdict": + await self._on_verdict_frame(frame.verdict) + else: + logger.warning( + "ignoring unknown ServerFrame kind %r " + "(future server addition?)", + kind, + ) + except asyncio.CancelledError: + # Expected on graceful shutdown or when _handle_disconnect cancels + # us from the send_failure path. Re-raise to honour cancellation. + raise + except Exception as exc: + logger.warning("recv_loop exited: %s", exc) + finally: + close_code = getattr(ws, "close_code", None) + + if close_code == _QUOTA_EXHAUSTED_CLOSE_CODE: + self._next_reconnect_delay = _QUOTA_RECONNECT_DELAY + + reason = ( + f"quota_exhausted (close={close_code})" + if close_code == _QUOTA_EXHAUSTED_CLOSE_CODE + else "recv_loop_exit" + ) + await self._handle_disconnect(reason) + + def _on_login_ack(self, ack: pb.LoginAck) -> None: + """Apply the org's effective execution-mode policy. + + Fires the ``on_login_ack`` hook (if configured) as a fire-and-forget + task on the running loop so the recv loop doesn't block waiting on it. + """ + self._mode = ack.policy.mode + self._policy = ack.policy + self._login_ack_received.set() + logger.info( + "LoginAck received: mode=%s policy_m0=%s policy_m2=%s " + "policy_m3=%s policy_m4=%s", + pb.Mode.Name(ack.policy.mode), + ack.policy.policy_m0, + ack.policy.policy_m2, + ack.policy.policy_m3, + ack.policy.policy_m4, + ) + + if self._on_login_ack_cb is not None: + asyncio.create_task(self._run_login_ack_cb()) + + async def _run_login_ack_cb(self) -> None: + """Invoke the on_login_ack hook, swallowing exceptions.""" + if self._on_login_ack_cb is None: + return + try: + await self._on_login_ack_cb() + except Exception: + logger.exception("on_login_ack hook raised") + + async def _on_verdict_frame(self, verdict: pb.Verdict) -> None: + """Fire callbacks then resolve the matching pending future, if any. + + The future is left in ``_pending_verdicts`` after ``set_result`` so + a later ``register_pending`` (e.g. from the patched ToolNode after + the verdict has already round-tripped) returns the resolved + future and the wait completes immediately. ``wait_for_verdict`` + owns the cleanup: its ``finally`` pops the entry after the await + returns. + """ + logger.info( + "Verdict received: event_id=%s mad_code=%s mode=%s hitl=%s", + verdict.event_id, + verdict.mad_code or "-", + pb.Mode.Name(verdict.policy.mode), + verdict.HasField("hitl"), + ) + + if self._handler is not None: + await self._handler.handle_verdict(verdict) + + fut = self._pending_verdicts.get(verdict.event_id) + + if fut is None: + if verdict.HasField("hitl"): + logger.warning( + "HITL resolution for unknown event_id=%s, ignoring " + "(stale resolution from a prior SDK process)", + verdict.event_id, + ) + return + + if not fut.done(): + fut.set_result(verdict) + + # -- Resilience: buffering, replay, disconnect/reconnect -- + + def _buffer_frame(self, frame_bytes: bytes) -> None: + """Append a serialised frame to the replay ring. + + Tracks overflow drops and fires the one-shot "buffer full" WARN. + Called only from the offline or send-failure paths in + ``_send_frame``, the happy path bypasses the ring entirely. + """ + if len(self._replay_buffer) == self._replay_buffer.maxlen: + self._replay_buffer_dropped += 1 + + self._replay_buffer.append(frame_bytes) + + if ( + not self._replay_buffer_filled + and len(self._replay_buffer) == self._replay_buffer.maxlen + ): + self._replay_buffer_filled = True + logger.warning( + "adrian replay buffer reached capacity (%d frames); " + "further frames will evict oldest. Tune via " + "replay_buffer_frames or ADRIAN_REPLAY_BUFFER_FRAMES.", + self._replay_buffer.maxlen, + ) + + async def _replay_buffer_to_ws(self) -> None: + """Reissue buffered frames over the current WebSocket. + + Sends ``SessionLogin`` first if not already logged in (the server + requires it as the first frame on every new connection). Uses + ``_login_lock`` so a concurrent live send does not race on the + login check. + + Drains the deque one frame at a time via ``popleft`` inside a + ``while`` loop, rather than taking a snapshot up front. That + way, a live ``_send_frame`` call observed during the drain + routes its frame to the back of the same deque (because + ``_replaying`` is set) and this loop picks it up in the next + iteration, preserving across-outage order + ``[pre-outage] → [live during replay] → [post-replay live]``. + + On a mid-drain send failure, the failed frame is put back at + the front with ``appendleft`` and the function returns; the + next reconnect resumes from exactly where this one stopped. + """ + ws = self._ws + + if ws is None: + return + + self._replaying = True + try: + async with self._login_lock: + if not self._logged_in: + try: + await self._send_login(ws) + self._logged_in = True + except Exception as exc: + logger.warning( + "replay aborted: login send failed: %s", + exc, + ) + + return + + sent = 0 + while self._replay_buffer: + frame_bytes = self._replay_buffer.popleft() + try: + await ws.send(frame_bytes) + except Exception as exc: + # Put the failed frame back at the front so the next + # reconnect's drain resumes from exactly this point. + self._replay_buffer.appendleft(frame_bytes) + logger.warning( + "replay aborted after %d frame(s), %d remaining: %s", + sent, + len(self._replay_buffer), + exc, + ) + + return + sent += 1 + + logger.info("replayed %d buffered frames", sent) + self._replay_buffer_dropped = 0 + self._replay_buffer_filled = False + finally: + self._replaying = False + + async def _handle_disconnect(self, reason: str) -> None: + """Clear connection state and spawn a reconnect. + + Idempotent: if already disconnected or closing, returns immediately. + Pending verdict futures are intentionally left pending across the + disconnect, a late verdict after reconnect resolves them; if none + arrives, ``wait_for_verdict``'s timeout fires naturally. + """ + if self._closing or not self._connected.is_set(): + return + + self._connected.clear() + self._disconnected_at = time.monotonic() + + # Only cancel the recv task if we are not currently running inside it. + # When _recv_loop's own finally calls us, self._recv_task IS the + # current task, cancelling it would raise CancelledError inside the + # finally and prevent us from finishing disconnect handling. + current = asyncio.current_task() + + if self._recv_task is not None and self._recv_task is not current: + self._recv_task.cancel() + + self._recv_task = None + self._ws = None + self._logged_in = False + + logger.warning( + "disconnected (session_id=%s, reason=%s, pending_verdicts=%d)", + self._session_id, + reason, + len(self._pending_verdicts), + ) + + await self._fire_on_disconnect(reason) + + if self._closing: + return + + loop = asyncio.get_running_loop() + + if self._connect_task is None or self._connect_task.done(): + self._connect_task = loop.create_task(self.connect()) + + async def _fire_on_disconnect(self, reason: str) -> None: + """Invoke the on_disconnect callback, catching any exception.""" + if self._on_disconnect is None: + return + + try: + result = self._on_disconnect(reason) + + if asyncio.iscoroutine(result): + await result + except Exception: + logger.exception("on_disconnect callback raised") + + async def _fire_on_reconnect(self) -> None: + """Invoke the on_reconnect callback, catching any exception.""" + if self._on_reconnect is None: + return + + try: + result = self._on_reconnect() + + if asyncio.iscoroutine(result): + await result + except Exception: + logger.exception("on_reconnect callback raised") + + # -- Verdict-wait support -- + + def register_pending( + self, + event_id: str, + ) -> asyncio.Future[pb.Verdict]: + """Return a future awaiting a verdict for ``event_id``. + + Reuses an existing pending future if one is already registered, + so concurrent callers waiting on the same event_id see the same + verdict once it arrives. Must be called BEFORE sending the event + to avoid the race where the verdict arrives before the future exists. + """ + existing = self._pending_verdicts.get(event_id) + + if existing is not None: + return existing + + loop = asyncio.get_running_loop() + fut: asyncio.Future[pb.Verdict] = loop.create_future() + self._pending_verdicts[event_id] = fut + + return fut + + def _evict_resolved_verdicts(self) -> None: + """Remove oldest resolved futures when the dict exceeds the cap.""" + while len(self._pending_verdicts) > _MAX_PENDING_VERDICTS: + # Evict the oldest entry (dict preserves insertion order). + oldest_id = next(iter(self._pending_verdicts)) + oldest_fut = self._pending_verdicts[oldest_id] + if oldest_fut.done(): + del self._pending_verdicts[oldest_id] + else: + # Don't evict an in-flight future; stop evicting. + break + + async def wait_for_verdict( + self, + event_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for a verdict for ``event_id``. + + ``timeout`` is mode-derived (see :meth:`block_timeout`): + a positive float for ``MODE_BLOCK`` (fail-open at timeout), + ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the + verdict, or ``None`` on timeout (fail-open). + + Resolved futures are kept in ``_pending_verdicts`` so a second + waiter on the same event_id (e.g. BaseTool.ainvoke firing after + ToolNode.ainvoke already consumed the verdict) finds the already- + resolved future and returns instantly instead of timing out. + Timed-out (unconsumed) futures are removed immediately; resolved + futures are evicted when the dict exceeds ``_MAX_PENDING_VERDICTS``. + """ + fut = self.register_pending(event_id) + + try: + result = await asyncio.wait_for(fut, timeout=timeout) + # Keep resolved future in dict for late waiters; cap size. + self._evict_resolved_verdicts() + return result + except TimeoutError: + logger.warning( + "Verdict timeout for event_id=%s after %ss", + event_id, + timeout, + ) + # Timed-out future is useless — remove so a retry can + # register a fresh one. + self._pending_verdicts.pop(event_id, None) + return None + + async def wait_for_tool_verdict( + self, + parent_run_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for the verdict of the LLM pair that produced this tool call. + + Looks up the LLM event_id from the run_id map and awaits its verdict. + Returns ``None`` (fail-open) when the parent LLM has not been seen, + e.g. tools invoked outside an LLM flow. + """ + event_id = self._run_id_to_event_id.get(parent_run_id) + + if event_id is None: + logger.debug( + "No LLM context for parent_run_id=%s, skipping verdict wait", + parent_run_id, + ) + + return None + + return await self.wait_for_verdict(event_id, timeout) + + async def wait_for_tool_call_verdict( + self, + tool_call_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for the verdict of the LLM pair that emitted ``tool_call_id``. + + Every tool call in an AIMessage carries the id the LLM assigned + to it; that id is threaded through LangChain to the ToolNode + invocation. Looking it up against ``_tool_call_id_to_event_id`` + gives the producing LLM's event_id, correct under parallel + agents where a ``last_llm_event_id``-style global would race. + + Returns ``None`` (fail-open) when ``tool_call_id`` is empty or + unknown (direct ToolNode invocation, pre-LLM tool, or the LLM + pair that produced it was evicted from the LRU map). + """ + if not tool_call_id: + return None + + event_id = self._tool_call_id_to_event_id.get(tool_call_id) + + if event_id is None: + logger.debug( + "No LLM context for tool_call_id=%s, skipping verdict wait", + tool_call_id, + ) + + return None + + return await self.wait_for_verdict(event_id, timeout) From 087f47fa23e8ae5158853573e372bdf06ced2145 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Mon, 15 Jun 2026 23:03:20 +0200 Subject: [PATCH 5/6] Fix: linter --- sdk/python/adrian/__init__.py | 48 +---------------------------------- 1 file changed, 1 insertion(+), 47 deletions(-) diff --git a/sdk/python/adrian/__init__.py b/sdk/python/adrian/__init__.py index 8a264cf..d0d6d81 100644 --- a/sdk/python/adrian/__init__.py +++ b/sdk/python/adrian/__init__.py @@ -802,7 +802,7 @@ async def patched_astream( # --- 5. ToolNode --- -def _extract_tool_calls( +def _extract_tool_calls( # pyright: ignore[reportUnusedFunction] state: dict[str, Any] | list[BaseMessage] | Any, ) -> list[dict[str, Any]]: """Extract tool_calls from ToolNode input (all three dispatch shapes). @@ -879,52 +879,6 @@ def _patch_tool_node() -> None: original_ainvoke = ToolNode.ainvoke original_astream = getattr(ToolNode, "astream", None) - async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 - """Returns True if tools should be BLOCKED.""" - ws = _ws_client - if ws is None: - return False - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for(ws._login_ack_received.wait(), timeout=5.0) # pyright: ignore[reportPrivateUsage] - except TimeoutError: - logger.warning("ToolNode: LoginAck not received within 5s; blocking") - return True - if not ws.policy_active(): - return False - - tc_ids: list[str] = [ - str(tc.get("id")) for tc in _extract_tool_calls(state) if tc.get("id") - ] - if not tc_ids: - return False - - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tc_ids[0], timeout) - if verdict is None: - logger.warning("ToolNode: verdict timeout, blocking (fail-closed)") - return True - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return True - return False - - def _build_blocked(state: Any) -> dict[str, list[ToolMessage]]: # noqa: ANN401 - tc_ids = [tc.get("id") for tc in _extract_tool_calls(state) if tc.get("id")] - return { - "messages": [ - ToolMessage( - content="[BLOCKED by security policy]", tool_call_id=tid, name="" - ) - for tid in tc_ids - ] - } - def patched_invoke( self: Any, input: Any, From 10a0b1a5d12bafd1ab19f2a6b4daf7c129664dcb Mon Sep 17 00:00:00 2001 From: yanny-sec Date: Wed, 17 Jun 2026 17:58:25 +0100 Subject: [PATCH 6/6] fix(sdk): gate sync tools under create_react_agent/create_agent The single BaseTool verdict gate never fired for synchronous @tool functions dispatched by create_react_agent / create_agent. Fix 60s fail-open timeout on HITL mode for sync tools. Removes old sdk/adrian directory. --- sdk/adrian/__init__.py | 1247 ----------------- sdk/adrian/ws.py | 1038 -------------- sdk/python/adrian/__init__.py | 92 +- sdk/python/adrian/ws.py | 2 +- sdk/python/tests/test_block_mode.py | 252 +++- sdk/python/tests/test_block_mode_races.py | 18 +- sdk/python/tests/test_extract_tool_calls.py | 2 +- .../tests/test_parent_context_scenarios.py | 2 +- 8 files changed, 312 insertions(+), 2341 deletions(-) delete mode 100644 sdk/adrian/__init__.py delete mode 100644 sdk/adrian/ws.py diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py deleted file mode 100644 index 8a264cf..0000000 --- a/sdk/adrian/__init__.py +++ /dev/null @@ -1,1247 +0,0 @@ -"""Adrian: multi-agent event capture SDK for LangChain/LangGraph as of 2026-05-10. - -Initialise with a single call and all LLM / tool activity is automatically -captured, paired, and emitted as ``PairedEvent`` objects through registered -handlers:: - - import adrian - - adrian.init(api_key="...") - -Events are paired (chat_model_start + llm_end, tool_start + tool_end), -enriched with agent identity and parent context, and emitted through -pluggable handlers (JSONL, WebSocket, custom). - -""" - -# pyright: reportUnknownVariableType=false -# pyright: reportUnknownMemberType=false -# pyright: reportUnknownArgumentType=false -# pyright: reportUnknownLambdaType=false - -from __future__ import annotations - -import asyncio -import atexit -import logging -import os -from pathlib import Path -from typing import Any -from uuid import uuid4 - -from langchain_core.callbacks.manager import CallbackManager -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, ToolMessage -from langchain_core.runnables.base import Runnable -from langchain_core.runnables.config import ensure_config - -from adrian.config import ( - AdrianConfig, - OnAuditCallback, - OnBlockCallback, - OnDisconnectCallback, - OnEventCallback, - OnMcpServerCallback, - OnReconnectCallback, - OnVerdictCallback, - get_config, - is_initialized, - set_config, -) -from adrian.context import AgentContextTracker, get_invocation_id, set_invocation_id -from adrian.format.types import PairedEvent -from adrian.handler import AdrianCallbackHandler -from adrian.handlers.jsonl import JSONLHandler -from adrian.hooks import EventHandler, HookRegistry -from adrian.mcp import ( - McpServer, - _patch_mcp_adapter, # pyright: ignore[reportPrivateUsage] - mcp_servers, -) -from adrian.mcp import ( - _reset as _reset_mcp, # pyright: ignore[reportPrivateUsage] -) -from adrian.pairing import EventPairBuffer -from adrian.pii import ( - PiiConfig, - PiiRedactor, - RedactingHandler, - RedactionStrategy, - redact_text, -) -from adrian.proto import event_pb2 as pb -from adrian.session_persistence import resolve_session_id -from adrian.types import ToolCallRecord, VerdictContext -from adrian.ws import WebSocketClient - -__version__ = "1.0.2" -__all__ = [ - "init", - "shutdown", - "get_handler", - "AdrianCallbackHandler", - "AdrianConfig", - "EventHandler", - "JSONLHandler", - "McpServer", - "OnAuditCallback", - "OnBlockCallback", - "OnDisconnectCallback", - "OnEventCallback", - "OnMcpServerCallback", - "OnReconnectCallback", - "OnVerdictCallback", - "PairedEvent", - "PiiConfig", - "PiiRedactor", - "RedactingHandler", - "RedactionStrategy", - "ToolCallRecord", - "VerdictContext", - "__version__", - "mcp_servers", - "redact_text", -] - -logger = logging.getLogger("adrian") - -_hooks: HookRegistry | None = None -_handler: AdrianCallbackHandler | None = None -_ws_client: WebSocketClient | None = None -_fork_handler_registered: bool = False - - -# ------------------------------------------------------------------ -# Fork safety -# ------------------------------------------------------------------ - - -def _reset_after_fork() -> None: - """Drop inherited Adrian state in a forked child process. - - Registered via ``os.register_at_fork`` on the first :func:`init` call. - Nulls out module globals so the child does not silently share the - parent's WebSocket socket, writing to a shared socket from two - processes interleaves bytes on the wire, corrupting frames the - server cannot parse. - - Triggered by pre-fork deployments (``gunicorn --preload``, - ``multiprocessing.Pool``, Celery prefork). The child must call - :func:`init` again from its worker startup hook to establish its - own connection. - """ - global _hooks, _handler, _ws_client # noqa: PLW0603 - - _hooks = None - _handler = None - _ws_client = None - _reset_mcp() - - -# ------------------------------------------------------------------ -# Public API -# ------------------------------------------------------------------ - - -def init( - api_key: str | None = None, - log_file: str | Path = "events.jsonl", - handlers: list[EventHandler] | None = None, - auto_instrument: bool = True, - log_level: str | None = None, - ws_url: str | None = None, - session_id: str | None = None, - block_timeout: float = 30.0, - on_event: OnEventCallback | None = None, - on_verdict: OnVerdictCallback | None = None, - on_block: OnBlockCallback | None = None, - on_audit: OnAuditCallback | None = None, - on_disconnect: OnDisconnectCallback | None = None, - on_reconnect: OnReconnectCallback | None = None, - on_mcp_server: OnMcpServerCallback | None = None, - replay_buffer_frames: int = 1000, -) -> None: - """Initialise the Adrian SDK. - - Creates the event pairing buffer, agent context tracker, and hook - registry, then monkey-patches LangChain so every LLM call and tool - invocation is captured as a ``PairedEvent``. - - Events are emitted through registered handlers. If no handlers are - provided, defaults to a ``JSONLHandler`` writing to ``log_file``. - - Transport (WebSocket, HTTP, etc.) is not managed by the SDK, pass - a pre-configured handler via the ``handlers`` list instead. - - Args: - api_key: Adrian API key. Falls back to ``ADRIAN_API_KEY`` env - var. Stored in config for handlers that need it. - log_file: Path to the JSONL output file (used when no handlers - are explicitly provided). - handlers: List of ``EventHandler`` instances to receive paired - events. If ``None``, defaults to ``JSONLHandler(log_file)``. - auto_instrument: Patch LangChain / LangGraph at import time. - log_level: Optional override for the ``adrian`` logger's level. - ``None`` (default) inherits from the application's logging - config; pass e.g. ``"DEBUG"`` to force-enable verbose SDK - logging without touching global config. - ws_url: WebSocket URL for the Adrian server (e.g. - ``"ws://localhost:8080/ws"``). Falls back to ``ADRIAN_WS_URL``. - When set and ``handlers`` is ``None``, a ``WebSocketClient`` is - auto-registered alongside the default ``JSONLHandler``. Requires - ``api_key``. - session_id: Session identifier. Falls back to - ``ADRIAN_SESSION_ID``, then to a per-cwd persistent UUID. - See :mod:`adrian.session_persistence`. - block_timeout: Max seconds to wait for a verdict in ``MODE_BLOCK`` - before fail-open. Ignored in ``MODE_ALERT`` (no wait) and - ``MODE_HITL`` (wait indefinitely). Falls back to - ``ADRIAN_BLOCK_TIMEOUT``. - on_event: Callback for every paired event. - on_verdict: Callback for every verdict. - on_block: Callback for BLOCK-tier verdicts (M3 / M4). Notification - only; return value is ignored. - on_audit: Callback for NOTIFY-tier verdicts (M2). - on_disconnect: Callback fired when the WebSocket is lost. Receives - a reason string. Sync or async. - on_reconnect: Callback fired when the WebSocket reconnects after a - prior disconnect. Does not fire on initial connection. Sync - or async. - on_mcp_server: Callback fired when an MCP server is registered or - updated. Receives the freshly-registered ``McpServer``. Does - NOT fire on no-op re-observations. Sync or async. - replay_buffer_frames: Max serialised frames kept in the in-memory - ring for replay after a transient WS outage (server restart, - ALB shuffle). Each frame is one ``ClientFrame.paired_batch`` - (~4KB). Default 1000 frames ≈ ~4MB RAM. Falls back to - ``ADRIAN_REPLAY_BUFFER_FRAMES``. At capacity each further - append evicts the oldest; a one-shot WARN fires on first fill - and cumulative drops are logged on the next reconnect. - """ - global _hooks, _handler, _ws_client, _fork_handler_registered # noqa: PLW0603 - - if not _fork_handler_registered and hasattr(os, "register_at_fork"): - os.register_at_fork(after_in_child=_reset_after_fork) - _fork_handler_registered = True - - try: - loop: asyncio.AbstractEventLoop | None = asyncio.get_running_loop() - except RuntimeError: - loop = None - - resolved_key = api_key or os.getenv("ADRIAN_API_KEY") or None - resolved_file = Path(os.getenv("ADRIAN_LOG_FILE", str(log_file))) - # Default to the hosted Adrian backend so `adrian.init(api_key=...)` - # Just Works for freemium users. Self-hosted users override via - # ws_url= or ADRIAN_WS_URL. - resolved_ws_url = ( - os.getenv("ADRIAN_WS_URL") or ws_url or "wss://adrian.secureagentics.ai/ws" - ) - resolved_session = ( - os.getenv("ADRIAN_SESSION_ID") or session_id or resolve_session_id() - ) - resolved_block_timeout = float( - os.getenv("ADRIAN_BLOCK_TIMEOUT", str(block_timeout)), - ) - - resolved_replay_buffer_frames = replay_buffer_frames - env_replay = os.getenv("ADRIAN_REPLAY_BUFFER_FRAMES", "").strip() - - if env_replay: - try: - resolved_replay_buffer_frames = int(env_replay) - except ValueError: - logger.warning( - "ADRIAN_REPLAY_BUFFER_FRAMES=%r is not an int; " - "falling back to kwarg default %d", - env_replay, - replay_buffer_frames, - ) - - if resolved_ws_url and not resolved_key: - logger.warning( - "ws_url is set but no api_key provided. Set api_key or " - "ADRIAN_API_KEY; the server will reject the WS connection." - ) - - config = AdrianConfig( - api_key=resolved_key, - log_file=resolved_file, - log_level=log_level, - session_id=resolved_session, - ws_url=resolved_ws_url, - block_timeout=resolved_block_timeout, - on_event=on_event, - on_verdict=on_verdict, - on_block=on_block, - on_audit=on_audit, - on_disconnect=on_disconnect, - on_reconnect=on_reconnect, - on_mcp_server=_make_on_mcp_server_chain(on_mcp_server), - replay_buffer_frames=resolved_replay_buffer_frames, - ) - - set_config(config) - - if log_level is not None: - # Only override the adrian logger's level when the caller asks - # for it explicitly. Default behaviour respects whatever the - # application configured via logging.basicConfig / .config. - logging.getLogger("adrian").setLevel( - getattr(logging, log_level.upper(), logging.INFO), - ) - - # Build handler list, then optionally wrap with PII redaction - handler_list: list[EventHandler] = [] - - if handlers: - handler_list = list(handlers) - else: - handler_list.append(JSONLHandler(path=resolved_file)) - - if resolved_ws_url: - _ws_client = WebSocketClient( - url=resolved_ws_url, - session_id=config.session_id, - api_key=resolved_key or "", - on_disconnect=on_disconnect, - on_reconnect=on_reconnect, - on_login_ack=_send_mcp_inventory, - replay_buffer_frames=resolved_replay_buffer_frames, - ) - handler_list.append(_ws_client) - - handler_list = [RedactingHandler(h) for h in handler_list] - - # Create hook registry and register handlers - _hooks = HookRegistry() - - for h in handler_list: - _hooks.register(h) - - # Create pairing and context tracking components - pair_buffer = EventPairBuffer() - context_tracker = AgentContextTracker() - - # Create handler with new components - _handler = AdrianCallbackHandler( - pair_buffer=pair_buffer, - context_tracker=context_tracker, - hooks=_hooks, - config=config, - ) - - if _ws_client is not None: - # Back-reference so the recv loop can dispatch verdicts into the - # handler's block/audit/verdict callback machinery. - _ws_client._handler = _handler # pyright: ignore[reportPrivateUsage] - - if loop is not None: - _ws_client.schedule_connect(loop) - else: - logger.debug( - "No running event loop at init(); WebSocket will connect on " - "first send from within an async context." - ) - - if auto_instrument: - _auto_instrument_langchain() - - # MCP server tracking is independent of LangChain auto-instrumentation, - # it observes a different library (langchain-mcp-adapters) and is the - # only path the SDK has to learn about MCP servers. Always run. - _patch_mcp_adapter() - - atexit.register(shutdown) - logger.info( - "Adrian v%s initialised (handlers=%d, ws=%s)", - __version__, - len(_hooks), - resolved_ws_url or "disabled", - ) - - -def shutdown() -> None: - """Close all handlers and reset state.""" - global _hooks, _handler, _ws_client # noqa: PLW0603 - - if _hooks is not None: - try: - loop = asyncio.get_running_loop() - loop.create_task(_hooks.close()) - except RuntimeError: - asyncio.run(_hooks.close()) - - _hooks = None - - _handler = None - _ws_client = None - set_config(None) - - -def get_handler() -> AdrianCallbackHandler | None: - """Return the SDK's callback handler, or ``None`` if uninitialised. - - Useful when ``adrian.init(auto_instrument=False)`` is set and you - need to attach the handler to LangChain calls explicitly, e.g.:: - - adrian.init(api_key=..., auto_instrument=False) - handler = adrian.get_handler() - await llm.ainvoke(prompt, config={"callbacks": [handler]}) - - The handler is wired into Adrian's WS hook chain at ``init()`` - time; constructing a fresh ``AdrianCallbackHandler`` directly will - not emit events. - """ - return _handler - - -# ------------------------------------------------------------------ -# Internal helpers -# ------------------------------------------------------------------ - - -def _get_callback_handler() -> AdrianCallbackHandler | None: - """Return the current callback handler (closure helper).""" - return _handler - - -def _get_config() -> AdrianConfig | None: - """Return the current config without raising (closure helper).""" - if not is_initialized(): - return None - - return get_config() - - -async def _send_mcp_inventory() -> None: - """Send the current MCP server registry as a ``ClientFrame``. - - Triggers: once per connect (after each ``LoginAck``) and on every - ``on_mcp_server`` registry change. The server replaces its full - list on every frame, so a fresh snapshot is correct on every fire. - No-op when the WebSocket transport is disabled or when the registry - is empty (the registry is additive, so an empty snapshot is - indistinguishable from "not yet observed", sending it would only - log a ``which=`` warning on the server). - """ - ws = _ws_client - - if ws is None: - return - - servers = mcp_servers() - - if not servers: - return - - frame = pb.ClientFrame() - - for server in servers: - added = frame.mcp_inventory.servers.add() - added.name = server.name - added.transport = server.transport - added.endpoint = server.endpoint - - await ws._send_frame(frame) # pyright: ignore[reportPrivateUsage] - - -def _make_on_mcp_server_chain( - user_cb: OnMcpServerCallback | None, -) -> OnMcpServerCallback: - """Compose ``_send_mcp_inventory`` with the user's ``on_mcp_server``. - - Schedules the inventory sync as a fire-and-forget task on the - running loop (if any) and forwards transparently to the user's - callback so its sync-vs-async return shape is preserved for - :func:`adrian.callbacks.fire` to handle. When no loop is running, - the inventory sync is skipped, the next ``LoginAck`` (which only - fires once a loop is up) will catch up. - """ - - def chain(server: McpServer) -> Any: # noqa: ANN401 - try: - loop = asyncio.get_running_loop() - except RuntimeError: - pass - else: - loop.create_task(_send_mcp_inventory()) - - if user_cb is None: - return None - - return user_cb(server) - - return chain - - -def _inject_callbacks(config: Any) -> Any: # noqa: ANN401 - """Merge the Adrian handler into a LangChain ``RunnableConfig``. - - Args: - config: An existing LangChain RunnableConfig or ``None``. - - Returns: - A config dict guaranteed to contain the Adrian handler. - """ - handler = _get_callback_handler() - - if handler is None: - return ensure_config(config) - - config = ensure_config(config) - callbacks = config.get("callbacks") or [] - - if hasattr(callbacks, "handlers"): - callbacks = list(callbacks.handlers) # pyright: ignore[reportAttributeAccessIssue] - elif not isinstance(callbacks, list): - callbacks = [callbacks] if callbacks else [] - else: - callbacks = list(callbacks) - - handler_types = [type(h).__name__ for h in callbacks] - - if "AdrianCallbackHandler" not in handler_types: - callbacks.insert(0, handler) - - config["callbacks"] = callbacks - - return config - - -# ------------------------------------------------------------------ -# Auto-instrumentation -# ------------------------------------------------------------------ - - -def _auto_instrument_langchain() -> None: - """Apply all monkey-patches to LangChain / LangGraph.""" - try: - _patch_runnable() - _patch_callback_manager() - _patch_chat_model() - _patch_langgraph() - _patch_tool_node() - _patch_base_tool() - _patch_agent_executor() - logger.debug("LangChain auto-instrumentation applied") - except ImportError: - logger.debug("LangChain not found, skipping auto-instrumentation") - except Exception: - logger.exception("Auto-instrumentation failed") - - -# --- 1. Runnable --- - - -def _patch_runnable() -> None: - """Patch ``Runnable.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" - if getattr(Runnable, "_adrian_patched", False): - return - - original_invoke = Runnable.invoke - original_ainvoke = Runnable.ainvoke - original_astream = Runnable.astream - original_stream = Runnable.stream - - def patched_invoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - return original_invoke(self, input, config, **kwargs) - - async def patched_ainvoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - return await original_ainvoke(self, input, config, **kwargs) - - async def patched_astream( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - async for chunk in original_astream(self, input, config, **kwargs): - yield chunk - - def patched_stream( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - yield from original_stream(self, input, config, **kwargs) - - Runnable.invoke = patched_invoke # type: ignore[assignment] - Runnable.ainvoke = patched_ainvoke # type: ignore[assignment] - Runnable.astream = patched_astream # type: ignore[assignment] - Runnable.stream = patched_stream # type: ignore[assignment] - Runnable._adrian_patched = True # type: ignore[attr-defined] - logger.debug("Patched Runnable.invoke / ainvoke") - - -# --- 2. CallbackManager --- - - -def _patch_callback_manager() -> None: - """Patch ``CallbackManager.__init__`` to always include Adrian.""" - if getattr(CallbackManager, "_adrian_cbm_patched", False): - return - - original_configure = CallbackManager.configure - - def patched_configure( - _cls: Any, # noqa: ANN401 - inheritable_callbacks: Any = None, # noqa: ANN401 - local_callbacks: Any = None, # noqa: ANN401 - verbose: bool = False, - inheritable_tags: Any = None, # noqa: ANN401 - local_tags: Any = None, # noqa: ANN401 - inheritable_metadata: Any = None, # noqa: ANN401 - local_metadata: Any = None, # noqa: ANN401 - **extra: Any, # noqa: ANN401 - ) -> Any: # noqa: ANN401 - """Inject Adrian handler into inheritable callbacks. - - ``**extra`` forwards any kwargs newer langchain-core releases - add to ``CallbackManager.configure`` (e.g. 1.3 added - ``langsmith_inheritable_metadata``) so the patch stays - forward-compatible without re-declaring every signature change. - """ - handler = _get_callback_handler() - - if handler: - if inheritable_callbacks is None: - inheritable_callbacks = [handler] - elif isinstance(inheritable_callbacks, list): - handler_types = [type(h).__name__ for h in inheritable_callbacks] - - if "AdrianCallbackHandler" not in handler_types: - inheritable_callbacks = [handler, *inheritable_callbacks] - elif hasattr(inheritable_callbacks, "handlers"): - handler_types = [ - type(h).__name__ for h in inheritable_callbacks.handlers - ] - - if "AdrianCallbackHandler" not in handler_types: - inheritable_callbacks.handlers.insert(0, handler) - - return original_configure( - inheritable_callbacks=inheritable_callbacks, - local_callbacks=local_callbacks, - verbose=verbose, - inheritable_tags=inheritable_tags, - local_tags=local_tags, - inheritable_metadata=inheritable_metadata, - local_metadata=local_metadata, - **extra, - ) - - CallbackManager.configure = classmethod( # type: ignore[assignment] - lambda _cls, *a, **kw: patched_configure(_cls, *a, **kw), # pyright: ignore[reportCallIssue] - ) - CallbackManager._adrian_cbm_patched = True # type: ignore[attr-defined] - logger.debug("Patched CallbackManager.configure") - - -# --- 3. BaseChatModel --- - - -def _patch_chat_model() -> None: - """Patch ``BaseChatModel.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" - if getattr(BaseChatModel, "_adrian_chat_model_patched", False): - return - - original_invoke = BaseChatModel.invoke - original_ainvoke = BaseChatModel.ainvoke - original_astream = BaseChatModel.astream - original_stream = BaseChatModel.stream - - def patched_invoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) - - async def patched_ainvoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - return await original_ainvoke(self, input, config=config, **kwargs) - - async def patched_astream( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - async for chunk in original_astream(self, input, config=config, **kwargs): - yield chunk - - def patched_stream( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - yield from original_stream(self, input, config=config, **kwargs) - - BaseChatModel.invoke = patched_invoke # type: ignore[assignment] - BaseChatModel.ainvoke = patched_ainvoke # type: ignore[assignment] - BaseChatModel.astream = patched_astream # type: ignore[assignment] - BaseChatModel.stream = patched_stream # type: ignore[assignment] - BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] - logger.debug("Patched BaseChatModel.invoke / ainvoke") - - -# --- 4. LangGraph Pregel --- - - -def _patch_langgraph() -> None: - """Patch ``Pregel.invoke`` / ``ainvoke`` / ``astream``. - - The async patches also set the invocation_id ContextVar at the - top-level call so all sub-agent events share the same ID. - """ - try: - from langgraph.pregel import Pregel - except ImportError: - return - - if getattr(Pregel, "_adrian_pregel_patched", False): - return - - original_invoke = Pregel.invoke - original_ainvoke = Pregel.ainvoke - original_astream = Pregel.astream - - def patched_invoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync graph invocation.""" - config = _inject_callbacks(config) - - return original_invoke(self, input, config=config, **kwargs) - - async def patched_ainvoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks and set invocation_id. - - Only the top-level call sets the invocation_id. Nested calls - (sub-agent ainvoke) inherit it via contextvars propagation. - """ - config = _inject_callbacks(config) - - current = get_invocation_id() - token = None - - if current is None: - uuid_ = uuid4() - token = set_invocation_id(str(uuid_)) - - try: - return await original_ainvoke(self, input, config=config, **kwargs) - finally: - if token is not None: - token.var.reset(token) - - async def patched_astream( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks and set invocation_id for streaming.""" - config = _inject_callbacks(config) - - current = get_invocation_id() - token = None - - if current is None: - uuid_ = uuid4() - token = set_invocation_id(str(uuid_)) - - try: - async for chunk in original_astream(self, input, config=config, **kwargs): - yield chunk - finally: - if token is not None: - token.var.reset(token) - - Pregel.invoke = patched_invoke # type: ignore[assignment] - Pregel.ainvoke = patched_ainvoke # type: ignore[assignment] - Pregel.astream = patched_astream # type: ignore[assignment] - Pregel._adrian_pregel_patched = True # type: ignore[attr-defined] - logger.debug("Patched Pregel.invoke / ainvoke / astream") - - -# --- 5. ToolNode --- - - -def _extract_tool_calls( - state: dict[str, Any] | list[BaseMessage] | Any, -) -> list[dict[str, Any]]: - """Extract tool_calls from ToolNode input (all three dispatch shapes). - - Returns full tool_call dicts (with id, name, args) for backward - compat with tests and callers that need the full shape. - """ - # Shape 3: per-tool-call dict from _afunc dispatch - if isinstance(state, dict) and "tool_call" in state: - tc = state["tool_call"] - if isinstance(tc, dict) and tc.get("id"): - return [tc] - tc_id = getattr(tc, "id", None) - if tc_id: - return [ - { - "id": tc_id, - "name": getattr(tc, "name", ""), - "args": getattr(tc, "args", {}), - } - ] - return [] - - # Shape 1/2: state dict or message list - if isinstance(state, dict): - messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - elif isinstance(state, list): - messages = list(state) - else: - return [] - - for msg in reversed(messages): - if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): - return msg.tool_calls # type: ignore[no-any-return] - - return [] - - -def _should_halt(verdict: pb.Verdict) -> bool: - """Decide whether a verdict should halt tool execution. - - HITL resolutions override per-MAD policy when present. - """ - if verdict.HasField("hitl"): - return not verdict.hitl.continue_execution - - mad_prefix = verdict.mad_code[:2] - return { - "M0": verdict.policy.policy_m0, - "M2": verdict.policy.policy_m2, - "M3": verdict.policy.policy_m3, - "M4": verdict.policy.policy_m4, - }.get(mad_prefix, False) - - -def _patch_tool_node() -> None: - """Patch ToolNode for callback injection + async verdict gate. - - ToolNode dispatches tools via tool.invoke (sync) even within async - Pregel. BaseTool.invoke can't await a verdict from the event loop - thread, so we add the verdict gate here on ToolNode.ainvoke — the - entry point Pregel calls before tool dispatch begins. This is a - complementary gate to BaseTool (which covers direct callers). - """ - try: - from langgraph.prebuilt import ToolNode - except ImportError: - return - - if getattr(ToolNode, "_adrian_tool_node_patched", False): - return - - original_invoke = ToolNode.invoke - original_ainvoke = ToolNode.ainvoke - original_astream = getattr(ToolNode, "astream", None) - - async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 - """Returns True if tools should be BLOCKED.""" - ws = _ws_client - if ws is None: - return False - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for(ws._login_ack_received.wait(), timeout=5.0) # pyright: ignore[reportPrivateUsage] - except TimeoutError: - logger.warning("ToolNode: LoginAck not received within 5s; blocking") - return True - if not ws.policy_active(): - return False - - tc_ids: list[str] = [ - str(tc.get("id")) for tc in _extract_tool_calls(state) if tc.get("id") - ] - if not tc_ids: - return False - - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tc_ids[0], timeout) - if verdict is None: - logger.warning("ToolNode: verdict timeout, blocking (fail-closed)") - return True - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return True - return False - - def _build_blocked(state: Any) -> dict[str, list[ToolMessage]]: # noqa: ANN401 - tc_ids = [tc.get("id") for tc in _extract_tool_calls(state) if tc.get("id")] - return { - "messages": [ - ToolMessage( - content="[BLOCKED by security policy]", tool_call_id=tid, name="" - ) - for tid in tc_ids - ] - } - - def patched_invoke( - self: Any, - input: Any, - config: Any = None, - **kwargs: Any, # noqa: A002, ANN401 - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) - - async def patched_ainvoke( - self: Any, - input: Any, - config: Any = None, - **kwargs: Any, # noqa: A002, ANN401 - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - # Verdict gate removed — BaseTool.ainvoke/arun is the single - # gate layer. Gating here too caused double-gate: ToolNode - # consumed the verdict future, BaseTool's gate registered a - # fresh future that never resolved → 30s timeout on a benign - # verdict. Callback injection is kept so events still flow. - return await original_ainvoke(self, input, config=config, **kwargs) - - async def patched_astream( - self: Any, - input: Any, - config: Any = None, - **kwargs: Any, # noqa: A002, ANN401 - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - assert original_astream is not None # guarded by line below - async for chunk in original_astream(self, input, config=config, **kwargs): - yield chunk - - ToolNode.invoke = patched_invoke # type: ignore[assignment] - ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] - if original_astream is not None: - ToolNode.astream = patched_astream # type: ignore[assignment] - ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] - logger.debug("Patched ToolNode.invoke / ainvoke / astream") - - -# --- 6. BaseTool (universal verdict gate) --- - - -_BLOCKED_CONTENT = "[BLOCKED by security policy]" - - -def _patch_base_tool() -> None: - """Patch ``BaseTool.invoke`` and ``BaseTool.ainvoke`` with the verdict gate. - - Every LangChain tool — whether dispatched by ToolNode, AgentExecutor, - create_react_agent, or a manual ``tool.invoke(tool_call)`` loop — - funnels through ``BaseTool.invoke`` (sync) or ``BaseTool.ainvoke`` - (async). Gating here covers all frameworks in one place. - - The gate extracts ``tool_call_id`` from the input (a ``ToolCall`` - TypedDict), awaits the classifier verdict for the producing LLM - event, and returns a ``[BLOCKED]`` string instead of running the - tool body when the verdict is in-scope (M3/M4 under MODE_BLOCK). - - In MODE_BLOCK, verdict timeout is fail-closed (block the tool) - because the absence of a verdict in block mode is a policy violation. - In MODE_ALERT, no gate fires at all (skip). - """ - from langchain_core.tools import BaseTool - from langchain_core.tools.base import ( - _is_tool_call, # pyright: ignore[reportPrivateUsage] - ) - - if getattr(BaseTool, "_adrian_base_tool_patched", False): - return - - original_invoke = BaseTool.invoke - original_ainvoke = BaseTool.ainvoke - - def _extract_tool_call_id(input: Any) -> str | None: # noqa: A002, ANN401 - """Extract tool_call_id from a ToolCall input, or None.""" - if isinstance(input, dict) and _is_tool_call(input): - return input.get("id") - return None - - async def _async_gate(tool_call_id: str) -> bool: - """Returns True if the tool should be BLOCKED.""" - ws = _ws_client - if ws is None: - return False - - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: - logger.warning( - "BaseTool: LoginAck not received within 5s; " - "blocking tool (refusing to run without verified policy)" - ) - return True - - if not ws.policy_active(): - return False - - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) - - if verdict is None: - # Fail-closed in block mode: no verdict = block. - logger.warning( - "BaseTool: verdict timeout for tool_call_id=%s; " - "blocking (fail-closed in MODE_BLOCK)", - tool_call_id, - ) - return True - - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return True - - return False - - def _sync_gate(tool_call_id: str) -> bool: - """Sync verdict gate — works for pure-sync and worker-thread callers. - - Pure-sync (no event loop): runs ``_async_gate`` via - ``loop.run_until_complete``. - - Worker-thread (Pregel dispatches sync tools on a thread-pool - worker while the event loop runs on the main thread): bridges - the async gate to the main loop via ``run_coroutine_threadsafe`` - and blocks the worker thread until the verdict resolves. - - Event-loop thread (calling tool.invoke directly from async - code): cannot block — returns False (skip). The async path - (BaseTool.ainvoke) handles this case. - """ - ws = _ws_client - if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] - return False - - try: - loop = asyncio.get_event_loop() - except RuntimeError: - return False - - if not loop.is_running(): - # Pure-sync caller — safe to block - return loop.run_until_complete(_async_gate(tool_call_id)) - - # Check if we're on a worker thread (no running loop on THIS - # thread) vs the event-loop thread itself. - try: - asyncio.get_running_loop() - # We ARE on the event-loop thread — can't block it. - return False - except RuntimeError: - pass - - # Worker thread: bridge the async gate to the main loop. - main_loop = getattr(ws, "_loop", None) - if main_loop is None or not main_loop.is_running(): - return False - - try: - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - future = asyncio.run_coroutine_threadsafe( - _async_gate(tool_call_id), main_loop - ) - return future.result(timeout=timeout if timeout else 60.0) - except Exception: - return False - - def _blocked_response(tc_id: str) -> Any: # noqa: ANN401 - """Return a blocked response compatible with ToolNode. - - Returns a ToolMessage for create_react_agent / ToolNode - compatibility. Falls back to bare string on import failure. - """ - try: - return ToolMessage(content=_BLOCKED_CONTENT, tool_call_id=tc_id, name="") - except Exception: - return _BLOCKED_CONTENT - - def patched_invoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - tc_id = _extract_tool_call_id(input) - if tc_id and _sync_gate(tc_id): - return _blocked_response(tc_id) - return original_invoke(self, input, config=config, **kwargs) - - async def patched_ainvoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, - ) -> Any: # noqa: ANN401 - config = _inject_callbacks(config) - tc_id = _extract_tool_call_id(input) - if tc_id and await _async_gate(tc_id): - return _blocked_response(tc_id) - return await original_ainvoke(self, input, config=config, **kwargs) - - original_arun = BaseTool.arun - - async def patched_arun( - self: Any, # noqa: ANN401 - tool_input: Any, # noqa: ANN401 - *args: Any, - tool_call_id: str | None = None, - **kwargs: Any, - ) -> Any: # noqa: ANN401 - """Gate on arun — AgentExecutor calls tool.arun directly.""" - if tool_call_id and await _async_gate(tool_call_id): - return _blocked_response(tool_call_id) - return await original_arun( - self, tool_input, *args, tool_call_id=tool_call_id, **kwargs - ) - - BaseTool.invoke = patched_invoke # type: ignore[assignment] - BaseTool.ainvoke = patched_ainvoke # type: ignore[assignment] - BaseTool.arun = patched_arun # type: ignore[assignment] - BaseTool._adrian_base_tool_patched = True # type: ignore[attr-defined] - logger.debug("Patched BaseTool.invoke / ainvoke / arun (universal verdict gate)") - - -# --- 7. AgentExecutor (tool_call_id on agent_action, not on tool.arun) --- - - -def _patch_agent_executor() -> None: - """Patch AgentExecutor._aperform_agent_action for the executor path. - - AgentExecutor calls tool.arun without forwarding tool_call_id, - so the BaseTool.arun gate can't extract it. The tool_call_id lives - on agent_action.tool_call_id (set by OpenAI-style parsers). We - intercept here, await the verdict, and return a blocked observation - instead of calling the tool. - """ - AgentExecutor = None - AgentStep = None - for mod_path in ("langchain_classic.agents.agent", "langchain.agents.agent"): - try: - mod = __import__(mod_path, fromlist=["AgentExecutor", "AgentStep"]) - AgentExecutor = getattr(mod, "AgentExecutor", None) - AgentStep = getattr(mod, "AgentStep", None) - if AgentExecutor and AgentStep: - break - except ImportError: - continue - - if AgentExecutor is None or AgentStep is None: - return - if getattr(AgentExecutor, "_adrian_executor_patched", False): - return - - original_aperform = AgentExecutor._aperform_agent_action - - async def patched_aperform( - self: Any, - name_to_tool_map: Any, - color_mapping: Any, # noqa: ANN401 - agent_action: Any, - run_manager: Any = None, # noqa: ANN401 - ) -> Any: # noqa: ANN401 - tc_id = getattr(agent_action, "tool_call_id", None) - if tc_id: - ws = _ws_client - if ws is not None: - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: - logger.warning( - "AgentExecutor: LoginAck not received within 5s; blocking" - ) - return AgentStep( - action=agent_action, observation=_BLOCKED_CONTENT - ) - if ws.policy_active(): - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) - if verdict is None: - logger.warning( - "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", - tc_id, - ) - return AgentStep( - action=agent_action, observation=_BLOCKED_CONTENT - ) - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return AgentStep( - action=agent_action, observation=_BLOCKED_CONTENT - ) - return await original_aperform( - self, name_to_tool_map, color_mapping, agent_action, run_manager - ) - - AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] - AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] - logger.debug("Patched AgentExecutor._aperform_agent_action") diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py deleted file mode 100644 index 169cbdc..0000000 --- a/sdk/adrian/ws.py +++ /dev/null @@ -1,1038 +0,0 @@ -"""Async WebSocket ``EventHandler`` that streams ``PairedEvent`` to the worker core API. - -Converts each ``PairedEvent`` into a ``pb.PairedEvent`` protobuf, wraps it in a -``ClientFrame.paired_batch``, and sends it over a long-lived WebSocket -connection. Verdicts received back resolve block-mode futures and fire the -callback handler's verdict processing. - -Implements the ``EventHandler`` protocol so it slots into the SDK's hook -registry alongside ``JSONLHandler``. -""" - -from __future__ import annotations - -import asyncio -import contextlib -import json -import logging -import time -from collections import OrderedDict, deque -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - -import websockets - -if TYPE_CHECKING: - from adrian.config import OnDisconnectCallback, OnReconnectCallback - from adrian.handler import AdrianCallbackHandler - -from adrian.format.types import ( - AgentContext, - LlmPairData, - PairedEvent, - ParentContext, -) -from adrian.proto import event_pb2 as pb - -logger = logging.getLogger("adrian.ws") - -SCHEMA_VERSION = 2 - -_INITIAL_BACKOFF = 1.0 -_MAX_BACKOFF = 30.0 -# Server close code: quota exhausted. Spec'd in -# server/internal/websocket/handler.go (closeQuotaExceeded). Returning -# every 30s would hammer the server while quota is depleted; one -# minute is slow enough to be cheap, fast enough that the next hourly -# / daily / monthly window-rollover is picked up within tolerance. -_QUOTA_EXHAUSTED_CLOSE_CODE = 4003 -_QUOTA_RECONNECT_DELAY = 60.0 -# Cap on in-flight LLM run_id → event_id mappings. Evicted LRU-style; -# block-mode lookups for evicted entries fail open. -_MAX_RUN_ID_MAP = 1024 -# Cap on in-flight tool_call_id → event_id mappings (block-mode correlation). -_MAX_TOOL_CALL_MAP = 1024 -# Cap on resolved verdict futures kept for late-waiter replay. -_MAX_PENDING_VERDICTS = 512 - -_DEFAULT_REPLAY_BUFFER_FRAMES = 1000 - -# Heartbeat tuning. 10s interval / 15s pong timeout detects half-open -# connections (ALB idle cut, NAT drop, dead remote process) without -# flooding the wire. Kept in sync with the backend's pingInterval / -# pongTimeout, if these change, update server/internal/websocket/handler.go. -_PING_INTERVAL = 10.0 -_PING_TIMEOUT = 15.0 - -_PROVIDER_PREFIXES: dict[str, str] = { - "chatanthropic": "anthropic", - "chatopenai": "openai", - "chatgooglegenai": "google", - "chatcohere": "cohere", - "chatmistralai": "mistral", -} - -_PAIR_TYPE_MAP: dict[str, pb.PairType.ValueType] = { - "llm": pb.PAIR_TYPE_LLM, - "tool": pb.PAIR_TYPE_TOOL, -} - - -def _derive_provider(model_class_name: str) -> str: - """Derive the LLM provider from the model class name. - - Args: - model_class_name: Class name like ``"ChatAnthropic"`` or ``"ChatOpenAI"``. - - Returns: - Provider string (e.g. ``"anthropic"``), or the class name lower-cased - if no known prefix matches. - """ - key = model_class_name.lower() - - return _PROVIDER_PREFIXES.get(key, key) - - -def _fill_agent_context( - pb_ctx: pb.AgentContext, src: AgentContext | ParentContext -) -> None: - """Copy an AgentContext / ParentContext dataclass into its proto counterpart.""" - pb_ctx.agent_id = src.agent_id - pb_ctx.system_prompt = src.system_prompt - pb_ctx.user_instruction = src.user_instruction - - -def _safe_cancel( - task_or_future: asyncio.Task[Any] | asyncio.Future[Any] | None, -) -> None: - """Cancel a task / future, ignoring closed-loop errors at shutdown. - - Adrian's ``atexit`` handler may run after the user's loop has been - closed; in that path ``adrian.shutdown`` spawns a new ``asyncio.run`` - and walks each handler's ``close()``. Tasks bound to the *old* loop - can no longer be cancelled (``call_soon`` raises ``Event loop is - closed``). Swallowing the error here keeps the cleanup path quiet, - the task will be reaped when the dead loop is GC'd. - """ - if task_or_future is None or task_or_future.done(): - return - # "Event loop is closed", old loop is gone, nothing to cancel. - with contextlib.suppress(RuntimeError): - task_or_future.cancel() - - -def _paired_event_to_proto(event: PairedEvent) -> pb.PairedEvent: - """Convert a ``PairedEvent`` dataclass into its protobuf form. - - ``parent.agent_id`` empty-string signals "no parent agent". - ``parent_run_id`` empty-string signals "no parent in run tree". - """ - proto = pb.PairedEvent( - event_id=event.event_id, - invocation_id=event.invocation_id, - session_id=event.session_id, - run_id=event.run_id, - parent_run_id=event.parent_run_id, - timestamp=event.timestamp, - pair_type=_PAIR_TYPE_MAP.get(event.pair_type, pb.PAIR_TYPE_UNSPECIFIED), - ) - - _fill_agent_context(proto.agent, event.agent) - - if event.parent is not None: - _fill_agent_context(proto.parent, event.parent) - - if isinstance(event.data, LlmPairData): - proto.llm.model = event.data.model - - for msg in event.data.messages: - pb_msg = proto.llm.messages.add() - pb_msg.role = msg["role"] - pb_msg.content = msg["content"] - - proto.llm.output = event.data.output - - for tc in event.data.tool_calls: - pb_tc = proto.llm.tool_calls.add() - pb_tc.name = tc["name"] - pb_tc.args = json.dumps(tc["args"], default=str) - pb_tc.id = tc["id"] - - if event.data.usage is not None: - proto.llm.usage.prompt_tokens = event.data.usage["prompt_tokens"] - proto.llm.usage.completion_tokens = event.data.usage["completion_tokens"] - proto.llm.usage.total_tokens = event.data.usage["total_tokens"] - else: - # Union is LlmPairData | ToolPairData; this branch is the - # ToolPairData case. - proto.tool.tool_name = event.data.tool_name - proto.tool.tool_call_id = event.data.tool_call_id or "" - proto.tool.input = event.data.input - proto.tool.output = event.data.output - - if event.metadata: - proto.metadata_json = json.dumps(event.metadata, default=str).encode() - - return proto - - -class WebSocketClient: - """Streams ``PairedEvent`` instances to the worker core API. - - Connects eagerly via :meth:`schedule_connect` with exponential backoff, - auto-detects the LLM provider on the first LLM pair, sends paired events - as protobuf frames, and resolves block-mode futures when verdicts arrive. - """ - - def __init__( - self, - url: str, - session_id: str, - api_key: str, - handler: AdrianCallbackHandler | None = None, - on_disconnect: OnDisconnectCallback | None = None, - on_reconnect: OnReconnectCallback | None = None, - on_login_ack: Callable[[], Awaitable[None]] | None = None, - replay_buffer_frames: int = _DEFAULT_REPLAY_BUFFER_FRAMES, - ) -> None: - """Initialise without connecting. - - Args: - url: WebSocket endpoint URL. - session_id: Session ID sent in the login frame. - api_key: Adrian API key for the ``Authorization`` header. - handler: Callback handler for verdict processing. - on_disconnect: Fired when the connection is lost (sync or async). - Receives a reason string. - on_reconnect: Fired when the connection re-establishes after a - prior disconnect (sync or async). Does not fire on initial - connect. - on_login_ack: Async hook fired after each ``LoginAck`` frame is - applied, once per (re)connect. Used internally to push a - fresh ``McpInventory`` on every login. Exceptions are - logged and swallowed. - replay_buffer_frames: Ring-buffer capacity (frame count, not - bytes). When the cap is reached each further append evicts - the oldest frame; a one-shot WARN fires on first fill, and - the cumulative drop count is logged at WARN on the next - reconnect. - """ - self._url = url - self._session_id = session_id - self._api_key = api_key - self._handler = handler - self._on_disconnect = on_disconnect - self._on_reconnect = on_reconnect - self._on_login_ack_cb = on_login_ack - self._provider = "" - self._model = "" - # Server-supplied execution-mode policy. Populated when the - # first ServerFrame{login_ack} arrives after each (re)connect. - # ``policy_active()`` and ``block_timeout()`` read this state - # to decide whether the patched ToolNode should wait for a - # verdict and how long. - self._mode: int = pb.MODE_UNSPECIFIED - self._policy: pb.PolicySnapshot | None = None - # Set the first time a ``ServerFrame{login_ack}`` is applied. - # Used in two places: - # 1. ``on_paired_event`` defensively pre-registers a - # verdict-wait future when this event is unset, so the - # very first tool-bearing LLM emission is covered even - # though the recv loop hasn't yet processed LoginAck and - # ``policy_active()`` reads False. - # 2. The patched ``ToolNode.ainvoke`` ``await``s this event - # (with a short timeout) before deciding whether to wait - # for a verdict, so the first ToolNode invocation cannot - # run-through-without-waiting in the same window. - # Stays set across disconnect/reconnect because mode/policy - # state survives, a fresh LoginAck on reconnect simply re-sets - # an already-set event. - self._login_ack_received: asyncio.Event = asyncio.Event() - self._ws: websockets.ClientConnection | None = None - self._logged_in = False - self._connected = asyncio.Event() - self._connect_task: asyncio.Task[None] | None = None - self._recv_task: asyncio.Task[None] | None = None - # Set by close() so _handle_disconnect knows not to spawn a reconnect - # during a graceful shutdown. - self._closing = False - # Event loop running the WebSocket tasks. Captured on first - # connect so _sync_gate can bridge async waits from worker - # threads via run_coroutine_threadsafe. - self._loop: asyncio.AbstractEventLoop | None = None - # Futures awaited by the patched ToolNode.ainvoke when the - # active mode requires a wait (BLOCK or HITL). Each resolves - # with the matching ``Verdict`` proto. Futures survive a - # disconnect: a late verdict after reconnect still resolves - # the wait; if none arrives, ``wait_for_verdict``'s timeout - # produces a natural fail-open in BLOCK mode. - self._pending_verdicts: dict[str, asyncio.Future[pb.Verdict]] = {} - # Maps LLM pair run_id → event_id so a subsequent tool call can - # look up the verdict by its parent_run_id (the LLM's run_id). - # LRU-capped at _MAX_RUN_ID_MAP to bound memory on long sessions. - self._run_id_to_event_id: OrderedDict[str, str] = OrderedDict() - # Verdict-correlation map: maps each tool_call.id emitted by - # an LLM to the event_id of the LLM pair that emitted it. - # Populated on every LLM PairedEvent that has tool_calls. - # Consulted by the patched ``ToolNode.ainvoke`` so each tool - # in a parallel fan-out waits on its own producing LLM's - # verdict, not a global "last" pointer. LRU-capped at - # ``_MAX_TOOL_CALL_MAP``. - self._tool_call_id_to_event_id: OrderedDict[str, str] = OrderedDict() - # Serialises the lazy login-then-send sequence so two concurrent - # on_paired_event calls (parallel agents) cannot both send a login. - # Reused by _replay_buffer_to_ws to coordinate with live sends. - self._login_lock = asyncio.Lock() - # Ring buffer of recently serialised ClientFrame bytes. Appended - # only from the offline-or-send-failure paths in _send_frame; the - # happy path bypasses the ring entirely. Drained on reconnect. - self._replay_buffer: deque[bytes] = deque(maxlen=replay_buffer_frames) - # Flips True on the first append that reaches maxlen. Gates the - # one-shot "buffer full" WARN so we don't flood logs. - self._replay_buffer_filled: bool = False - # Monotonic counter of frames dropped due to buffer overflow - # (oldest evicted when a new append arrives at a full ring). - # Logged at WARN on the next reconnect. - self._replay_buffer_dropped: int = 0 - # True while the reconnect path is draining the replay buffer. - # Live sends observed during this window are routed back into - # the same deque so they slot in AFTER the pre-outage tail - # rather than racing onto the wire ahead of older buffered - # frames. Flipped on as the first sync line of - # _replay_buffer_to_ws and cleared in its finally. - self._replaying: bool = False - # Set by _handle_disconnect, cleared on successful reconnect. - # Used to gate on_reconnect and measure downtime. - self._disconnected_at: float | None = None - # One-shot delay applied before the next ``connect()`` attempt. - # Set when the server closes with a code that requests a longer - # wait (currently only 4003 quota exhausted); cleared by - # ``connect()`` after honouring it. ``None`` means use the - # standard exponential schedule. - self._next_reconnect_delay: float | None = None - - # -- Mode / policy state (populated by LoginAck) -- - - def policy_active(self) -> bool: - """Whether the active server mode requires waiting on verdicts. - - Single predicate consulted by the patched ``ToolNode.ainvoke``. - Returns ``True`` for ``MODE_BLOCK`` and ``MODE_HITL``; ``False`` - for ``MODE_ALERT`` and unset (pre-login) state. - """ - return self._mode in (pb.MODE_BLOCK, pb.MODE_HITL) - - def block_timeout(self, kwarg_default: float) -> float | None: - """Effective per-tool-call wait timeout for the active mode. - - - ``MODE_BLOCK``: ``kwarg_default`` (typically 30s), fail-open - if the server doesn't classify in time. - - ``MODE_HITL``: ``None``, wait indefinitely for human review. - - ``MODE_ALERT`` / unset: ``0``, caller short-circuits before - registering a future. - """ - if self._mode == pb.MODE_BLOCK: - return kwarg_default - elif self._mode == pb.MODE_HITL: - return None - else: - return 0 - - # -- EventHandler protocol -- - - async def on_paired_event(self, event: PairedEvent) -> None: - """Send a paired event over the WebSocket. - - Auto-detects the LLM provider on the first LLM pair, updates the - run_id → event_id map for block mode, converts the dataclass to - protobuf, and sends a ``ClientFrame.paired_batch`` frame. - - For LLM pairs that carry tool_calls, registers the verdict-wait - future *before* the frame leaves the SDK. This closes the race - where a fast verdict roundtrip resolves and is dropped before - the patched ``ToolNode.ainvoke`` reaches its own - ``register_pending`` call. The matching ``register_pending`` - from the wait site is a get-or-create that returns the existing - future. - - Args: - event: The paired event to stream. - """ - if ( - event.pair_type == "llm" - and not self._provider - and isinstance(event.data, LlmPairData) - ): - self._model = event.data.model - self._provider = _derive_provider(event.data.model) - - if event.pair_type == "llm": - self._run_id_to_event_id[event.run_id] = event.event_id - self._run_id_to_event_id.move_to_end(event.run_id) - - if len(self._run_id_to_event_id) > _MAX_RUN_ID_MAP: - self._run_id_to_event_id.popitem(last=False) - - # Populate tool_call.id → event_id so each tool call can block - # on its own producing LLM's verdict under parallel fan-out. - if isinstance(event.data, LlmPairData) and event.data.tool_calls: - for tc in event.data.tool_calls: - tc_id = tc.get("id") or "" - - if not tc_id: - continue - - self._tool_call_id_to_event_id[tc_id] = event.event_id - self._tool_call_id_to_event_id.move_to_end(tc_id) - - if len(self._tool_call_id_to_event_id) > _MAX_TOOL_CALL_MAP: - self._tool_call_id_to_event_id.popitem(last=False) - - # Pre-register the wait future so an eager verdict - # cannot race ahead of the ToolNode patch. Gated on - # ``policy_active()`` so ALERT-mode sessions don't - # accumulate futures that will never be resolved or - # awaited, except for the very first event of the - # session, where ``LoginAck`` may not yet have been - # processed by the recv loop and ``policy_active()`` - # therefore reads False even when the mode will - # imminently be set to BLOCK or HITL. Pre-register - # defensively in that window; in ALERT mode the gate - # filters out every subsequent event so the leak is - # bounded to one orphan future per session. - if self.policy_active() or not self._login_ack_received.is_set(): - self.register_pending(event.event_id) - - proto = _paired_event_to_proto(event) - frame = pb.ClientFrame() - added = frame.paired_batch.events.add() - added.CopyFrom(proto) - - await self._send_frame(frame) - - async def close(self) -> None: - """Cancel background tasks and close the WebSocket. - - Sets ``_closing`` so any in-flight ``_handle_disconnect`` does not - spawn a reconnect during graceful shutdown. - - Defensive against the ``atexit`` shutdown path: ``adrian.shutdown`` - spawns a fresh ``asyncio.run`` loop after the user's loop has - already closed, so background tasks bound to the old loop can no - longer be cancelled cleanly (``call_soon`` raises - ``Event loop is closed``). Skip the cancel in that case, the - old loop is gone, the task will be reaped by GC. - """ - self._closing = True - - _safe_cancel(self._recv_task) - self._recv_task = None - _safe_cancel(self._connect_task) - self._connect_task = None - - if self._ws is not None: - with contextlib.suppress(Exception): - await asyncio.wait_for(self._ws.close(), timeout=2.0) - self._ws = None - - for fut in self._pending_verdicts.values(): - if not fut.done(): - _safe_cancel(fut) - self._pending_verdicts.clear() - - # -- Connection lifecycle -- - - def schedule_connect(self, loop: asyncio.AbstractEventLoop) -> None: - """Schedule :meth:`connect` as a background task on the given loop.""" - if self._connect_task is None or self._connect_task.done(): - self._connect_task = loop.create_task(self.connect()) - - async def connect(self) -> None: - """Establish the WebSocket with exponential-backoff retry. - - Heartbeat (``ping_interval`` / ``ping_timeout``) is configured on - the underlying ``websockets`` client; if the server fails to pong - within ``_PING_TIMEOUT`` the library closes the connection and - ``_recv_loop`` surfaces the disconnect via ``_handle_disconnect``. - - On a reconnect (``_disconnected_at`` set by a prior disconnect), - drains the replay buffer and fires ``on_reconnect``. Login is - deferred to ``_send_frame`` / ``_replay_buffer_to_ws`` so the - auto-detected provider/model is included. An ``api_key``, if - configured, is sent as an ``Authorization: Bearer `` header. - - Honours ``_next_reconnect_delay`` if a previous disconnect set - it (e.g. 4003 quota exhausted requests a slower retry). The - delay is consumed on the first attempt; subsequent failures - fall back to the standard exponential schedule. - """ - initial_delay = self._next_reconnect_delay - self._next_reconnect_delay = None - - if initial_delay is not None: - logger.info( - "delaying reconnect by %.0fs (server-requested)", - initial_delay, - ) - await asyncio.sleep(initial_delay) - - backoff = _INITIAL_BACKOFF - loop = asyncio.get_running_loop() - self._loop = loop - - headers: dict[str, str] = {} - - if self._api_key: - headers["Authorization"] = f"Bearer {self._api_key}" - - while True: - try: - self._ws = await websockets.connect( - self._url, - additional_headers=headers, - ping_interval=_PING_INTERVAL, - ping_timeout=_PING_TIMEOUT, - ) - self._connected.set() - self._recv_task = loop.create_task(self._recv_loop()) - - disconnected_at = self._disconnected_at - is_reconnect = disconnected_at is not None - if disconnected_at is not None: - downtime = time.monotonic() - disconnected_at - self._disconnected_at = None - logger.warning( - "WebSocket reconnected: %s (session_id=%s, downtime=%.2fs)", - self._url, - self._session_id, - downtime, - ) - - if self._replay_buffer_dropped > 0: - logger.warning( - "replay buffer dropped %d frames due to overflow " - "before this reconnect (session_id=%s); " - "increase replay_buffer_frames if this recurs", - self._replay_buffer_dropped, - self._session_id, - ) - else: - logger.info("WebSocket connected: %s", self._url) - - # Drain anything buffered while we were offline, even - # on the very first connect. ``_send_mcp_inventory`` - # and other init-time emitters queue frames before the - # WS is open; without this drain those frames never - # ship until something else triggers a live send. - if self._replay_buffer: - logger.info( - "replaying %d buffered frames after connect", - len(self._replay_buffer), - ) - await self._replay_buffer_to_ws() - - if is_reconnect: - await self._fire_on_reconnect() - - return - except Exception: - logger.warning( - "WebSocket connect to %s failed, retrying in %.0fs", - self._url, - backoff, - ) - try: - await asyncio.sleep(backoff) - except RuntimeError: - # Loop closed mid-retry (atexit shutdown). Bail out - # quietly rather than dumping a traceback. - return - backoff = min(backoff * 2, _MAX_BACKOFF) - - async def _send_login(self, ws: websockets.ClientConnection) -> None: - """Send the mandatory SessionLogin frame.""" - frame = pb.ClientFrame() - frame.login.session_id = self._session_id - frame.login.llm_stack.provider = self._provider - frame.login.llm_stack.model = self._model - frame.login.schema_version = SCHEMA_VERSION - await ws.send(frame.SerializeToString()) - logger.debug( - "Sent login (session=%s, provider=%s, model=%s, schema=%d)", - self._session_id, - self._provider, - self._model, - SCHEMA_VERSION, - ) - - async def _send_frame(self, frame: pb.ClientFrame) -> None: - """Serialise and send a ``ClientFrame``, buffering on failure. - - Happy path (connected + healthy): send over WS, bypass the ring - entirely, zero overhead. Offline on entry: buffer for replay. - During reconnect replay: buffer as well, so the drain loop picks - this frame up after the pre-outage tail (preserves order across - the outage boundary). Send raises: buffer the in-flight frame - then trigger ``_handle_disconnect`` so state is cleared and - reconnect is spawned. - """ - frame_bytes = frame.SerializeToString() - kind = frame.WhichOneof("frame") - - if not self._connected.is_set() or self._replaying: - self._buffer_frame(frame_bytes) - reason = "disconnected" if not self._connected.is_set() else "replaying" - logger.info( - "buffered for replay (session_id=%s, kind=%s, " - "buffer_size=%d, reason=%s)", - self._session_id, - kind, - len(self._replay_buffer), - reason, - ) - - return - - ws = self._ws - - if ws is None: - self._buffer_frame(frame_bytes) - - return - - try: - async with self._login_lock: - if not self._logged_in: - await self._send_login(ws) - self._logged_in = True - - await ws.send(frame_bytes) - logger.debug("Sent %s frame", kind) - except Exception: - # Send raised, we cannot confirm the server received this frame. - # Buffer it so the reconnect replay ships it, then clean up state. - self._buffer_frame(frame_bytes) - await self._handle_disconnect("send_failure") - - async def _recv_loop(self) -> None: - """Read ``ServerFrame``s, dispatch by oneof kind. - - First frame after each (re)login MUST be ``login_ack``; anything - else is a protocol error and we tear the connection down so the - reconnect path can try again. Subsequent frames are - ``verdict``s. Unknown oneof kinds (future server additions like - a quota-exhausted signal) are logged and dropped rather than - crashing the loop. - - Any exit path (clean close, exception, cancellation) calls - ``_handle_disconnect`` via ``finally`` so state is cleared and a - reconnect is spawned. - """ - ws = self._ws - - if ws is None: - return - - awaiting_login_ack = True - try: - async for message in ws: - if not isinstance(message, bytes): - continue - - frame = pb.ServerFrame() - frame.ParseFromString(message) - kind = frame.WhichOneof("frame") - - if awaiting_login_ack: - awaiting_login_ack = False - if kind != "login_ack": - logger.error( - "expected ServerFrame{login_ack} as first frame, " - "got %r, closing connection", - kind, - ) - return - - if kind == "login_ack": - self._on_login_ack(frame.login_ack) - elif kind == "verdict": - await self._on_verdict_frame(frame.verdict) - else: - logger.warning( - "ignoring unknown ServerFrame kind %r " - "(future server addition?)", - kind, - ) - except asyncio.CancelledError: - # Expected on graceful shutdown or when _handle_disconnect cancels - # us from the send_failure path. Re-raise to honour cancellation. - raise - except Exception as exc: - logger.warning("recv_loop exited: %s", exc) - finally: - close_code = getattr(ws, "close_code", None) - - if close_code == _QUOTA_EXHAUSTED_CLOSE_CODE: - self._next_reconnect_delay = _QUOTA_RECONNECT_DELAY - - reason = ( - f"quota_exhausted (close={close_code})" - if close_code == _QUOTA_EXHAUSTED_CLOSE_CODE - else "recv_loop_exit" - ) - await self._handle_disconnect(reason) - - def _on_login_ack(self, ack: pb.LoginAck) -> None: - """Apply the org's effective execution-mode policy. - - Fires the ``on_login_ack`` hook (if configured) as a fire-and-forget - task on the running loop so the recv loop doesn't block waiting on it. - """ - self._mode = ack.policy.mode - self._policy = ack.policy - self._login_ack_received.set() - logger.info( - "LoginAck received: mode=%s policy_m0=%s policy_m2=%s " - "policy_m3=%s policy_m4=%s", - pb.Mode.Name(ack.policy.mode), - ack.policy.policy_m0, - ack.policy.policy_m2, - ack.policy.policy_m3, - ack.policy.policy_m4, - ) - - if self._on_login_ack_cb is not None: - asyncio.create_task(self._run_login_ack_cb()) - - async def _run_login_ack_cb(self) -> None: - """Invoke the on_login_ack hook, swallowing exceptions.""" - if self._on_login_ack_cb is None: - return - try: - await self._on_login_ack_cb() - except Exception: - logger.exception("on_login_ack hook raised") - - async def _on_verdict_frame(self, verdict: pb.Verdict) -> None: - """Fire callbacks then resolve the matching pending future, if any. - - The future is left in ``_pending_verdicts`` after ``set_result`` so - a later ``register_pending`` (e.g. from the patched ToolNode after - the verdict has already round-tripped) returns the resolved - future and the wait completes immediately. ``wait_for_verdict`` - owns the cleanup: its ``finally`` pops the entry after the await - returns. - """ - logger.info( - "Verdict received: event_id=%s mad_code=%s mode=%s hitl=%s", - verdict.event_id, - verdict.mad_code or "-", - pb.Mode.Name(verdict.policy.mode), - verdict.HasField("hitl"), - ) - - if self._handler is not None: - await self._handler.handle_verdict(verdict) - - fut = self._pending_verdicts.get(verdict.event_id) - - if fut is None: - if verdict.HasField("hitl"): - logger.warning( - "HITL resolution for unknown event_id=%s, ignoring " - "(stale resolution from a prior SDK process)", - verdict.event_id, - ) - return - - if not fut.done(): - fut.set_result(verdict) - - # -- Resilience: buffering, replay, disconnect/reconnect -- - - def _buffer_frame(self, frame_bytes: bytes) -> None: - """Append a serialised frame to the replay ring. - - Tracks overflow drops and fires the one-shot "buffer full" WARN. - Called only from the offline or send-failure paths in - ``_send_frame``, the happy path bypasses the ring entirely. - """ - if len(self._replay_buffer) == self._replay_buffer.maxlen: - self._replay_buffer_dropped += 1 - - self._replay_buffer.append(frame_bytes) - - if ( - not self._replay_buffer_filled - and len(self._replay_buffer) == self._replay_buffer.maxlen - ): - self._replay_buffer_filled = True - logger.warning( - "adrian replay buffer reached capacity (%d frames); " - "further frames will evict oldest. Tune via " - "replay_buffer_frames or ADRIAN_REPLAY_BUFFER_FRAMES.", - self._replay_buffer.maxlen, - ) - - async def _replay_buffer_to_ws(self) -> None: - """Reissue buffered frames over the current WebSocket. - - Sends ``SessionLogin`` first if not already logged in (the server - requires it as the first frame on every new connection). Uses - ``_login_lock`` so a concurrent live send does not race on the - login check. - - Drains the deque one frame at a time via ``popleft`` inside a - ``while`` loop, rather than taking a snapshot up front. That - way, a live ``_send_frame`` call observed during the drain - routes its frame to the back of the same deque (because - ``_replaying`` is set) and this loop picks it up in the next - iteration, preserving across-outage order - ``[pre-outage] → [live during replay] → [post-replay live]``. - - On a mid-drain send failure, the failed frame is put back at - the front with ``appendleft`` and the function returns; the - next reconnect resumes from exactly where this one stopped. - """ - ws = self._ws - - if ws is None: - return - - self._replaying = True - try: - async with self._login_lock: - if not self._logged_in: - try: - await self._send_login(ws) - self._logged_in = True - except Exception as exc: - logger.warning( - "replay aborted: login send failed: %s", - exc, - ) - - return - - sent = 0 - while self._replay_buffer: - frame_bytes = self._replay_buffer.popleft() - try: - await ws.send(frame_bytes) - except Exception as exc: - # Put the failed frame back at the front so the next - # reconnect's drain resumes from exactly this point. - self._replay_buffer.appendleft(frame_bytes) - logger.warning( - "replay aborted after %d frame(s), %d remaining: %s", - sent, - len(self._replay_buffer), - exc, - ) - - return - sent += 1 - - logger.info("replayed %d buffered frames", sent) - self._replay_buffer_dropped = 0 - self._replay_buffer_filled = False - finally: - self._replaying = False - - async def _handle_disconnect(self, reason: str) -> None: - """Clear connection state and spawn a reconnect. - - Idempotent: if already disconnected or closing, returns immediately. - Pending verdict futures are intentionally left pending across the - disconnect, a late verdict after reconnect resolves them; if none - arrives, ``wait_for_verdict``'s timeout fires naturally. - """ - if self._closing or not self._connected.is_set(): - return - - self._connected.clear() - self._disconnected_at = time.monotonic() - - # Only cancel the recv task if we are not currently running inside it. - # When _recv_loop's own finally calls us, self._recv_task IS the - # current task, cancelling it would raise CancelledError inside the - # finally and prevent us from finishing disconnect handling. - current = asyncio.current_task() - - if self._recv_task is not None and self._recv_task is not current: - self._recv_task.cancel() - - self._recv_task = None - self._ws = None - self._logged_in = False - - logger.warning( - "disconnected (session_id=%s, reason=%s, pending_verdicts=%d)", - self._session_id, - reason, - len(self._pending_verdicts), - ) - - await self._fire_on_disconnect(reason) - - if self._closing: - return - - loop = asyncio.get_running_loop() - - if self._connect_task is None or self._connect_task.done(): - self._connect_task = loop.create_task(self.connect()) - - async def _fire_on_disconnect(self, reason: str) -> None: - """Invoke the on_disconnect callback, catching any exception.""" - if self._on_disconnect is None: - return - - try: - result = self._on_disconnect(reason) - - if asyncio.iscoroutine(result): - await result - except Exception: - logger.exception("on_disconnect callback raised") - - async def _fire_on_reconnect(self) -> None: - """Invoke the on_reconnect callback, catching any exception.""" - if self._on_reconnect is None: - return - - try: - result = self._on_reconnect() - - if asyncio.iscoroutine(result): - await result - except Exception: - logger.exception("on_reconnect callback raised") - - # -- Verdict-wait support -- - - def register_pending( - self, - event_id: str, - ) -> asyncio.Future[pb.Verdict]: - """Return a future awaiting a verdict for ``event_id``. - - Reuses an existing pending future if one is already registered, - so concurrent callers waiting on the same event_id see the same - verdict once it arrives. Must be called BEFORE sending the event - to avoid the race where the verdict arrives before the future exists. - """ - existing = self._pending_verdicts.get(event_id) - - if existing is not None: - return existing - - loop = asyncio.get_running_loop() - fut: asyncio.Future[pb.Verdict] = loop.create_future() - self._pending_verdicts[event_id] = fut - - return fut - - def _evict_resolved_verdicts(self) -> None: - """Remove oldest resolved futures when the dict exceeds the cap.""" - while len(self._pending_verdicts) > _MAX_PENDING_VERDICTS: - # Evict the oldest entry (dict preserves insertion order). - oldest_id = next(iter(self._pending_verdicts)) - oldest_fut = self._pending_verdicts[oldest_id] - if oldest_fut.done(): - del self._pending_verdicts[oldest_id] - else: - # Don't evict an in-flight future; stop evicting. - break - - async def wait_for_verdict( - self, - event_id: str, - timeout: float | None, - ) -> pb.Verdict | None: - """Wait for a verdict for ``event_id``. - - ``timeout`` is mode-derived (see :meth:`block_timeout`): - a positive float for ``MODE_BLOCK`` (fail-open at timeout), - ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the - verdict, or ``None`` on timeout (fail-open). - - Resolved futures are kept in ``_pending_verdicts`` so a second - waiter on the same event_id (e.g. BaseTool.ainvoke firing after - ToolNode.ainvoke already consumed the verdict) finds the already- - resolved future and returns instantly instead of timing out. - Timed-out (unconsumed) futures are removed immediately; resolved - futures are evicted when the dict exceeds ``_MAX_PENDING_VERDICTS``. - """ - fut = self.register_pending(event_id) - - try: - result = await asyncio.wait_for(fut, timeout=timeout) - # Keep resolved future in dict for late waiters; cap size. - self._evict_resolved_verdicts() - return result - except TimeoutError: - logger.warning( - "Verdict timeout for event_id=%s after %ss", - event_id, - timeout, - ) - # Timed-out future is useless — remove so a retry can - # register a fresh one. - self._pending_verdicts.pop(event_id, None) - return None - - async def wait_for_tool_verdict( - self, - parent_run_id: str, - timeout: float | None, - ) -> pb.Verdict | None: - """Wait for the verdict of the LLM pair that produced this tool call. - - Looks up the LLM event_id from the run_id map and awaits its verdict. - Returns ``None`` (fail-open) when the parent LLM has not been seen, - e.g. tools invoked outside an LLM flow. - """ - event_id = self._run_id_to_event_id.get(parent_run_id) - - if event_id is None: - logger.debug( - "No LLM context for parent_run_id=%s, skipping verdict wait", - parent_run_id, - ) - - return None - - return await self.wait_for_verdict(event_id, timeout) - - async def wait_for_tool_call_verdict( - self, - tool_call_id: str, - timeout: float | None, - ) -> pb.Verdict | None: - """Wait for the verdict of the LLM pair that emitted ``tool_call_id``. - - Every tool call in an AIMessage carries the id the LLM assigned - to it; that id is threaded through LangChain to the ToolNode - invocation. Looking it up against ``_tool_call_id_to_event_id`` - gives the producing LLM's event_id, correct under parallel - agents where a ``last_llm_event_id``-style global would race. - - Returns ``None`` (fail-open) when ``tool_call_id`` is empty or - unknown (direct ToolNode invocation, pre-LLM tool, or the LLM - pair that produced it was evicted from the LRU map). - """ - if not tool_call_id: - return None - - event_id = self._tool_call_id_to_event_id.get(tool_call_id) - - if event_id is None: - logger.debug( - "No LLM context for tool_call_id=%s, skipping verdict wait", - tool_call_id, - ) - - return None - - return await self.wait_for_verdict(event_id, timeout) diff --git a/sdk/python/adrian/__init__.py b/sdk/python/adrian/__init__.py index d0d6d81..c8ce231 100644 --- a/sdk/python/adrian/__init__.py +++ b/sdk/python/adrian/__init__.py @@ -863,7 +863,7 @@ def _patch_tool_node() -> None: ToolNode dispatches tools via tool.invoke (sync) even within async Pregel. BaseTool.invoke can't await a verdict from the event loop - thread, so we add the verdict gate here on ToolNode.ainvoke — the + thread, so we add the verdict gate here on ToolNode.ainvoke - the entry point Pregel calls before tool dispatch begins. This is a complementary gate to BaseTool (which covers direct callers). """ @@ -895,7 +895,7 @@ async def patched_ainvoke( **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 config = _inject_callbacks(config) - # Verdict gate removed — BaseTool.ainvoke/arun is the single + # Verdict gate removed - BaseTool.ainvoke/arun is the single # gate layer. Gating here too caused double-gate: ToolNode # consumed the verdict future, BaseTool's gate registered a # fresh future that never resolved → 30s timeout on a benign @@ -930,8 +930,8 @@ async def patched_astream( def _patch_base_tool() -> None: """Patch ``BaseTool.invoke`` and ``BaseTool.ainvoke`` with the verdict gate. - Every LangChain tool — whether dispatched by ToolNode, AgentExecutor, - create_react_agent, or a manual ``tool.invoke(tool_call)`` loop — + Every LangChain tool - whether dispatched by ToolNode, AgentExecutor, + create_react_agent, or a manual ``tool.invoke(tool_call)`` loop - funnels through ``BaseTool.invoke`` (sync) or ``BaseTool.ainvoke`` (async). Gating here covers all frameworks in one place. @@ -1007,56 +1007,64 @@ async def _async_gate(tool_call_id: str) -> bool: return False def _sync_gate(tool_call_id: str) -> bool: - """Sync verdict gate — works for pure-sync and worker-thread callers. - - Pure-sync (no event loop): runs ``_async_gate`` via - ``loop.run_until_complete``. - - Worker-thread (Pregel dispatches sync tools on a thread-pool - worker while the event loop runs on the main thread): bridges - the async gate to the main loop via ``run_coroutine_threadsafe`` - and blocks the worker thread until the verdict resolves. - - Event-loop thread (calling tool.invoke directly from async - code): cannot block — returns False (skip). The async path - (BaseTool.ainvoke) handles this case. + """Sync verdict gate - works for pure-sync and worker-thread callers. + + Worker-thread (the common LangGraph case: ``StructuredTool.ainvoke`` + dispatches a *sync* tool via ``run_in_executor(self.invoke)``, so the + gate runs on a thread-pool worker while the WS event loop runs on + another thread): bridges the async gate onto the WS loop via + ``run_coroutine_threadsafe`` and blocks the worker until the verdict + resolves. + + Pure-sync (no event loop anywhere): runs ``_async_gate`` to + completion on this thread. + + Event-loop thread (calling ``tool.invoke`` directly from async + code): cannot block without deadlocking - returns False (skip). + The async path (``BaseTool.ainvoke``) handles this case. + + Thread detection uses ``asyncio.get_running_loop()`` rather than + ``get_event_loop()``: the latter raises ``RuntimeError`` on a worker + thread (no loop *set* there, since Python 3.10+), which would + misclassify the worker-thread case as "no loop" and skip the gate - + leaving sync tools ungated under ``create_react_agent``. """ ws = _ws_client if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] return False - try: - loop = asyncio.get_event_loop() - except RuntimeError: - return False - - if not loop.is_running(): - # Pure-sync caller — safe to block - return loop.run_until_complete(_async_gate(tool_call_id)) - - # Check if we're on a worker thread (no running loop on THIS - # thread) vs the event-loop thread itself. + # Is THIS thread running an event loop? try: asyncio.get_running_loop() - # We ARE on the event-loop thread — can't block it. - return False except RuntimeError: - pass + pass # no loop on this thread: worker thread or pure-sync caller + else: + # On the event-loop thread - can't block it. The async gate + # (BaseTool.ainvoke) covers direct-from-async callers. + return False - # Worker thread: bridge the async gate to the main loop. + # Worker thread: the WS loop runs elsewhere - bridge onto it and + # block this worker until the verdict resolves. ``_async_gate`` owns + # the wait policy (bounded with fail-closed in MODE_BLOCK, indefinite + # in MODE_HITL where execution must pause until a human acts), so we + # wait on the future with no timeout of our own - a finite timeout + # here would fail-open a HITL hold once it elapsed. Fail closed (treat + # as halt) if the bridge itself raises. main_loop = getattr(ws, "_loop", None) - if main_loop is None or not main_loop.is_running(): - return False + if main_loop is not None and main_loop.is_running(): + try: + future = asyncio.run_coroutine_threadsafe( + _async_gate(tool_call_id), main_loop + ) + return future.result() + except Exception: + return True + # Pure-sync caller, no loop anywhere - run the gate to completion. try: - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - future = asyncio.run_coroutine_threadsafe( - _async_gate(tool_call_id), main_loop - ) - return future.result(timeout=timeout if timeout else 60.0) + return asyncio.run(_async_gate(tool_call_id)) except Exception: - return False + return True def _blocked_response(tc_id: str) -> Any: # noqa: ANN401 """Return a blocked response compatible with ToolNode. @@ -1102,7 +1110,7 @@ async def patched_arun( tool_call_id: str | None = None, **kwargs: Any, ) -> Any: # noqa: ANN401 - """Gate on arun — AgentExecutor calls tool.arun directly.""" + """Gate on arun - AgentExecutor calls tool.arun directly.""" if tool_call_id and await _async_gate(tool_call_id): return _blocked_response(tool_call_id) return await original_arun( diff --git a/sdk/python/adrian/ws.py b/sdk/python/adrian/ws.py index 169cbdc..9eb4bc3 100644 --- a/sdk/python/adrian/ws.py +++ b/sdk/python/adrian/ws.py @@ -977,7 +977,7 @@ async def wait_for_verdict( event_id, timeout, ) - # Timed-out future is useless — remove so a retry can + # Timed-out future is useless - remove so a retry can # register a fresh one. self._pending_verdicts.pop(event_id, None) return None diff --git a/sdk/python/tests/test_block_mode.py b/sdk/python/tests/test_block_mode.py index 0bbbdaf..742249b 100644 --- a/sdk/python/tests/test_block_mode.py +++ b/sdk/python/tests/test_block_mode.py @@ -146,7 +146,7 @@ async def test_in_scope_block_verdict_halts_tool(self, tmp_path: Path) -> None: The verdict gate lives on BaseTool (the universal layer), not ToolNode.ainvoke. Uses an async tool so BaseTool.ainvoke (not - BaseTool.invoke) is the entry point — matching the production + BaseTool.invoke) is the entry point - matching the production path for create_react_agent with async tools. """ @@ -186,7 +186,7 @@ async def _real_tool(x: str) -> str: result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] - # BaseTool.ainvoke gate blocks — tool body does NOT run. + # BaseTool.ainvoke gate blocks - tool body does NOT run. assert _real_tool.called is False # type: ignore[attr-defined] msgs = result["messages"] assert len(msgs) == 1 @@ -308,3 +308,251 @@ async def _real_tool(x: str) -> str: assert captured == ["hi"] assert not ws._pending_verdicts + + +class TestSyncToolNodeBlocking: + """Regression: sync (``def``) tools dispatched by ToolNode / create_react_agent. + + The tests in ``TestToolNodePatchBlocking`` use ``async def`` tools, so + they exercise ``BaseTool.ainvoke`` (the async gate). A sync ``def`` tool + takes a different path: ``StructuredTool.ainvoke`` has no coroutine, so it + runs ``self.invoke`` via ``run_in_executor`` on a worker thread. The gate + therefore lands in ``BaseTool.invoke`` -> ``_sync_gate`` on a thread that + is not running an event loop, and ``_sync_gate`` must bridge the gate onto + the WS loop. A regression here (e.g. probing the thread with + ``get_event_loop()``, which raises on a worker thread) silently skips the + gate and lets block-level tool calls run ungated under create_react_agent. + """ + + @staticmethod + def _prep(ws: WebSocketClient, policy_m4: bool, mad_code: str) -> None: + """Drive a logged-in MODE_BLOCK state with a pre-resolved verdict. + + ``ws._loop`` points at the test loop so the worker-thread bridge in + ``_sync_gate`` has a running target, mirroring production where the + WS loop lives on its own thread, separate from the Pregel worker. + """ + policy = _apply_mode(ws, pb.MODE_BLOCK, policy_m4=policy_m4) + ws._connected.set() + ws._loop = asyncio.get_running_loop() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + fut = ws.register_pending("llm-evt") + fut.set_result(pb.Verdict(event_id="llm-evt", mad_code=mad_code, policy=policy)) + + async def test_sync_tool_block_verdict_halts(self, tmp_path: Path) -> None: + """MODE_BLOCK + policy_m4 + M4 verdict: sync tool body must NOT run.""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + self._prep(ws, policy_m4=True, mad_code="M4_a") + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + # Sync tool body must NOT run; a BLOCKED ToolMessage is returned. + assert captured == [] + msgs = result["messages"] + assert len(msgs) == 1 + assert "BLOCKED" in msgs[0].content + + async def test_sync_tool_out_of_scope_runs(self, tmp_path: Path) -> None: + """MODE_BLOCK, M2 verdict with policy_m2 false: sync tool runs (no over-block).""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + self._prep(ws, policy_m4=True, mad_code="M2") # m2 not in policy scope + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == ["hi"] + + @staticmethod + def _prep_hitl( + ws: WebSocketClient, + ) -> tuple[pb.PolicySnapshot, asyncio.Future[pb.Verdict]]: + """MODE_HITL, logged in, with an UNRESOLVED pending verdict (held). + + Returns the policy and the pending future so the test can resolve it + later, standing in for a human approve/reject. + """ + policy = _apply_mode(ws, pb.MODE_HITL, policy_m4=True) + ws._connected.set() + ws._loop = asyncio.get_running_loop() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + fut = ws.register_pending("llm-evt") + return policy, fut + + @staticmethod + def _tool_call_state() -> dict[str, Any]: + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + return {"messages": [ai]} + + async def test_sync_tool_hitl_holds_until_human_then_blocks_on_reject( + self, tmp_path: Path + ) -> None: + """MODE_HITL: a sync tool is HELD indefinitely, never fail-opens. + + The gate must wait past ``block_timeout`` (the bounded MODE_BLOCK wait + does not apply to HITL); a human reject then halts the tool. Regression + for the worker-thread bridge fail-opening a HITL hold once a finite + ``future.result`` timeout elapsed. + """ + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.5, + ) + + ws = adrian._ws_client + assert ws is not None + policy, fut = self._prep_hitl(ws) + + task = asyncio.ensure_future( + ToolNode([_real_tool]).ainvoke( # pyright: ignore[reportUnknownMemberType] + self._tool_call_state(), config=_runtime_config() + ) + ) + + # Held well past block_timeout: neither run nor returned, waiting for a human. + await asyncio.sleep(1.5) + assert not task.done() + assert captured == [] + + # Human rejects -> HITL verdict with continue_execution=False. + verdict = pb.Verdict(event_id="llm-evt", mad_code="M4_a", policy=policy) + verdict.hitl.continue_execution = False + fut.set_result(verdict) + + result = await asyncio.wait_for(task, timeout=2.0) + assert captured == [] + msgs = result["messages"] + assert len(msgs) == 1 + assert "BLOCKED" in msgs[0].content + + async def test_sync_tool_hitl_resumes_on_approve(self, tmp_path: Path) -> None: + """MODE_HITL: after a human approve (continue_execution=True), the sync tool runs.""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.5, + ) + + ws = adrian._ws_client + assert ws is not None + policy, fut = self._prep_hitl(ws) + + task = asyncio.ensure_future( + ToolNode([_real_tool]).ainvoke( # pyright: ignore[reportUnknownMemberType] + self._tool_call_state(), config=_runtime_config() + ) + ) + + await asyncio.sleep(0.3) + assert not task.done() + assert captured == [] + + verdict = pb.Verdict(event_id="llm-evt", mad_code="M4_a", policy=policy) + verdict.hitl.continue_execution = True + fut.set_result(verdict) + + await asyncio.wait_for(task, timeout=2.0) + assert captured == ["hi"] + + async def test_sync_tool_block_timeout_fails_closed(self, tmp_path: Path) -> None: + """MODE_BLOCK: no verdict before block_timeout -> sync tool blocked (fail-closed).""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.1, + ) + + ws = adrian._ws_client + assert ws is not None + _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) + ws._connected.set() + ws._loop = asyncio.get_running_loop() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + ws.register_pending("llm-evt") # never resolved -> verdict times out + + result = await ToolNode([_real_tool]).ainvoke( # pyright: ignore[reportUnknownMemberType] + self._tool_call_state(), config=_runtime_config() + ) + + assert captured == [] + assert "BLOCKED" in result["messages"][0].content diff --git a/sdk/python/tests/test_block_mode_races.py b/sdk/python/tests/test_block_mode_races.py index 16d8e4a..394408d 100644 --- a/sdk/python/tests/test_block_mode_races.py +++ b/sdk/python/tests/test_block_mode_races.py @@ -5,17 +5,17 @@ LLM calls; no running backend. Scenarios mirror the validated shapes from the multi-agent work: - S1 subagents-as-tools , director → worker (nested) - S2 handoffs , triage → specialist (sequential) - S3 router , parallel fan-out via Send() - S4 hierarchical , 3-level deep (director → team-lead → worker) - S5 custom workflow , deterministic + LLM nodes mixed - S6 swarm , back-and-forth handoffs (Alice ↔ Bob) - S7 supervisor , central dispatcher to N workers - S8 deep research , parallel researchers via asyncio.gather + S1 subagents-as-tools - director → worker (nested) + S2 handoffs - triage → specialist (sequential) + S3 router - parallel fan-out via Send() + S4 hierarchical - 3-level deep (director → team-lead → worker) + S5 custom workflow - deterministic + LLM nodes mixed + S6 swarm - back-and-forth handoffs (Alice ↔ Bob) + S7 supervisor - central dispatcher to N workers + S8 deep research - parallel researchers via asyncio.gather The invariant under test: for EVERY pattern, each ToolNode invocation -blocks on the verdict of the LLM that emitted its specific tool_call.id , +blocks on the verdict of the LLM that emitted its specific tool_call.id - never a sibling, never a parent, never a stale global. """ diff --git a/sdk/python/tests/test_extract_tool_calls.py b/sdk/python/tests/test_extract_tool_calls.py index 9910673..9cad0d4 100644 --- a/sdk/python/tests/test_extract_tool_calls.py +++ b/sdk/python/tests/test_extract_tool_calls.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2026 SecureAgentics -"""Unit tests for ``_extract_tool_calls`` — the function whose missing shape +"""Unit tests for ``_extract_tool_calls`` - the function whose missing shape handling let block/HITL skip the verdict wait for ``create_react_agent`` agents. Covers all three ToolNode input shapes. Shape 3 (per-tool-call dispatch) is the diff --git a/sdk/python/tests/test_parent_context_scenarios.py b/sdk/python/tests/test_parent_context_scenarios.py index 157b5e1..327884b 100644 --- a/sdk/python/tests/test_parent_context_scenarios.py +++ b/sdk/python/tests/test_parent_context_scenarios.py @@ -1,4 +1,4 @@ -"""End-to-end parent-context derivation per multi-agent scenario (S1–S8). +"""End-to-end parent-context derivation per multi-agent scenario (S1-S8). Fires the LangChain-shaped callback sequence each scenario produces - with the ``langgraph_checkpoint_ns`` metadata LangGraph would emit -