From edbec4e30b2a9bd3eb0b466be0fab69a0854bfaa Mon Sep 17 00:00:00 2001 From: habema Date: Thu, 5 Feb 2026 15:23:10 +0300 Subject: [PATCH 01/11] feat: Add tool origin tracking to ToolCallItem and ToolCallOutputItem - Add ToolOriginType enum and ToolOrigin dataclass - Add _tool_origin field to FunctionTool - Set tool_origin for MCP tools and agent-as-tool - Extract and set tool_origin in ToolCallItem and ToolCallOutputItem creation - Add comprehensive tests for tool origin tracking --- src/agents/agent.py | 7 + src/agents/items.py | 7 + src/agents/mcp/util.py | 9 +- src/agents/run_internal/run_loop.py | 9 +- src/agents/run_internal/tool_execution.py | 3 + src/agents/run_internal/turn_resolution.py | 9 +- src/agents/tool.py | 58 ++++ tests/test_tool_origin.py | 333 +++++++++++++++++++++ 8 files changed, 431 insertions(+), 4 deletions(-) create mode 100644 tests/test_tool_origin.py diff --git a/src/agents/agent.py b/src/agents/agent.py index b0368e8698..1afc33757b 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -45,6 +45,8 @@ FunctionToolResult, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, _extract_tool_argument_json_error, default_tool_error_function, ) @@ -802,6 +804,11 @@ async def _run_agent_tool(context: ToolContext, input_json: str) -> Any: ) run_agent_tool._is_agent_tool = True run_agent_tool._agent_instance = self + # Set origin tracking on run_agent (the FunctionTool returned by @function_tool) + run_agent_tool._tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool=self, + ) return run_agent_tool diff --git a/src/agents/items.py b/src/agents/items.py index 94ab5daa35..64565b6037 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -49,6 +49,7 @@ from .exceptions import AgentsException, ModelBehaviorError from .logger import logger from .tool import ( + ToolOrigin, ToolOutputFileContent, ToolOutputImage, ToolOutputText, @@ -248,6 +249,9 @@ class ToolCallItem(RunItemBase[Any]): description: str | None = None """Optional tool description if known at item creation time.""" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -271,6 +275,9 @@ class ToolCallOutputItem(RunItemBase[Any]): type: Literal["tool_call_output_item"] = "tool_call_output_item" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 9c9a59f683..a72c55ccaf 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -20,6 +20,8 @@ FunctionTool, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, ToolOutputImageDict, ToolOutputTextDict, default_tool_error_function, @@ -301,7 +303,7 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput: bool | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] ) = server._get_needs_approval_for_tool(tool, agent) - return FunctionTool( + function_tool = FunctionTool( name=tool.name, description=tool.description or "", params_json_schema=schema, @@ -309,6 +311,11 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput: strict_json_schema=is_strict, needs_approval=needs_approval, ) + function_tool._tool_origin = ToolOrigin( + type=ToolOriginType.MCP, + mcp_server=server, + ) + return function_tool @staticmethod def _merge_mcp_meta( diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index e807c0cb11..4404868ed8 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -49,7 +49,7 @@ RawResponsesStreamEvent, RunItemStreamEvent, ) -from ..tool import Tool, dispose_resolved_computers +from ..tool import FunctionTool, Tool, _get_tool_origin_info, dispose_resolved_computers from ..tracing import Span, SpanError, agent_span, get_current_trace from ..tracing.model_tracing import get_model_tracing_impl from ..tracing.span_data import AgentSpanData @@ -1216,13 +1216,18 @@ async def run_single_turn_streamed( # execution behavior in process_model_response). tool_name = getattr(output_item, "name", None) tool_description: str | None = None + tool_origin = None if isinstance(tool_name, str) and tool_name in tool_map: - tool_description = getattr(tool_map[tool_name], "description", None) + tool = tool_map[tool_name] + tool_description = getattr(tool, "description", None) + if isinstance(tool, FunctionTool): + tool_origin = _get_tool_origin_info(tool) tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), agent=agent, description=tool_description, + tool_origin=tool_origin, ) streamed_result._event_queue.put_nowait( RunItemStreamEvent(item=tool_item, name="tool_called") diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index bc370ea611..a22f9a5cdc 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -52,6 +52,7 @@ ShellCallOutcome, ShellCommandOutput, Tool, + _get_tool_origin_info, resolve_computer, ) from ..tool_context import ToolContext @@ -973,10 +974,12 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo run_item: RunItem | None = None if not nested_interruptions: + tool_origin = _get_tool_origin_info(tool_run.function_tool) run_item = ToolCallOutputItem( output=result, raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), agent=agent, + tool_origin=tool_origin, ) else: # Skip tool output until nested interruptions are resolved. diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index fed661ea9a..86872f4d27 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -62,6 +62,7 @@ LocalShellTool, ShellTool, Tool, + _get_tool_origin_info, ) from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from ..tracing import SpanError, handoff_span @@ -1473,8 +1474,14 @@ def process_model_response( raise ModelBehaviorError(error) func_tool = function_map[output.name] + tool_origin = _get_tool_origin_info(func_tool) items.append( - ToolCallItem(raw_item=output, agent=agent, description=func_tool.description) + ToolCallItem( + raw_item=output, + agent=agent, + description=func_tool.description, + tool_origin=tool_origin, + ) ) functions.append( ToolRunFunction( diff --git a/src/agents/tool.py b/src/agents/tool.py index 4f70adc0f8..06cc25a734 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import enum import inspect import json import weakref @@ -48,6 +49,7 @@ if TYPE_CHECKING: from .agent import Agent, AgentBase from .items import RunItem, ToolApprovalItem + from .mcp.server import MCPServer ToolParams = ParamSpec("ToolParams") @@ -182,6 +184,59 @@ class ComputerProvider(Generic[ComputerT]): ] +class ToolOriginType(str, enum.Enum): + """The type of tool origin.""" + + FUNCTION = "function" + """Regular Python function tool created via @function_tool decorator.""" + + MCP = "mcp" + """MCP server tool converted via MCPUtil.to_function_tool().""" + + AGENT_AS_TOOL = "agent_as_tool" + """Agent converted to tool via agent.as_tool().""" + + +@dataclass +class ToolOrigin: + """Information about the origin/source of a function tool.""" + + type: ToolOriginType + """The type of tool origin.""" + + mcp_server: MCPServer | None = None + """The MCP server object. Only set when type is MCP.""" + + agent_as_tool: Agent[Any] | None = None + """The agent object. Only set when type is AGENT_AS_TOOL.""" + + def __repr__(self) -> str: + """Custom repr that only includes relevant fields.""" + parts = [f"type={self.type.value!r}"] + if self.mcp_server is not None: + parts.append(f"mcp_server_name={self.mcp_server.name!r}") + if self.agent_as_tool is not None: + parts.append(f"agent_as_tool_name={self.agent_as_tool.name!r}") + return f"ToolOrigin({', '.join(parts)})" + + +def _get_tool_origin_info(function_tool: FunctionTool) -> ToolOrigin | None: + """Extract origin information from a FunctionTool. + + Args: + function_tool: The function tool to extract origin info from. + + Returns: + ToolOrigin object if origin is set, otherwise None (defaults to FUNCTION type). + """ + origin = function_tool._tool_origin + if origin is None: + # Default to FUNCTION if not explicitly set + return ToolOrigin(type=ToolOriginType.FUNCTION) + + return origin + + @dataclass class FunctionToolResult: tool: FunctionTool @@ -264,6 +319,9 @@ class FunctionTool: _agent_instance: Any = field(default=None, init=False, repr=False) """Internal reference to the agent instance if this is an agent-as-tool.""" + _tool_origin: ToolOrigin | None = field(default=None, init=False, repr=False) + """Internal field tracking the origin of this tool (FUNCTION, MCP, or AGENT_AS_TOOL).""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py new file mode 100644 index 0000000000..5800b87099 --- /dev/null +++ b/tests/test_tool_origin.py @@ -0,0 +1,333 @@ +"""Tests for tool origin tracking feature.""" + +from __future__ import annotations + +import sys +from typing import cast + +import pytest + +from agents import Agent, FunctionTool, RunContextWrapper, Runner, function_tool +from agents.items import ToolCallItem, ToolCallItemTypes, ToolCallOutputItem +from agents.tool import ToolOrigin, ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_function_tool_origin(): + """Test that regular function tools have FUNCTION origin.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_call_items[0].tool_origin.mcp_server is None + assert tool_call_items[0].tool_origin.agent_as_tool is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_items[0].tool_origin.mcp_server is None + assert tool_output_items[0].tool_origin.agent_as_tool is None + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_mcp_tool_origin(): + """Test that MCP tools have MCP origin with server name.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server is not None + assert tool_call_items[0].tool_origin.mcp_server.name == "test_mcp_server" + assert tool_call_items[0].tool_origin.agent_as_tool is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server is not None + assert tool_output_items[0].tool_origin.mcp_server.name == "test_mcp_server" + assert tool_output_items[0].tool_origin.agent_as_tool is None + + +@pytest.mark.asyncio +async def test_agent_as_tool_origin(): + """Test that agent-as-tool has AGENT_AS_TOOL origin with agent name.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_items[0].tool_origin.mcp_server is None + assert tool_call_items[0].tool_origin.agent_as_tool is not None + assert tool_call_items[0].tool_origin.agent_as_tool.name == "nested_agent" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_items[0].tool_origin.mcp_server is None + assert tool_output_items[0].tool_origin.agent_as_tool is not None + assert tool_output_items[0].tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_multiple_tool_origins(): + """Test that multiple tools from different origins work together.""" + model = FakeModel() + nested_model = FakeModel() + + @function_tool + def func_tool(x: int) -> str: + """Function tool.""" + return f"function: {x}" + + mcp_server = FakeMCPServer(server_name="mcp_server") + mcp_server.add_tool("mcp_tool", {}) + + nested_agent = Agent(name="nested", model=nested_model, instructions="Nested agent") + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + agent_tool = nested_agent.as_tool(tool_name="agent_tool", tool_description="Agent tool") + + agent = Agent( + name="test", + model=model, + tools=[func_tool, agent_tool], + mcp_servers=[mcp_server], + ) + + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call("func_tool", '{"x": 1}'), + get_function_tool_call("mcp_tool", ""), + get_function_tool_call("agent_tool", '{"input": "test"}'), + ], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 3 + assert len(tool_output_items) == 3 + + # Find items by tool name + function_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "func_tool" + ) + mcp_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "mcp_tool" + ) + agent_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "agent_tool" + ) + + assert function_item.tool_origin is not None + assert function_item.tool_origin.type == ToolOriginType.FUNCTION + assert mcp_item.tool_origin is not None + assert mcp_item.tool_origin.type == ToolOriginType.MCP + assert mcp_item.tool_origin.mcp_server is not None + assert mcp_item.tool_origin.mcp_server.name == "mcp_server" + assert agent_item.tool_origin is not None + assert agent_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert agent_item.tool_origin.agent_as_tool is not None + assert agent_item.tool_origin.agent_as_tool.name == "nested" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_tool_origin_streaming(): + """Test that tool origin is populated correctly in streaming scenarios.""" + model = FakeModel() + server = FakeMCPServer(server_name="streaming_server") + server.add_tool("streaming_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("streaming_tool", "")], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="test") + tool_call_items = [] + tool_output_items = [] + + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + if isinstance(event.item, ToolCallItem): + tool_call_items.append(event.item) + elif isinstance(event.item, ToolCallOutputItem): + tool_output_items.append(event.item) + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server is not None + assert tool_call_items[0].tool_origin.mcp_server.name == "streaming_server" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server is not None + assert tool_output_items[0].tool_origin.mcp_server.name == "streaming_server" + + +@pytest.mark.asyncio +async def test_tool_origin_repr(): + """Test that ToolOrigin repr only shows relevant fields.""" + # FUNCTION origin + function_origin = ToolOrigin(type=ToolOriginType.FUNCTION) + assert "mcp_server_name" not in repr(function_origin) + assert "agent_as_tool_name" not in repr(function_origin) + + # MCP origin + if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + test_server = FakeMCPServer(server_name="test_server") + mcp_origin = ToolOrigin(type=ToolOriginType.MCP, mcp_server=test_server) + assert "mcp_server_name='test_server'" in repr(mcp_origin) + assert "agent_as_tool_name" not in repr(mcp_origin) + + # AGENT_AS_TOOL origin + model = FakeModel() + test_agent = Agent(name="test_agent", model=model, instructions="Test agent") + agent_origin = ToolOrigin(type=ToolOriginType.AGENT_AS_TOOL, agent_as_tool=test_agent) + assert "agent_as_tool_name='test_agent'" in repr(agent_origin) + assert "mcp_server_name" not in repr(agent_origin) + + +@pytest.mark.asyncio +async def test_tool_origin_defaults_to_function(): + """Test that tools without explicit origin default to FUNCTION.""" + model = FakeModel() + + # Create a FunctionTool directly without using @function_tool decorator + async def test_func(ctx: RunContextWrapper, args: str) -> str: + return "result" + + tool = FunctionTool( + name="direct_tool", + description="Direct tool", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=test_func, + ) + + agent = Agent(name="test", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("direct_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + + assert len(tool_call_items) == 1 + # Even though _tool_origin is None, _get_tool_origin_info defaults to FUNCTION + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_non_function_tool_items_have_no_origin(): + """Test that non-FunctionTool items (computer, shell, etc.) don't have tool_origin.""" + model = FakeModel() + + @function_tool + def func_tool() -> str: + """Function tool.""" + return "result" + + agent = Agent(name="test", model=model, tools=[func_tool]) + + # Create a ToolCallItem for a non-function tool (simulating computer/shell tool) + computer_call = { + "type": "computer_use_preview", + "call_id": "call_123", + "actions": [], + } + + # This simulates what happens for non-FunctionTool items + # They should not have tool_origin set + item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, computer_call), + agent=agent, + ) + + assert item.tool_origin is None From e1b635702c2f3f2f6b3e8ec98fe187c00beab71f Mon Sep 17 00:00:00 2001 From: habema Date: Thu, 5 Feb 2026 15:30:28 +0300 Subject: [PATCH 02/11] fix memory leak in code review and add test --- src/agents/items.py | 12 +++++++++ src/agents/tool.py | 41 ++++++++++++++++++++++++++++-- tests/test_tool_origin.py | 53 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 2 deletions(-) diff --git a/src/agents/items.py b/src/agents/items.py index 64565b6037..7139e07f99 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -252,6 +252,12 @@ class ToolCallItem(RunItemBase[Any]): tool_origin: ToolOrigin | None = field(default=None, repr=False) """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -278,6 +284,12 @@ class ToolCallOutputItem(RunItemBase[Any]): tool_origin: ToolOrigin | None = field(default=None, repr=False) """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. diff --git a/src/agents/tool.py b/src/agents/tool.py index 06cc25a734..2e0e043581 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -210,13 +210,50 @@ class ToolOrigin: agent_as_tool: Agent[Any] | None = None """The agent object. Only set when type is AGENT_AS_TOOL.""" + _agent_as_tool_ref: weakref.ReferenceType[Agent[Any]] | None = field( + default=None, init=False, repr=False + ) + """Weak reference to agent_as_tool for memory management.""" + + def __post_init__(self) -> None: + """Initialize weak reference for agent_as_tool.""" + if self.agent_as_tool is not None: + self._agent_as_tool_ref = weakref.ref(self.agent_as_tool) + + def __getattribute__(self, name: str) -> Any: + """Lazily resolve agent_as_tool via weakref when strong ref is cleared.""" + if name == "agent_as_tool": + # Check if strong reference still exists + value = object.__getattribute__(self, "__dict__").get("agent_as_tool") + if value is not None: + return value + # Try to resolve via weakref + ref = object.__getattribute__(self, "_agent_as_tool_ref") + if ref is not None: + agent = ref() + if agent is not None: + return agent + return None + return super().__getattribute__(name) + + def release_agent(self) -> None: + """Release the strong reference to agent_as_tool while keeping a weak reference.""" + if "agent_as_tool" not in self.__dict__: + return + agent = self.__dict__.get("agent_as_tool") + if agent is not None: + self._agent_as_tool_ref = weakref.ref(agent) + # Set to None instead of deleting so dataclass repr/asdict keep working. + self.__dict__["agent_as_tool"] = None + def __repr__(self) -> str: """Custom repr that only includes relevant fields.""" parts = [f"type={self.type.value!r}"] if self.mcp_server is not None: parts.append(f"mcp_server_name={self.mcp_server.name!r}") - if self.agent_as_tool is not None: - parts.append(f"agent_as_tool_name={self.agent_as_tool.name!r}") + agent = self.agent_as_tool + if agent is not None: + parts.append(f"agent_as_tool_name={agent.name!r}") return f"ToolOrigin({', '.join(parts)})" diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py index 5800b87099..245f982491 100644 --- a/tests/test_tool_origin.py +++ b/tests/test_tool_origin.py @@ -2,7 +2,9 @@ from __future__ import annotations +import gc import sys +import weakref from typing import cast import pytest @@ -331,3 +333,54 @@ def func_tool() -> str: ) assert item.tool_origin is None + + +def test_tool_origin_release_agent_clears_strong_reference(): + """Test that release_agent() clears strong reference to agent_as_tool.""" + # Create a ToolOrigin with an agent_as_tool + nested_agent = Agent( + name="nested_agent", + model=FakeModel(), + instructions="You are a nested agent.", + ) + + tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool=nested_agent, + ) + + # Create a ToolCallItem with this tool_origin + tool_call_item = ToolCallItem( + raw_item=cast( + ToolCallItemTypes, + { + "type": "function_call", + "name": "test_tool", + "call_id": "call_123", + "arguments": "{}", + }, + ), + agent=nested_agent, + tool_origin=tool_origin, + ) + + # Verify agent_as_tool is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.agent_as_tool is nested_agent + + # Create weak reference to verify GC behavior + nested_agent_ref = weakref.ref(nested_agent) + + # Release agent - this should clear strong reference in tool_origin + tool_call_item.release_agent() + + # After release, agent_as_tool should still be accessible via weakref + assert tool_call_item.tool_origin.agent_as_tool is nested_agent + + # Delete the agent and force GC + del nested_agent + gc.collect() + + # After GC, agent_as_tool should be None since strong refs were cleared + assert nested_agent_ref() is None + assert tool_call_item.tool_origin.agent_as_tool is None From 5b2835b82a4a1732ab17fd363f25a119bd4724bc Mon Sep 17 00:00:00 2001 From: habema Date: Sat, 7 Feb 2026 21:46:39 +0300 Subject: [PATCH 03/11] address code review and add test --- src/agents/run_state.py | 86 ++++++++- tests/test_tool_origin_serialization.py | 228 ++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 tests/test_tool_origin_serialization.py diff --git a/src/agents/run_state.py b/src/agents/run_state.py index d02d298140..08b23e2505 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -63,6 +63,8 @@ HostedMCPTool, LocalShellTool, ShellTool, + ToolOrigin, + ToolOriginType, ) from .tool_guardrails import ( AllowBehavior, @@ -635,6 +637,13 @@ def _serialize_item(self, item: RunItem) -> dict[str, Any]: result["tool_name"] = item.tool_name if hasattr(item, "description") and item.description is not None: result["description"] = item.description + if hasattr(item, "tool_origin") and item.tool_origin is not None: + tool_origin_data: dict[str, Any] = {"type": item.tool_origin.type.value} + if item.tool_origin.agent_as_tool is not None: + tool_origin_data["agent_as_tool"] = {"name": item.tool_origin.agent_as_tool.name} + if item.tool_origin.mcp_server is not None: + tool_origin_data["mcp_server"] = {"name": item.tool_origin.mcp_server.name} + result["tool_origin"] = tool_origin_data return result @@ -1918,6 +1927,67 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: return agent_map +def _deserialize_tool_origin( + tool_origin_data: dict[str, Any] | None, agent_map: dict[str, Agent[Any]], agent: Agent[Any] +) -> ToolOrigin | None: + """Deserialize ToolOrigin from JSON data. + + Args: + tool_origin_data: Serialized tool origin dictionary. + agent_map: Map of agent names to agent instances. + agent: The agent associated with this item (used for MCP server lookup). + + Returns: + ToolOrigin instance or None if data is missing/invalid. + """ + if not tool_origin_data: + return None + + origin_type_str = tool_origin_data.get("type") + if not origin_type_str: + return None + + try: + origin_type = ToolOriginType(origin_type_str) + except ValueError: + logger.warning(f"Unknown tool origin type: {origin_type_str}") + return None + + agent_as_tool: Agent[Any] | None = None + mcp_server: Any | None = None + + if origin_type == ToolOriginType.AGENT_AS_TOOL: + agent_data = tool_origin_data.get("agent_as_tool") + if agent_data and isinstance(agent_data, Mapping): + agent_name = agent_data.get("name") + if agent_name: + agent_as_tool = agent_map.get(agent_name) + if not agent_as_tool: + logger.warning(f"Agent {agent_name} not found in agent map for tool_origin") + + elif origin_type == ToolOriginType.MCP: + mcp_data = tool_origin_data.get("mcp_server") + if mcp_data and isinstance(mcp_data, Mapping): + server_name = mcp_data.get("name") + if server_name: + # Try to find the MCP server from the agent's mcp_servers list + mcp_servers = getattr(agent, "mcp_servers", []) + for server in mcp_servers: + if hasattr(server, "name") and server.name == server_name: + mcp_server = server + break + if not mcp_server: + logger.debug( + f"MCP server {server_name} not found in agent's mcp_servers for tool_origin" + ) + + return ToolOrigin( + type=origin_type, + agent_as_tool=agent_as_tool, + mcp_server=mcp_server, + ) + + def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[ModelResponse]: """Deserialize model responses from JSON data. @@ -2019,8 +2089,17 @@ def _resolve_agent_info( raw_item_tool = _deserialize_tool_call_raw_item(normalized_raw_item) # Preserve description if it was stored with the item description = item_data.get("description") + # Preserve tool_origin if it was stored with the item + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent + ) result.append( - ToolCallItem(agent=agent, raw_item=raw_item_tool, description=description) + ToolCallItem( + agent=agent, + raw_item=raw_item_tool, + description=description, + tool_origin=tool_origin, + ) ) elif item_type == "tool_call_output_item": @@ -2029,11 +2108,16 @@ def _resolve_agent_info( raw_item_output = _deserialize_tool_call_output_raw_item(normalized_raw_item) if raw_item_output is None: continue + # Preserve tool_origin if it was stored with the item + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent + ) result.append( ToolCallOutputItem( agent=agent, raw_item=raw_item_output, output=item_data.get("output", ""), + tool_origin=tool_origin, ) ) diff --git a/tests/test_tool_origin_serialization.py b/tests/test_tool_origin_serialization.py new file mode 100644 index 0000000000..87bca9fcc4 --- /dev/null +++ b/tests/test_tool_origin_serialization.py @@ -0,0 +1,228 @@ +"""Tests for tool_origin serialization in RunState.""" + +from __future__ import annotations + +import sys + +import pytest + +from agents import Agent, Runner, function_tool +from agents.items import ToolCallItem, ToolCallOutputItem +from agents.run_state import RunState +from agents.tool import ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_serialize_tool_origin_function(): + """Test that FUNCTION tool_origin is serialized and deserialized.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.FUNCTION + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(agent, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.FUNCTION + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_serialize_tool_origin_agent_as_tool(): + """Test that AGENT_AS_TOOL tool_origin is serialized and deserialized.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_item.tool_origin.agent_as_tool is not None + assert tool_call_item.tool_origin.agent_as_tool.name == "nested_agent" + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_item.tool_origin.agent_as_tool is not None + assert tool_output_item.tool_origin.agent_as_tool.name == "nested_agent" + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=orchestrator, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(orchestrator, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert deserialized_tool_call.tool_origin.agent_as_tool is not None + assert deserialized_tool_call.tool_origin.agent_as_tool.name == "nested_agent" + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert deserialized_tool_output.tool_origin.agent_as_tool is not None + assert deserialized_tool_output.tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_serialize_tool_origin_mcp(): + """Test that MCP tool_origin is serialized and deserialized.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.MCP + assert tool_call_item.tool_origin.mcp_server is not None + assert tool_call_item.tool_origin.mcp_server.name == "test_mcp_server" + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.MCP + assert tool_output_item.tool_origin.mcp_server is not None + assert tool_output_item.tool_origin.mcp_server.name == "test_mcp_server" + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(agent, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.MCP + # MCP server should be reconstructed from agent's mcp_servers + assert deserialized_tool_call.tool_origin.mcp_server is not None + assert deserialized_tool_call.tool_origin.mcp_server.name == "test_mcp_server" + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.MCP + assert deserialized_tool_output.tool_origin.mcp_server is not None + assert deserialized_tool_output.tool_origin.mcp_server.name == "test_mcp_server" From 68c9dde2f5f55f687a5ad510004c152fae8d5169 Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 8 Feb 2026 15:00:32 +0300 Subject: [PATCH 04/11] address code review --- src/agents/run_internal/items.py | 4 +- src/agents/run_internal/tool_execution.py | 2 + src/agents/run_internal/turn_resolution.py | 5 +- tests/test_tool_origin_rejection.py | 205 +++++++++++++++++++++ 4 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 tests/test_tool_origin_rejection.py diff --git a/src/agents/run_internal/items.py b/src/agents/run_internal/items.py index 04e00f598f..015d73afdc 100644 --- a/src/agents/run_internal/items.py +++ b/src/agents/run_internal/items.py @@ -15,7 +15,7 @@ from ..agent_tool_state import drop_agent_tool_run_result from ..items import ItemHelpers, ToolCallOutputItem, TResponseInputItem from ..models.fake_id import FAKE_RESPONSES_ID -from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE +from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, ToolOrigin REJECTION_MESSAGE = DEFAULT_APPROVAL_REJECTION_MESSAGE _TOOL_CALL_TO_OUTPUT_TYPE: dict[str, str] = { @@ -191,6 +191,7 @@ def function_rejection_item( tool_call: Any, *, rejection_message: str = REJECTION_MESSAGE, + tool_origin: ToolOrigin | None = None, ) -> ToolCallOutputItem: """Build a ToolCallOutputItem representing a rejected function tool call.""" if isinstance(tool_call, ResponseFunctionToolCall): @@ -199,6 +200,7 @@ def function_rejection_item( output=rejection_message, raw_item=ItemHelpers.tool_call_output_item(tool_call, rejection_message), agent=agent, + tool_origin=tool_origin, ) diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index a22f9a5cdc..e8a26ff121 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -868,6 +868,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo ) result = rejection_message span_fn.span_data.output = result + tool_origin = _get_tool_origin_info(func_tool) return FunctionToolResult( tool=func_tool, output=result, @@ -875,6 +876,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo agent, tool_call, rejection_message=rejection_message, + tool_origin=tool_origin, ), ) diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index 86872f4d27..ac28ff3a1a 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -692,8 +692,11 @@ async def _record_function_rejection( tool_name=function_tool.name, call_id=call_id, ) + tool_origin = _get_tool_origin_info(function_tool) rejected_function_outputs.append( - function_rejection_item(agent, tool_call, rejection_message=rejection_message) + function_rejection_item( + agent, tool_call, rejection_message=rejection_message, tool_origin=tool_origin + ) ) if isinstance(call_id, str): rejected_function_call_ids.add(call_id) diff --git a/tests/test_tool_origin_rejection.py b/tests/test_tool_origin_rejection.py new file mode 100644 index 0000000000..8582e03fdb --- /dev/null +++ b/tests/test_tool_origin_rejection.py @@ -0,0 +1,205 @@ +"""Tests for tool_origin preservation on rejected function tool calls.""" + +from __future__ import annotations + +import sys + +import pytest + +from agents import Agent, function_tool +from agents.items import ToolCallOutputItem +from agents.tool import ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message +from .utils.hitl import reject_tool_call + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_rejected_function_tool_preserves_tool_origin(): + """Test that rejected function tools preserve tool_origin.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("test_tool", '{"x": 42}') + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + reject_tool_call(context, agent, tool_call, "test_tool") + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=test_tool) + results, _, _ = await execute_function_tool_calls( + agent=agent, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_rejected_agent_as_tool_preserves_tool_origin(): + """Test that rejected agent-as-tool preserves tool_origin.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("nested_tool", '{"input": "test"}') + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + from agents.tool import FunctionTool + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + assert isinstance(tool, FunctionTool) + reject_tool_call(context, orchestrator, tool_call, "nested_tool") + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=tool) + results, _, _ = await execute_function_tool_calls( + agent=orchestrator, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert result.run_item.tool_origin.agent_as_tool is not None + assert result.run_item.tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_rejected_mcp_tool_preserves_tool_origin(): + """Test that rejected MCP tools preserve tool_origin.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("mcp_tool", "") + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.mcp import MCPUtil + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + from agents.tool import FunctionTool + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + reject_tool_call(context, agent, tool_call, "mcp_tool") + + # Get the MCP tool as FunctionTool + mcp_tools = await MCPUtil.get_all_function_tools( + agent.mcp_servers, + convert_schemas_to_strict=False, + run_context=context, + agent=agent, + ) + mcp_tool = next(tool for tool in mcp_tools if tool.name == "mcp_tool") + assert isinstance(mcp_tool, FunctionTool) + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=mcp_tool) + results, _, _ = await execute_function_tool_calls( + agent=agent, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.MCP + assert result.run_item.tool_origin.mcp_server is not None + assert result.run_item.tool_origin.mcp_server.name == "test_mcp_server" From c2ea42d1a27e6a92720994b9f501a4e93160ac29 Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 8 Feb 2026 15:13:16 +0300 Subject: [PATCH 05/11] address code review --- src/agents/run_internal/turn_resolution.py | 12 +++- tests/test_tool_origin_output_schema.py | 69 ++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 tests/test_tool_origin_output_schema.py diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index ac28ff3a1a..806a112361 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -1459,11 +1459,19 @@ def process_model_response( else: if output.name not in function_map: if output_schema is not None and output.name == "json_tool_call": - items.append(ToolCallItem(raw_item=output, agent=agent)) + json_tool = build_litellm_json_tool_call(output) + tool_origin = _get_tool_origin_info(json_tool) + items.append( + ToolCallItem( + raw_item=output, + agent=agent, + tool_origin=tool_origin, + ) + ) functions.append( ToolRunFunction( tool_call=output, - function_tool=build_litellm_json_tool_call(output), + function_tool=json_tool, ) ) continue diff --git a/tests/test_tool_origin_output_schema.py b/tests/test_tool_origin_output_schema.py new file mode 100644 index 0000000000..3b7144a650 --- /dev/null +++ b/tests/test_tool_origin_output_schema.py @@ -0,0 +1,69 @@ +"""Tests for tool_origin with output_schema json_tool_call.""" + +from __future__ import annotations + +from pydantic import BaseModel + +from agents import Agent +from agents.agent_output import AgentOutputSchema +from agents.items import ModelResponse, ToolCallItem +from agents.run_internal.turn_resolution import process_model_response +from agents.tool import ToolOriginType +from agents.usage import Usage + +from .test_responses import get_function_tool_call + + +class OutputSchema(BaseModel): + """Test output schema.""" + + result: str + + +def test_output_schema_json_tool_call_has_tool_origin(): + """Test that json_tool_call ToolCallItem has tool_origin when output_schema is enabled.""" + agent = Agent(name="test", output_type=OutputSchema) + + # Get the output_schema + from agents.run_internal.run_loop import get_output_schema + + output_schema = get_output_schema(agent) + assert output_schema is not None + assert isinstance(output_schema, AgentOutputSchema) + + # Simulate a json_tool_call response + json_output = OutputSchema(result="test").model_dump_json() + json_tool_call = get_function_tool_call("json_tool_call", json_output) + + response = ModelResponse( + output=[json_tool_call], + usage=Usage(), + response_id=None, + ) + + # Process the response + processed = process_model_response( + agent=agent, + all_tools=[], + response=response, + output_schema=output_schema, + handoffs=[], + ) + + # Find the json_tool_call item + json_tool_call_item = next( + item + for item in processed.new_items + if isinstance(item, ToolCallItem) + and hasattr(item.raw_item, "name") + and item.raw_item.name == "json_tool_call" + ) + + # Verify tool_origin is set on ToolCallItem + assert json_tool_call_item.tool_origin is not None + assert json_tool_call_item.tool_origin.type == ToolOriginType.FUNCTION + + # Verify that a ToolRunFunction was created for execution + assert len(processed.functions) == 1 + function_run = processed.functions[0] + assert function_run.function_tool.name == "json_tool_call" From 6825a7453d34f2a9bb1210f94719098eddeded3d Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 8 Feb 2026 15:38:17 +0300 Subject: [PATCH 06/11] address code review --- src/agents/run_internal/run_loop.py | 17 ++++++++++- tests/test_tool_origin_output_schema.py | 40 +++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 4404868ed8..0116e04d14 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -10,7 +10,11 @@ from collections.abc import Awaitable, Callable from typing import Any, TypeVar, cast -from openai.types.responses import ResponseCompletedEvent, ResponseOutputItemDoneEvent +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, +) from openai.types.responses.response_prompt_param import ResponsePromptParam from openai.types.responses.response_reasoning_item import ResponseReasoningItem @@ -113,6 +117,7 @@ from .streaming import stream_step_items_to_queue, stream_step_result_to_queue from .tool_actions import ApplyPatchAction, ComputerAction, LocalShellAction, ShellAction from .tool_execution import ( + build_litellm_json_tool_call, coerce_shell_call, execute_apply_patch_calls, execute_computer_actions, @@ -1222,6 +1227,16 @@ async def run_single_turn_streamed( tool_description = getattr(tool, "description", None) if isinstance(tool, FunctionTool): tool_origin = _get_tool_origin_info(tool) + elif ( + isinstance(tool_name, str) + and tool_name == "json_tool_call" + and output_schema is not None + and isinstance(output_item, ResponseFunctionToolCall) + ): + # json_tool_call is synthesized dynamically and not in tool_map. + # Synthesize it here to get tool_origin, matching process_model_response. + json_tool = build_litellm_json_tool_call(output_item) + tool_origin = _get_tool_origin_info(json_tool) tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), diff --git a/tests/test_tool_origin_output_schema.py b/tests/test_tool_origin_output_schema.py index 3b7144a650..8556c9e9f3 100644 --- a/tests/test_tool_origin_output_schema.py +++ b/tests/test_tool_origin_output_schema.py @@ -2,16 +2,18 @@ from __future__ import annotations +import pytest from pydantic import BaseModel -from agents import Agent +from agents import Agent, Runner from agents.agent_output import AgentOutputSchema from agents.items import ModelResponse, ToolCallItem from agents.run_internal.turn_resolution import process_model_response from agents.tool import ToolOriginType from agents.usage import Usage -from .test_responses import get_function_tool_call +from .fake_model import FakeModel +from .test_responses import get_final_output_message, get_function_tool_call class OutputSchema(BaseModel): @@ -67,3 +69,37 @@ def test_output_schema_json_tool_call_has_tool_origin(): assert len(processed.functions) == 1 function_run = processed.functions[0] assert function_run.function_tool.name == "json_tool_call" + + +@pytest.mark.asyncio +async def test_output_schema_json_tool_call_streaming_has_tool_origin(): + """ + Test that streamed json_tool_call ToolCallItem has tool_origin when output_schema is enabled. + """ + model = FakeModel() + agent = Agent(name="test", model=model, output_type=OutputSchema) + + # Simulate a json_tool_call response followed by completion + json_output = OutputSchema(result="test").model_dump_json() + json_tool_call = get_function_tool_call("json_tool_call", json_output) + final_output = get_final_output_message(json_output) + model.add_multiple_turn_outputs([[json_tool_call], [final_output]]) + + # Collect streamed events + streamed_tool_call_items: list[ToolCallItem] = [] + + result = Runner.run_streamed(agent, input="test") + async for event in result.stream_events(): + if event.type == "run_item_stream_event" and isinstance(event.item, ToolCallItem): + streamed_tool_call_items.append(event.item) + + # Find the json_tool_call item + json_tool_call_item = next( + item + for item in streamed_tool_call_items + if hasattr(item.raw_item, "name") and item.raw_item.name == "json_tool_call" + ) + + # Verify tool_origin is set on streamed ToolCallItem + assert json_tool_call_item.tool_origin is not None + assert json_tool_call_item.tool_origin.type == ToolOriginType.FUNCTION From f850af4c7b174ea88794acacbb8e1288368e8d81 Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 8 Feb 2026 15:43:59 +0300 Subject: [PATCH 07/11] export ToolOrigin and ToolOriginType --- src/agents/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index c4f1de30f2..6393b45b0e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -126,6 +126,8 @@ ShellResult, ShellTool, Tool, + ToolOrigin, + ToolOriginType, ToolOutputFileContent, ToolOutputFileContentDict, ToolOutputImage, @@ -359,6 +361,8 @@ def enable_verbose_stdout_logging(): "ApplyPatchResult", "ApplyPatchTool", "Tool", + "ToolOrigin", + "ToolOriginType", "WebSearchTool", "HostedMCPTool", "MCPToolApprovalFunction", From 8776b022953e89467d6eb4c5e14353c49b4d2ca3 Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 1 Mar 2026 14:37:13 +0300 Subject: [PATCH 08/11] address code review --- src/agents/run_internal/run_loop.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 8b44529480..2bcd2d7d1d 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -1260,8 +1260,23 @@ async def run_single_turn_streamed( if isinstance(tool_name, str) and tool_name in tool_map: tool = tool_map[tool_name] tool_description = getattr(tool, "description", None) - if isinstance(tool, FunctionTool): - tool_origin = _get_tool_origin_info(tool) + # Resolve FunctionTool for tool_origin; tool_map may have non-FunctionTool + # due to name collision. + func_tool = ( + tool + if isinstance(tool, FunctionTool) + else next( + ( + t + for t in all_tools + if isinstance(t, FunctionTool) + and getattr(t, "name", None) == tool_name + ), + None, + ) + ) + if func_tool is not None: + tool_origin = _get_tool_origin_info(func_tool) elif ( isinstance(tool_name, str) and tool_name == "json_tool_call" From e96186113bb72b13d16b07d0305d6f7a4130c965 Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 1 Mar 2026 15:00:39 +0300 Subject: [PATCH 09/11] update CURRENT_SCHEMA_VERSION to 1.5 and include it in SUPPORTED_SCHEMA_VERSIONS --- src/agents/run_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agents/run_state.py b/src/agents/run_state.py index 2cea48d750..8a7c02cf17 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -102,8 +102,8 @@ # 2. Keep older readable versions in SUPPORTED_SCHEMA_VERSIONS for backward reads. # 3. to_json() always emits CURRENT_SCHEMA_VERSION. # 4. Forward compatibility is intentionally fail-fast (older SDKs reject newer versions). -CURRENT_SCHEMA_VERSION = "1.4" -SUPPORTED_SCHEMA_VERSIONS = frozenset({"1.0", "1.1", "1.2", "1.3", CURRENT_SCHEMA_VERSION}) +CURRENT_SCHEMA_VERSION = "1.5" +SUPPORTED_SCHEMA_VERSIONS = frozenset({"1.0", "1.1", "1.2", "1.3", "1.4", CURRENT_SCHEMA_VERSION}) _FUNCTION_OUTPUT_ADAPTER: TypeAdapter[FunctionCallOutput] = TypeAdapter(FunctionCallOutput) _COMPUTER_OUTPUT_ADAPTER: TypeAdapter[ComputerCallOutput] = TypeAdapter(ComputerCallOutput) From d649f0679abe675d8f83d680a75c16b308bdd4cf Mon Sep 17 00:00:00 2001 From: habema Date: Sun, 1 Mar 2026 15:13:11 +0300 Subject: [PATCH 10/11] address code review --- src/agents/run_internal/run_loop.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 2bcd2d7d1d..dc90a336fa 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -1069,6 +1069,8 @@ async def run_single_turn_streamed( # execution in process_model_response, so duplicate names (e.g., MCP + local tool) # stream the same description that execution uses. tool_map = {t.name: t for t in all_tools if hasattr(t, "name") and t.name} + # FunctionTool-only map for tool_origin; matches process_model_response's last-wins. + function_map = {t.name: t for t in all_tools if isinstance(t, FunctionTool)} try: turn_input = ItemHelpers.input_to_new_input_list(streamed_result.input) @@ -1260,21 +1262,9 @@ async def run_single_turn_streamed( if isinstance(tool_name, str) and tool_name in tool_map: tool = tool_map[tool_name] tool_description = getattr(tool, "description", None) - # Resolve FunctionTool for tool_origin; tool_map may have non-FunctionTool - # due to name collision. - func_tool = ( - tool - if isinstance(tool, FunctionTool) - else next( - ( - t - for t in all_tools - if isinstance(t, FunctionTool) - and getattr(t, "name", None) == tool_name - ), - None, - ) - ) + # Use function_map for tool_origin to match process_model_response's + # last-wins semantics when multiple FunctionTools share a name. + func_tool = function_map.get(tool_name) if func_tool is not None: tool_origin = _get_tool_origin_info(func_tool) elif ( From 2eef12731cb3ba970c600a6760932e18e78f1d22 Mon Sep 17 00:00:00 2001 From: habema Date: Mon, 2 Mar 2026 02:09:49 +0300 Subject: [PATCH 11/11] Code Review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. **Agent-as-tool identity** – Switched from `agent.name` to `agent_index` in serialization so duplicate agent names are handled correctly. 2. **`_build_agent_map`** – Now returns `(agent_map, agent_list)` with `id()`-based visited tracking. 3. **`ToolApprovalItem`** – Added `tool_origin` and `release_agent` for provenance on pending approvals. 4. **RunState** – Added `_starting_agent` and optional `root_agent` for correct agent graph traversal during serialization. --- src/agents/items.py | 9 + src/agents/realtime/session.py | 9 +- src/agents/run_internal/tool_execution.py | 6 +- src/agents/run_internal/tool_use_tracker.py | 2 +- src/agents/run_internal/turn_resolution.py | 7 +- src/agents/run_state.py | 182 +++++++++++++++----- tests/test_run_state.py | 40 +++-- 7 files changed, 189 insertions(+), 66 deletions(-) diff --git a/src/agents/items.py b/src/agents/items.py index e35d5e7e49..5f5eb2617c 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -394,8 +394,17 @@ class ToolApprovalItem(RunItemBase[Any]): tool_name: str | None = None """Tool name for approval tracking; falls back to raw_item.name when absent.""" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool. Only set for FunctionTool calls.""" + type: Literal["tool_approval_item"] = "tool_approval_item" + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + def __post_init__(self) -> None: """Populate tool_name from the raw item if not provided.""" if self.tool_name is None: diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 6448c17bfc..f3ebc809eb 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -429,7 +429,14 @@ def _build_tool_approval_item( "call_id": tool_call.call_id, "arguments": tool_call.arguments, } - return ToolApprovalItem(agent=cast(Any, agent), raw_item=raw_item, tool_name=tool.name) + from ..tool import _get_tool_origin_info + + return ToolApprovalItem( + agent=cast(Any, agent), + raw_item=raw_item, + tool_name=tool.name, + tool_origin=_get_tool_origin_info(tool), + ) async def _maybe_request_tool_approval( self, diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index 224a19478e..377bf2090e 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -873,8 +873,12 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo ) if approval_status is None: + tool_origin = _get_tool_origin_info(func_tool) approval_item = ToolApprovalItem( - agent=agent, raw_item=tool_call, tool_name=func_tool.name + agent=agent, + raw_item=tool_call, + tool_name=func_tool.name, + tool_origin=tool_origin, ) return FunctionToolResult( tool=func_tool, output=None, run_item=approval_item diff --git a/src/agents/run_internal/tool_use_tracker.py b/src/agents/run_internal/tool_use_tracker.py index 33e2d72156..aa8c35645b 100644 --- a/src/agents/run_internal/tool_use_tracker.py +++ b/src/agents/run_internal/tool_use_tracker.py @@ -85,7 +85,7 @@ def hydrate_tool_use_tracker( if not snapshot: return - agent_map = _build_agent_map(starting_agent) + agent_map, _ = _build_agent_map(starting_agent) for agent_name, tool_names in snapshot.items(): agent = agent_map.get(agent_name) if agent is None: diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index e73d56aa11..7e14702d52 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -982,7 +982,12 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: output_exists_checker=_function_output_exists, record_rejection=_record_function_rejection, pending_interruption_adder=_add_pending_interruption, - pending_item_builder=lambda run: ToolApprovalItem(agent=agent, raw_item=run.tool_call), + pending_item_builder=lambda run: ToolApprovalItem( + agent=agent, + raw_item=run.tool_call, + tool_name=run.function_tool.name, + tool_origin=_get_tool_origin_info(run.function_tool), + ), ) rebuilt_function_tool_runs = await _rebuild_function_runs_from_approvals() diff --git a/src/agents/run_state.py b/src/agents/run_state.py index 8a7c02cf17..072ddb9858 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -209,11 +209,13 @@ def __init__( conversation_id: str | None = None, previous_response_id: str | None = None, auto_previous_response_id: bool = False, + root_agent: TAgent | None = None, ): """Initialize a new RunState.""" self._context = context self._original_input = _clone_original_input(original_input) self._current_agent = starting_agent + self._starting_agent = root_agent if root_agent is not None else starting_agent self._max_turns = max_turns self._conversation_id = conversation_id self._previous_response_id = previous_response_id @@ -532,6 +534,8 @@ def to_json( if self._context is None: raise UserError("Cannot serialize RunState: No context") + _, agent_list = _build_agent_map(self._starting_agent) + approvals_dict = self._serialize_approvals() model_responses = self._serialize_model_responses() original_input_serialized = self._serialize_original_input() @@ -578,13 +582,18 @@ def to_json( } generated_items = self._merge_generated_items_with_processed() - result["generated_items"] = [self._serialize_item(item) for item in generated_items] - result["session_items"] = [self._serialize_item(item) for item in list(self._session_items)] - result["current_step"] = self._serialize_current_step() + result["generated_items"] = [ + self._serialize_item(item, agent_list) for item in generated_items + ] + result["session_items"] = [ + self._serialize_item(item, agent_list) for item in list(self._session_items) + ] + result["current_step"] = self._serialize_current_step(agent_list) result["last_model_response"] = _serialize_last_model_response(model_responses) result["last_processed_response"] = ( self._serialize_processed_response( self._last_processed_response, + agent_list, context_serializer=context_serializer, strict_context=strict_context, include_tracing_api_key=include_tracing_api_key, @@ -602,6 +611,7 @@ def to_json( def _serialize_processed_response( self, processed_response: ProcessedResponse, + agent_list: list[Any], *, context_serializer: ContextSerializer | None = None, strict_context: bool = False, @@ -611,6 +621,7 @@ def _serialize_processed_response( Args: processed_response: The ProcessedResponse to serialize. + agent_list: List of agents in traversal order for tool_origin serialization. Returns: A dictionary representation of the ProcessedResponse. @@ -628,19 +639,25 @@ def _serialize_processed_response( ) interruptions_data = [ - _serialize_tool_approval_interruption(interruption, include_tool_name=True) + _serialize_tool_approval_interruption( + interruption, + include_tool_name=True, + agent_list=agent_list, + ) for interruption in processed_response.interruptions if isinstance(interruption, ToolApprovalItem) ] return { - "new_items": [self._serialize_item(item) for item in processed_response.new_items], + "new_items": [ + self._serialize_item(item, agent_list) for item in processed_response.new_items + ], "tools_used": processed_response.tools_used, **action_groups, "interruptions": interruptions_data, } - def _serialize_current_step(self) -> dict[str, Any] | None: + def _serialize_current_step(self, agent_list: list[Any] | None = None) -> dict[str, Any] | None: """Serialize the current step if it's an interruption.""" # Import at runtime to avoid circular import from .run_internal.run_steps import NextStepInterruption @@ -650,7 +667,9 @@ def _serialize_current_step(self) -> dict[str, Any] | None: interruptions_data = [ _serialize_tool_approval_interruption( - item, include_tool_name=item.tool_name is not None + item, + include_tool_name=item.tool_name is not None, + agent_list=agent_list, ) for item in self._current_step.interruptions if isinstance(item, ToolApprovalItem) @@ -663,7 +682,7 @@ def _serialize_current_step(self) -> dict[str, Any] | None: }, } - def _serialize_item(self, item: RunItem) -> dict[str, Any]: + def _serialize_item(self, item: RunItem, agent_list: list[Any] | None = None) -> dict[str, Any]: """Serialize a run item to JSON-compatible dict.""" raw_item_dict: Any = _serialize_raw_item_value(item.raw_item) @@ -694,12 +713,9 @@ def _serialize_item(self, item: RunItem) -> dict[str, Any]: if hasattr(item, "description") and item.description is not None: result["description"] = item.description if hasattr(item, "tool_origin") and item.tool_origin is not None: - tool_origin_data: dict[str, Any] = {"type": item.tool_origin.type.value} - if item.tool_origin.agent_as_tool is not None: - tool_origin_data["agent_as_tool"] = {"name": item.tool_origin.agent_as_tool.name} - if item.tool_origin.mcp_server is not None: - tool_origin_data["mcp_server"] = {"name": item.tool_origin.mcp_server.name} - result["tool_origin"] = tool_origin_data + tool_origin_data = _serialize_tool_origin(item.tool_origin, agent_list) + if tool_origin_data: + result["tool_origin"] = tool_origin_data return result @@ -1120,8 +1136,33 @@ def _serialize_mcp_tool(mcp_tool: Any) -> dict[str, Any]: return {"value": normalized} +def _serialize_tool_origin( + tool_origin: ToolOrigin, agent_list: list[Any] | None = None +) -> dict[str, Any] | None: + """Serialize ToolOrigin to JSON. Uses agent_index when agent_list provided.""" + if tool_origin is None: + return None + result: dict[str, Any] = {"type": tool_origin.type.value} + if tool_origin.agent_as_tool is not None: + agent = tool_origin.agent_as_tool + if agent_list: + try: + idx = next(i for i, a in enumerate(agent_list) if a is agent) + result["agent_as_tool"] = {"agent_index": idx} + except StopIteration: + result["agent_as_tool"] = {"name": agent.name} + else: + result["agent_as_tool"] = {"name": agent.name} + if tool_origin.mcp_server is not None: + result["mcp_server"] = {"name": tool_origin.mcp_server.name} + return result + + def _serialize_tool_approval_interruption( - interruption: ToolApprovalItem, *, include_tool_name: bool + interruption: ToolApprovalItem, + *, + include_tool_name: bool, + agent_list: list[Any] | None = None, ) -> dict[str, Any]: """Serialize a ToolApprovalItem interruption.""" interruption_dict: dict[str, Any] = { @@ -1131,6 +1172,10 @@ def _serialize_tool_approval_interruption( } if include_tool_name and interruption.tool_name is not None: interruption_dict["tool_name"] = interruption.tool_name + if hasattr(interruption, "tool_origin") and interruption.tool_origin is not None: + tool_origin_data = _serialize_tool_origin(interruption.tool_origin, agent_list) + if tool_origin_data: + interruption_dict["tool_origin"] = tool_origin_data return interruption_dict @@ -1418,6 +1463,7 @@ async def _deserialize_processed_response( current_agent: Agent[Any], context: RunContextWrapper[Any], agent_map: dict[str, Agent[Any]], + agent_list: list[Agent[Any]], *, scope_id: str | None = None, context_deserializer: ContextDeserializer | None = None, @@ -1430,11 +1476,14 @@ async def _deserialize_processed_response( current_agent: The current agent (used to get tools and handoffs). context: The run context wrapper. agent_map: Map of agent names to agents. + agent_list: List of agents in traversal order for agent_index lookup. Returns: A reconstructed ProcessedResponse instance. """ - new_items = _deserialize_items(processed_response_data.get("new_items", []), agent_map) + new_items = _deserialize_items( + processed_response_data.get("new_items", []), agent_map, agent_list + ) if hasattr(current_agent, "get_all_tools"): all_tools = await current_agent.get_all_tools(context) @@ -1713,6 +1762,7 @@ def _deserialize_tool_approval_item( item_data: Mapping[str, Any], *, agent_map: Mapping[str, Agent[Any]], + agent_list: list[Agent[Any]] | None = None, fallback_agent: Agent[Any] | None = None, pre_normalized_raw_item: Any | None = None, ) -> ToolApprovalItem | None: @@ -1729,7 +1779,12 @@ def _deserialize_tool_approval_item( tool_name = item_data.get("tool_name") raw_item = _deserialize_tool_approval_raw_item(raw_item_data) - return ToolApprovalItem(agent=agent, raw_item=raw_item, tool_name=tool_name) + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent, agent_list + ) + return ToolApprovalItem( + agent=agent, raw_item=raw_item, tool_name=tool_name, tool_origin=tool_origin + ) def _deserialize_tool_call_output_raw_item( @@ -1956,7 +2011,7 @@ async def _build_run_state_from_json( f"New snapshots are written as version {CURRENT_SCHEMA_VERSION}." ) - agent_map = _build_agent_map(initial_agent) + agent_map, agent_list = _build_agent_map(initial_agent) current_agent_name = state_json["current_agent"]["name"] current_agent = agent_map.get(current_agent_name) @@ -2039,6 +2094,7 @@ async def _build_run_state_from_json( conversation_id=state_json.get("conversation_id"), previous_response_id=state_json.get("previous_response_id"), auto_previous_response_id=bool(state_json.get("auto_previous_response_id", False)), + root_agent=initial_agent, ) from .agent_tool_state import set_agent_tool_state_scope @@ -2047,7 +2103,9 @@ async def _build_run_state_from_json( state._current_turn = state_json["current_turn"] state._model_responses = _deserialize_model_responses(state_json.get("model_responses", [])) - state._generated_items = _deserialize_items(state_json.get("generated_items", []), agent_map) + state._generated_items = _deserialize_items( + state_json.get("generated_items", []), agent_map, agent_list + ) last_processed_response_data = state_json.get("last_processed_response") if last_processed_response_data and state._context is not None: @@ -2056,6 +2114,7 @@ async def _build_run_state_from_json( current_agent, state._context, agent_map, + agent_list, scope_id=state._agent_tool_state_scope_id, context_deserializer=context_deserializer, strict_context=strict_context, @@ -2064,7 +2123,9 @@ async def _build_run_state_from_json( state._last_processed_response = None if "session_items" in state_json: - state._session_items = _deserialize_items(state_json.get("session_items", []), agent_map) + state._session_items = _deserialize_items( + state_json.get("session_items", []), agent_map, agent_list + ) else: state._session_items = state._merge_generated_items_with_processed() @@ -2090,7 +2151,9 @@ async def _build_run_state_from_json( "interruptions", current_step_data.get("interruptions", []) ) for item_data in interruptions_data: - approval_item = _deserialize_tool_approval_item(item_data, agent_map=agent_map) + approval_item = _deserialize_tool_approval_item( + item_data, agent_map=agent_map, agent_list=agent_list + ) if approval_item is not None: interruptions.append(approval_item) @@ -2118,23 +2181,35 @@ async def _build_run_state_from_json( return state -def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: - """Build a map of agent names to agents by traversing handoffs. +def _build_agent_map( + initial_agent: Agent[Any], +) -> tuple[dict[str, Agent[Any]], list[Agent[Any]]]: + """Build agent map and agent list by traversing handoffs and tools. + + Uses id(agent) for visited tracking so duplicate-name agents are included in + agent_list. agent_map uses name as key (first wins) for backward compatibility. + agent_list preserves order for stable agent_index in tool_origin serialization. Args: initial_agent: The starting agent. Returns: - Dictionary mapping agent names to agent instances. + Tuple of (agent_map, agent_list). agent_map maps name -> agent (first wins). + agent_list is all agents in BFS order, including duplicates by name. """ agent_map: dict[str, Agent[Any]] = {} + agent_list: list[Agent[Any]] = [] + visited: set[int] = set() queue: deque[Agent[Any]] = deque([initial_agent]) while queue: current = queue.popleft() - if current.name in agent_map: + if id(current) in visited: continue - agent_map[current.name] = current + visited.add(id(current)) + agent_list.append(current) + if current.name not in agent_map: + agent_map[current.name] = current # Add handoff agents to the queue for handoff_item in current.handoffs: @@ -2142,21 +2217,15 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: handoff_agent_name: str | None = None if isinstance(handoff_item, Handoff): - # Some custom/mocked Handoff subclasses bypass dataclass initialization. - # Prefer agent_name, then legacy name fallback used in tests. candidate_name = getattr(handoff_item, "agent_name", None) or getattr( handoff_item, "name", None ) if isinstance(candidate_name, str): handoff_agent_name = candidate_name - if handoff_agent_name in agent_map: - continue handoff_ref = getattr(handoff_item, "_agent_ref", None) handoff_agent = handoff_ref() if callable(handoff_ref) else None if handoff_agent is None: - # Backward-compatibility fallback for custom legacy handoff objects that store - # the target directly on `.agent`. New code should prefer `handoff()` objects. legacy_agent = getattr(handoff_item, "agent", None) if legacy_agent is not None: handoff_agent = legacy_agent @@ -2164,7 +2233,7 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: "Using legacy handoff `.agent` fallback while building agent map. " "This compatibility path is not recommended for new code." ) - if handoff_agent_name is None: + if handoff_agent_name is None and handoff_agent is not None: candidate_name = getattr(handoff_agent, "name", None) handoff_agent_name = candidate_name if isinstance(candidate_name, str) else None if handoff_agent is None or not hasattr(handoff_agent, "handoffs"): @@ -2175,8 +2244,6 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: ) continue else: - # Backward-compatibility fallback for custom legacy handoff wrappers that expose - # the target directly on `.agent` without inheriting from `Handoff`. legacy_agent = getattr(handoff_item, "agent", None) if legacy_agent is not None: handoff_agent = legacy_agent @@ -2191,7 +2258,7 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: if ( handoff_agent is not None and handoff_agent_name - and handoff_agent_name not in agent_map + and id(handoff_agent) not in visited ): queue.append(cast(Any, handoff_agent)) @@ -2203,14 +2270,17 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: continue tool_agent = getattr(tool, "_agent_instance", None) tool_agent_name = getattr(tool_agent, "name", None) - if tool_agent and tool_agent_name and tool_agent_name not in agent_map: + if tool_agent and tool_agent_name and id(tool_agent) not in visited: queue.append(tool_agent) - return agent_map + return agent_map, agent_list def _deserialize_tool_origin( - tool_origin_data: dict[str, Any] | None, agent_map: dict[str, Agent[Any]], agent: Agent[Any] + tool_origin_data: dict[str, Any] | None, + agent_map: Mapping[str, Agent[Any]], + agent: Agent[Any], + agent_list: list[Agent[Any]] | None = None, ) -> ToolOrigin | None: """Deserialize ToolOrigin from JSON data. @@ -2218,6 +2288,7 @@ def _deserialize_tool_origin( tool_origin_data: Serialized tool origin dictionary. agent_map: Map of agent names to agent instances. agent: The agent associated with this item (used for MCP server lookup). + agent_list: List of agents in traversal order for agent_index lookup. Returns: ToolOrigin instance or None if data is missing/invalid. @@ -2241,11 +2312,22 @@ def _deserialize_tool_origin( if origin_type == ToolOriginType.AGENT_AS_TOOL: agent_data = tool_origin_data.get("agent_as_tool") if agent_data and isinstance(agent_data, Mapping): - agent_name = agent_data.get("name") - if agent_name: - agent_as_tool = agent_map.get(agent_name) - if not agent_as_tool: - logger.warning(f"Agent {agent_name} not found in agent map for tool_origin") + agent_index = agent_data.get("agent_index") + if agent_index is not None and isinstance(agent_index, int) and agent_list: + if 0 <= agent_index < len(agent_list): + agent_as_tool = agent_list[agent_index] + else: + logger.warning( + "Agent index %s out of range for tool_origin (list len=%s)", + agent_index, + len(agent_list), + ) + else: + agent_name = agent_data.get("name") + if agent_name: + agent_as_tool = agent_map.get(agent_name) + if not agent_as_tool: + logger.warning(f"Agent {agent_name} not found in agent map for tool_origin") elif origin_type == ToolOriginType.MCP: mcp_data = tool_origin_data.get("mcp_server") @@ -2307,17 +2389,23 @@ def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[M def _deserialize_items( - items_data: list[dict[str, Any]], agent_map: dict[str, Agent[Any]] + items_data: list[dict[str, Any]], + agent_map: dict[str, Agent[Any]], + agent_list: list[Agent[Any]] | None = None, ) -> list[RunItem]: """Deserialize run items from JSON data. Args: items_data: List of serialized run item dictionaries. agent_map: Map of agent names to agent instances. + agent_list: List of agents in traversal order for agent_index lookup. If None, + agent_as_tool falls back to name-based lookup (duplicate names not supported). Returns: List of RunItem instances. """ + if agent_list is None: + agent_list = list(agent_map.values()) result: list[RunItem] = [] @@ -2375,7 +2463,7 @@ def _resolve_agent_info( description = item_data.get("description") # Preserve tool_origin if it was stored with the item tool_origin = _deserialize_tool_origin( - item_data.get("tool_origin"), agent_map, agent + item_data.get("tool_origin"), agent_map, agent, agent_list ) result.append( ToolCallItem( @@ -2394,7 +2482,7 @@ def _resolve_agent_info( continue # Preserve tool_origin if it was stored with the item tool_origin = _deserialize_tool_origin( - item_data.get("tool_origin"), agent_map, agent + item_data.get("tool_origin"), agent_map, agent, agent_list ) result.append( ToolCallOutputItem( diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 39d0cdb43c..55854d40c9 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -901,7 +901,7 @@ def test_build_agent_map_collects_agents_without_looping(self): agent_a.handoffs = [agent_b] agent_b.handoffs = [agent_a] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert agent_map.get("AgentA") is not None assert agent_map.get("AgentB") is not None @@ -921,7 +921,7 @@ def test_build_agent_map_handles_complex_handoff_graphs(self): agent_b.handoffs = [agent_d] agent_c.handoffs = [agent_d] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert len(agent_map) == 4 assert all(agent_map.get(name) is not None for name in ["A", "B", "C", "D"]) @@ -932,7 +932,7 @@ def test_build_agent_map_handles_handoff_objects(self): agent_b = Agent(name="AgentB") agent_a.handoffs = [handoff(agent_b)] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert sorted(agent_map.keys()) == ["AgentA", "AgentB"] @@ -950,7 +950,7 @@ def __init__(self, target: Agent[Any]): agent_a.handoffs = [LegacyHandoff(agent_b)] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert sorted(agent_map.keys()) == ["AgentA", "AgentB"] @@ -965,7 +965,7 @@ def __init__(self, target: Agent[Any]): agent_a.handoffs = [LegacyWrapper(agent_b)] # type: ignore[list-item] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert sorted(agent_map.keys()) == ["AgentA", "AgentB"] @@ -986,7 +986,7 @@ async def _invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[An ) agent_a.handoffs = [detached_handoff] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert sorted(agent_map.keys()) == ["AgentA"] @@ -2965,7 +2965,7 @@ class AgentWithoutGetAllTools(Agent): # This should trigger line 759 (all_tools = []) result = await _deserialize_processed_response( - processed_response_data, agent_no_tools, context, {} + processed_response_data, agent_no_tools, context, {}, [] ) assert result is not None @@ -3002,8 +3002,9 @@ async def test_deserialize_processed_response_handoff_with_tool_name(self): } # This should trigger lines 778-782 and 787-796 + agent_map = {"AgentA": agent_a, "AgentB": agent_b} result = await _deserialize_processed_response( - processed_response_data, agent_a, context, {"AgentA": agent_a, "AgentB": agent_b} + processed_response_data, agent_a, context, agent_map, list(agent_map.values()) ) assert result is not None assert len(result.handoffs) == 1 @@ -3047,8 +3048,9 @@ async def tool_func(context: ToolContext[Any], arguments: str) -> str: } # This should trigger lines 801-808 + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None assert len(result.functions) == 1 @@ -3124,8 +3126,9 @@ def wait(self) -> None: } # This should trigger lines 815-824 + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None assert len(result.computer_actions) == 1 @@ -3165,8 +3168,9 @@ async def shell_executor(request: Any) -> Any: } # This should trigger the ValidationError path (lines 1299-1302) + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None # Should fall back to using tool_call_data directly when validation fails @@ -3218,8 +3222,9 @@ def delete_file(self, operation: Any) -> Any: } # This should trigger the Exception path (lines 1314-1317) + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None # Should fall back to using tool_call_data directly when deserialization fails @@ -3260,8 +3265,9 @@ async def test_deserialize_processed_response_local_shell_action_round_trip(self "interruptions": [], } + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert len(result.local_shell_calls) == 1 @@ -3310,8 +3316,9 @@ def __init__(self): } # This should trigger lines 831-852 + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None # The MCP approval request might not be deserialized if MockMCPTool isn't a HostedMCPTool, @@ -3635,6 +3642,7 @@ class AgentWithoutGetAllTools: agent, # type: ignore[arg-type] context, {}, + [], ) assert result is not None @@ -3666,7 +3674,9 @@ async def test_deserialize_processed_response_empty_mcp_tool_data(self): ], } - result = await _deserialize_processed_response(processed_response_data, agent, context, {}) + result = await _deserialize_processed_response( + processed_response_data, agent, context, {}, [] + ) # Should skip the empty mcp_tool_data and not add it to mcp_approval_requests assert len(result.mcp_approval_requests) == 0