diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 49c7c794b7..973260e54f 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -148,6 +148,8 @@ ShellToolLocalSkill, ShellToolSkillReference, Tool, + ToolOrigin, + ToolOriginType, ToolOutputFileContent, ToolOutputFileContentDict, ToolOutputImage, @@ -408,6 +410,8 @@ def enable_verbose_stdout_logging(): "ApplyPatchResult", "ApplyPatchTool", "Tool", + "ToolOrigin", + "ToolOriginType", "WebSearchTool", "HostedMCPTool", "MCPToolApprovalFunction", diff --git a/src/agents/agent.py b/src/agents/agent.py index f28df1a14c..54b91a5daf 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -47,6 +47,8 @@ FunctionToolResult, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, _extract_tool_argument_json_error, default_tool_error_function, ) @@ -851,6 +853,11 @@ async def _run_agent_tool(context: ToolContext, input_json: str) -> Any: ) run_agent_tool._is_agent_tool = True run_agent_tool._agent_instance = self + # Set origin tracking on run_agent (the FunctionTool returned by @function_tool) + run_agent_tool._tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool=self, + ) return run_agent_tool diff --git a/src/agents/items.py b/src/agents/items.py index 8eadcbcd0a..5f5eb2617c 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -50,6 +50,7 @@ from .exceptions import AgentsException, ModelBehaviorError from .logger import logger from .tool import ( + ToolOrigin, ToolOutputFileContent, ToolOutputImage, ToolOutputText, @@ -250,6 +251,15 @@ class ToolCallItem(RunItemBase[Any]): description: str | None = None """Optional tool description if known at item creation time.""" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -274,6 +284,15 @@ class ToolCallOutputItem(RunItemBase[Any]): type: Literal["tool_call_output_item"] = "tool_call_output_item" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. @@ -375,8 +394,17 @@ class ToolApprovalItem(RunItemBase[Any]): tool_name: str | None = None """Tool name for approval tracking; falls back to raw_item.name when absent.""" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool. Only set for FunctionTool calls.""" + type: Literal["tool_approval_item"] = "tool_approval_item" + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + def __post_init__(self) -> None: """Populate tool_name from the raw item if not provided.""" if self.tool_name is None: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 4a5f5c4a4d..31f1cc4200 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -20,6 +20,8 @@ FunctionTool, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, ToolOutputImageDict, ToolOutputTextDict, default_tool_error_function, @@ -302,7 +304,7 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput: bool | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] ) = server._get_needs_approval_for_tool(tool, agent) - return FunctionTool( + function_tool = FunctionTool( name=tool.name, description=tool.description or "", params_json_schema=schema, @@ -310,6 +312,11 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput: strict_json_schema=is_strict, needs_approval=needs_approval, ) + function_tool._tool_origin = ToolOrigin( + type=ToolOriginType.MCP, + mcp_server=server, + ) + return function_tool @staticmethod def _merge_mcp_meta( diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 6448c17bfc..f3ebc809eb 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -429,7 +429,14 @@ def _build_tool_approval_item( "call_id": tool_call.call_id, "arguments": tool_call.arguments, } - return ToolApprovalItem(agent=cast(Any, agent), raw_item=raw_item, tool_name=tool.name) + from ..tool import _get_tool_origin_info + + return ToolApprovalItem( + agent=cast(Any, agent), + raw_item=raw_item, + tool_name=tool.name, + tool_origin=_get_tool_origin_info(tool), + ) async def _maybe_request_tool_approval( self, diff --git a/src/agents/run_internal/items.py b/src/agents/run_internal/items.py index ed21e801f9..ede190aeb7 100644 --- a/src/agents/run_internal/items.py +++ b/src/agents/run_internal/items.py @@ -15,7 +15,7 @@ from ..agent_tool_state import drop_agent_tool_run_result from ..items import ItemHelpers, RunItem, ToolCallOutputItem, TResponseInputItem from ..models.fake_id import FAKE_RESPONSES_ID -from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE +from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, ToolOrigin REJECTION_MESSAGE = DEFAULT_APPROVAL_REJECTION_MESSAGE _TOOL_CALL_TO_OUTPUT_TYPE: dict[str, str] = { @@ -246,6 +246,7 @@ def function_rejection_item( tool_call: Any, *, rejection_message: str = REJECTION_MESSAGE, + tool_origin: ToolOrigin | None = None, scope_id: str | None = None, ) -> ToolCallOutputItem: """Build a ToolCallOutputItem representing a rejected function tool call.""" @@ -255,6 +256,7 @@ def function_rejection_item( output=rejection_message, raw_item=ItemHelpers.tool_call_output_item(tool_call, rejection_message), agent=agent, + tool_origin=tool_origin, ) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index b32276bdc3..dc90a336fa 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -10,7 +10,12 @@ from collections.abc import Awaitable, Callable from typing import Any, TypeVar, cast -from openai.types.responses import Response, ResponseCompletedEvent, ResponseOutputItemDoneEvent +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, +) from openai.types.responses.response_prompt_param import ResponsePromptParam from openai.types.responses.response_reasoning_item import ResponseReasoningItem @@ -49,7 +54,7 @@ RawResponsesStreamEvent, RunItemStreamEvent, ) -from ..tool import Tool, dispose_resolved_computers +from ..tool import FunctionTool, Tool, _get_tool_origin_info, dispose_resolved_computers from ..tracing import Span, SpanError, agent_span, get_current_trace from ..tracing.model_tracing import get_model_tracing_impl from ..tracing.span_data import AgentSpanData @@ -113,6 +118,7 @@ from .streaming import stream_step_items_to_queue, stream_step_result_to_queue from .tool_actions import ApplyPatchAction, ComputerAction, LocalShellAction, ShellAction from .tool_execution import ( + build_litellm_json_tool_call, coerce_shell_call, execute_apply_patch_calls, execute_computer_actions, @@ -1063,6 +1069,8 @@ async def run_single_turn_streamed( # execution in process_model_response, so duplicate names (e.g., MCP + local tool) # stream the same description that execution uses. tool_map = {t.name: t for t in all_tools if hasattr(t, "name") and t.name} + # FunctionTool-only map for tool_origin; matches process_model_response's last-wins. + function_map = {t.name: t for t in all_tools if isinstance(t, FunctionTool)} try: turn_input = ItemHelpers.input_to_new_input_list(streamed_result.input) @@ -1250,13 +1258,31 @@ async def run_single_turn_streamed( # execution behavior in process_model_response). tool_name = getattr(output_item, "name", None) tool_description: str | None = None + tool_origin = None if isinstance(tool_name, str) and tool_name in tool_map: - tool_description = getattr(tool_map[tool_name], "description", None) + tool = tool_map[tool_name] + tool_description = getattr(tool, "description", None) + # Use function_map for tool_origin to match process_model_response's + # last-wins semantics when multiple FunctionTools share a name. + func_tool = function_map.get(tool_name) + if func_tool is not None: + tool_origin = _get_tool_origin_info(func_tool) + elif ( + isinstance(tool_name, str) + and tool_name == "json_tool_call" + and output_schema is not None + and isinstance(output_item, ResponseFunctionToolCall) + ): + # json_tool_call is synthesized dynamically and not in tool_map. + # Synthesize it here to get tool_origin, matching process_model_response. + json_tool = build_litellm_json_tool_call(output_item) + tool_origin = _get_tool_origin_info(json_tool) tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), agent=agent, description=tool_description, + tool_origin=tool_origin, ) streamed_result._event_queue.put_nowait( RunItemStreamEvent(item=tool_item, name="tool_called") diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index 45b32f418d..377bf2090e 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -56,6 +56,7 @@ ShellCallOutcome, ShellCommandOutput, Tool, + _get_tool_origin_info, invoke_function_tool, resolve_computer, ) @@ -872,8 +873,12 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo ) if approval_status is None: + tool_origin = _get_tool_origin_info(func_tool) approval_item = ToolApprovalItem( - agent=agent, raw_item=tool_call, tool_name=func_tool.name + agent=agent, + raw_item=tool_call, + tool_name=func_tool.name, + tool_origin=tool_origin, ) return FunctionToolResult( tool=func_tool, output=None, run_item=approval_item @@ -901,6 +906,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo ) result = rejection_message span_fn.span_data.output = result + tool_origin = _get_tool_origin_info(func_tool) return FunctionToolResult( tool=func_tool, output=result, @@ -908,6 +914,7 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo agent, tool_call, rejection_message=rejection_message, + tool_origin=tool_origin, scope_id=tool_state_scope_id, ), ) @@ -1024,10 +1031,12 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo run_item: RunItem | None = None if not nested_interruptions: + tool_origin = _get_tool_origin_info(tool_run.function_tool) run_item = ToolCallOutputItem( output=result, raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), agent=agent, + tool_origin=tool_origin, ) else: # Skip tool output until nested interruptions are resolved. diff --git a/src/agents/run_internal/tool_use_tracker.py b/src/agents/run_internal/tool_use_tracker.py index 33e2d72156..aa8c35645b 100644 --- a/src/agents/run_internal/tool_use_tracker.py +++ b/src/agents/run_internal/tool_use_tracker.py @@ -85,7 +85,7 @@ def hydrate_tool_use_tracker( if not snapshot: return - agent_map = _build_agent_map(starting_agent) + agent_map, _ = _build_agent_map(starting_agent) for agent_name, tool_names in snapshot.items(): agent = agent_map.get(agent_name) if agent is None: diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index 59996c6e0f..7e14702d52 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -63,6 +63,7 @@ LocalShellTool, ShellTool, Tool, + _get_tool_origin_info, ) from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from ..tracing import SpanError, handoff_span @@ -696,11 +697,13 @@ async def _record_function_rejection( tool_name=function_tool.name, call_id=call_id, ) + tool_origin = _get_tool_origin_info(function_tool) rejected_function_outputs.append( function_rejection_item( agent, tool_call, rejection_message=rejection_message, + tool_origin=tool_origin, scope_id=tool_state_scope_id, ) ) @@ -979,7 +982,12 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: output_exists_checker=_function_output_exists, record_rejection=_record_function_rejection, pending_interruption_adder=_add_pending_interruption, - pending_item_builder=lambda run: ToolApprovalItem(agent=agent, raw_item=run.tool_call), + pending_item_builder=lambda run: ToolApprovalItem( + agent=agent, + raw_item=run.tool_call, + tool_name=run.function_tool.name, + tool_origin=_get_tool_origin_info(run.function_tool), + ), ) rebuilt_function_tool_runs = await _rebuild_function_runs_from_approvals() @@ -1536,11 +1544,19 @@ def process_model_response( else: if output.name not in function_map: if output_schema is not None and output.name == "json_tool_call": - items.append(ToolCallItem(raw_item=output, agent=agent)) + json_tool = build_litellm_json_tool_call(output) + tool_origin = _get_tool_origin_info(json_tool) + items.append( + ToolCallItem( + raw_item=output, + agent=agent, + tool_origin=tool_origin, + ) + ) functions.append( ToolRunFunction( tool_call=output, - function_tool=build_litellm_json_tool_call(output), + function_tool=json_tool, ) ) continue @@ -1554,8 +1570,14 @@ def process_model_response( raise ModelBehaviorError(error) func_tool = function_map[output.name] + tool_origin = _get_tool_origin_info(func_tool) items.append( - ToolCallItem(raw_item=output, agent=agent, description=func_tool.description) + ToolCallItem( + raw_item=output, + agent=agent, + description=func_tool.description, + tool_origin=tool_origin, + ) ) functions.append( ToolRunFunction( diff --git a/src/agents/run_state.py b/src/agents/run_state.py index b541d40d6b..072ddb9858 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -65,6 +65,8 @@ HostedMCPTool, LocalShellTool, ShellTool, + ToolOrigin, + ToolOriginType, ) from .tool_guardrails import ( AllowBehavior, @@ -100,8 +102,8 @@ # 2. Keep older readable versions in SUPPORTED_SCHEMA_VERSIONS for backward reads. # 3. to_json() always emits CURRENT_SCHEMA_VERSION. # 4. Forward compatibility is intentionally fail-fast (older SDKs reject newer versions). -CURRENT_SCHEMA_VERSION = "1.4" -SUPPORTED_SCHEMA_VERSIONS = frozenset({"1.0", "1.1", "1.2", "1.3", CURRENT_SCHEMA_VERSION}) +CURRENT_SCHEMA_VERSION = "1.5" +SUPPORTED_SCHEMA_VERSIONS = frozenset({"1.0", "1.1", "1.2", "1.3", "1.4", CURRENT_SCHEMA_VERSION}) _FUNCTION_OUTPUT_ADAPTER: TypeAdapter[FunctionCallOutput] = TypeAdapter(FunctionCallOutput) _COMPUTER_OUTPUT_ADAPTER: TypeAdapter[ComputerCallOutput] = TypeAdapter(ComputerCallOutput) @@ -207,11 +209,13 @@ def __init__( conversation_id: str | None = None, previous_response_id: str | None = None, auto_previous_response_id: bool = False, + root_agent: TAgent | None = None, ): """Initialize a new RunState.""" self._context = context self._original_input = _clone_original_input(original_input) self._current_agent = starting_agent + self._starting_agent = root_agent if root_agent is not None else starting_agent self._max_turns = max_turns self._conversation_id = conversation_id self._previous_response_id = previous_response_id @@ -530,6 +534,8 @@ def to_json( if self._context is None: raise UserError("Cannot serialize RunState: No context") + _, agent_list = _build_agent_map(self._starting_agent) + approvals_dict = self._serialize_approvals() model_responses = self._serialize_model_responses() original_input_serialized = self._serialize_original_input() @@ -576,13 +582,18 @@ def to_json( } generated_items = self._merge_generated_items_with_processed() - result["generated_items"] = [self._serialize_item(item) for item in generated_items] - result["session_items"] = [self._serialize_item(item) for item in list(self._session_items)] - result["current_step"] = self._serialize_current_step() + result["generated_items"] = [ + self._serialize_item(item, agent_list) for item in generated_items + ] + result["session_items"] = [ + self._serialize_item(item, agent_list) for item in list(self._session_items) + ] + result["current_step"] = self._serialize_current_step(agent_list) result["last_model_response"] = _serialize_last_model_response(model_responses) result["last_processed_response"] = ( self._serialize_processed_response( self._last_processed_response, + agent_list, context_serializer=context_serializer, strict_context=strict_context, include_tracing_api_key=include_tracing_api_key, @@ -600,6 +611,7 @@ def to_json( def _serialize_processed_response( self, processed_response: ProcessedResponse, + agent_list: list[Any], *, context_serializer: ContextSerializer | None = None, strict_context: bool = False, @@ -609,6 +621,7 @@ def _serialize_processed_response( Args: processed_response: The ProcessedResponse to serialize. + agent_list: List of agents in traversal order for tool_origin serialization. Returns: A dictionary representation of the ProcessedResponse. @@ -626,19 +639,25 @@ def _serialize_processed_response( ) interruptions_data = [ - _serialize_tool_approval_interruption(interruption, include_tool_name=True) + _serialize_tool_approval_interruption( + interruption, + include_tool_name=True, + agent_list=agent_list, + ) for interruption in processed_response.interruptions if isinstance(interruption, ToolApprovalItem) ] return { - "new_items": [self._serialize_item(item) for item in processed_response.new_items], + "new_items": [ + self._serialize_item(item, agent_list) for item in processed_response.new_items + ], "tools_used": processed_response.tools_used, **action_groups, "interruptions": interruptions_data, } - def _serialize_current_step(self) -> dict[str, Any] | None: + def _serialize_current_step(self, agent_list: list[Any] | None = None) -> dict[str, Any] | None: """Serialize the current step if it's an interruption.""" # Import at runtime to avoid circular import from .run_internal.run_steps import NextStepInterruption @@ -648,7 +667,9 @@ def _serialize_current_step(self) -> dict[str, Any] | None: interruptions_data = [ _serialize_tool_approval_interruption( - item, include_tool_name=item.tool_name is not None + item, + include_tool_name=item.tool_name is not None, + agent_list=agent_list, ) for item in self._current_step.interruptions if isinstance(item, ToolApprovalItem) @@ -661,7 +682,7 @@ def _serialize_current_step(self) -> dict[str, Any] | None: }, } - def _serialize_item(self, item: RunItem) -> dict[str, Any]: + def _serialize_item(self, item: RunItem, agent_list: list[Any] | None = None) -> dict[str, Any]: """Serialize a run item to JSON-compatible dict.""" raw_item_dict: Any = _serialize_raw_item_value(item.raw_item) @@ -691,6 +712,10 @@ def _serialize_item(self, item: RunItem) -> dict[str, Any]: result["tool_name"] = item.tool_name if hasattr(item, "description") and item.description is not None: result["description"] = item.description + if hasattr(item, "tool_origin") and item.tool_origin is not None: + tool_origin_data = _serialize_tool_origin(item.tool_origin, agent_list) + if tool_origin_data: + result["tool_origin"] = tool_origin_data return result @@ -1111,8 +1136,33 @@ def _serialize_mcp_tool(mcp_tool: Any) -> dict[str, Any]: return {"value": normalized} +def _serialize_tool_origin( + tool_origin: ToolOrigin, agent_list: list[Any] | None = None +) -> dict[str, Any] | None: + """Serialize ToolOrigin to JSON. Uses agent_index when agent_list provided.""" + if tool_origin is None: + return None + result: dict[str, Any] = {"type": tool_origin.type.value} + if tool_origin.agent_as_tool is not None: + agent = tool_origin.agent_as_tool + if agent_list: + try: + idx = next(i for i, a in enumerate(agent_list) if a is agent) + result["agent_as_tool"] = {"agent_index": idx} + except StopIteration: + result["agent_as_tool"] = {"name": agent.name} + else: + result["agent_as_tool"] = {"name": agent.name} + if tool_origin.mcp_server is not None: + result["mcp_server"] = {"name": tool_origin.mcp_server.name} + return result + + def _serialize_tool_approval_interruption( - interruption: ToolApprovalItem, *, include_tool_name: bool + interruption: ToolApprovalItem, + *, + include_tool_name: bool, + agent_list: list[Any] | None = None, ) -> dict[str, Any]: """Serialize a ToolApprovalItem interruption.""" interruption_dict: dict[str, Any] = { @@ -1122,6 +1172,10 @@ def _serialize_tool_approval_interruption( } if include_tool_name and interruption.tool_name is not None: interruption_dict["tool_name"] = interruption.tool_name + if hasattr(interruption, "tool_origin") and interruption.tool_origin is not None: + tool_origin_data = _serialize_tool_origin(interruption.tool_origin, agent_list) + if tool_origin_data: + interruption_dict["tool_origin"] = tool_origin_data return interruption_dict @@ -1409,6 +1463,7 @@ async def _deserialize_processed_response( current_agent: Agent[Any], context: RunContextWrapper[Any], agent_map: dict[str, Agent[Any]], + agent_list: list[Agent[Any]], *, scope_id: str | None = None, context_deserializer: ContextDeserializer | None = None, @@ -1421,11 +1476,14 @@ async def _deserialize_processed_response( current_agent: The current agent (used to get tools and handoffs). context: The run context wrapper. agent_map: Map of agent names to agents. + agent_list: List of agents in traversal order for agent_index lookup. Returns: A reconstructed ProcessedResponse instance. """ - new_items = _deserialize_items(processed_response_data.get("new_items", []), agent_map) + new_items = _deserialize_items( + processed_response_data.get("new_items", []), agent_map, agent_list + ) if hasattr(current_agent, "get_all_tools"): all_tools = await current_agent.get_all_tools(context) @@ -1704,6 +1762,7 @@ def _deserialize_tool_approval_item( item_data: Mapping[str, Any], *, agent_map: Mapping[str, Agent[Any]], + agent_list: list[Agent[Any]] | None = None, fallback_agent: Agent[Any] | None = None, pre_normalized_raw_item: Any | None = None, ) -> ToolApprovalItem | None: @@ -1720,7 +1779,12 @@ def _deserialize_tool_approval_item( tool_name = item_data.get("tool_name") raw_item = _deserialize_tool_approval_raw_item(raw_item_data) - return ToolApprovalItem(agent=agent, raw_item=raw_item, tool_name=tool_name) + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent, agent_list + ) + return ToolApprovalItem( + agent=agent, raw_item=raw_item, tool_name=tool_name, tool_origin=tool_origin + ) def _deserialize_tool_call_output_raw_item( @@ -1947,7 +2011,7 @@ async def _build_run_state_from_json( f"New snapshots are written as version {CURRENT_SCHEMA_VERSION}." ) - agent_map = _build_agent_map(initial_agent) + agent_map, agent_list = _build_agent_map(initial_agent) current_agent_name = state_json["current_agent"]["name"] current_agent = agent_map.get(current_agent_name) @@ -2030,6 +2094,7 @@ async def _build_run_state_from_json( conversation_id=state_json.get("conversation_id"), previous_response_id=state_json.get("previous_response_id"), auto_previous_response_id=bool(state_json.get("auto_previous_response_id", False)), + root_agent=initial_agent, ) from .agent_tool_state import set_agent_tool_state_scope @@ -2038,7 +2103,9 @@ async def _build_run_state_from_json( state._current_turn = state_json["current_turn"] state._model_responses = _deserialize_model_responses(state_json.get("model_responses", [])) - state._generated_items = _deserialize_items(state_json.get("generated_items", []), agent_map) + state._generated_items = _deserialize_items( + state_json.get("generated_items", []), agent_map, agent_list + ) last_processed_response_data = state_json.get("last_processed_response") if last_processed_response_data and state._context is not None: @@ -2047,6 +2114,7 @@ async def _build_run_state_from_json( current_agent, state._context, agent_map, + agent_list, scope_id=state._agent_tool_state_scope_id, context_deserializer=context_deserializer, strict_context=strict_context, @@ -2055,7 +2123,9 @@ async def _build_run_state_from_json( state._last_processed_response = None if "session_items" in state_json: - state._session_items = _deserialize_items(state_json.get("session_items", []), agent_map) + state._session_items = _deserialize_items( + state_json.get("session_items", []), agent_map, agent_list + ) else: state._session_items = state._merge_generated_items_with_processed() @@ -2081,7 +2151,9 @@ async def _build_run_state_from_json( "interruptions", current_step_data.get("interruptions", []) ) for item_data in interruptions_data: - approval_item = _deserialize_tool_approval_item(item_data, agent_map=agent_map) + approval_item = _deserialize_tool_approval_item( + item_data, agent_map=agent_map, agent_list=agent_list + ) if approval_item is not None: interruptions.append(approval_item) @@ -2109,23 +2181,35 @@ async def _build_run_state_from_json( return state -def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: - """Build a map of agent names to agents by traversing handoffs. +def _build_agent_map( + initial_agent: Agent[Any], +) -> tuple[dict[str, Agent[Any]], list[Agent[Any]]]: + """Build agent map and agent list by traversing handoffs and tools. + + Uses id(agent) for visited tracking so duplicate-name agents are included in + agent_list. agent_map uses name as key (first wins) for backward compatibility. + agent_list preserves order for stable agent_index in tool_origin serialization. Args: initial_agent: The starting agent. Returns: - Dictionary mapping agent names to agent instances. + Tuple of (agent_map, agent_list). agent_map maps name -> agent (first wins). + agent_list is all agents in BFS order, including duplicates by name. """ agent_map: dict[str, Agent[Any]] = {} + agent_list: list[Agent[Any]] = [] + visited: set[int] = set() queue: deque[Agent[Any]] = deque([initial_agent]) while queue: current = queue.popleft() - if current.name in agent_map: + if id(current) in visited: continue - agent_map[current.name] = current + visited.add(id(current)) + agent_list.append(current) + if current.name not in agent_map: + agent_map[current.name] = current # Add handoff agents to the queue for handoff_item in current.handoffs: @@ -2133,21 +2217,15 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: handoff_agent_name: str | None = None if isinstance(handoff_item, Handoff): - # Some custom/mocked Handoff subclasses bypass dataclass initialization. - # Prefer agent_name, then legacy name fallback used in tests. candidate_name = getattr(handoff_item, "agent_name", None) or getattr( handoff_item, "name", None ) if isinstance(candidate_name, str): handoff_agent_name = candidate_name - if handoff_agent_name in agent_map: - continue handoff_ref = getattr(handoff_item, "_agent_ref", None) handoff_agent = handoff_ref() if callable(handoff_ref) else None if handoff_agent is None: - # Backward-compatibility fallback for custom legacy handoff objects that store - # the target directly on `.agent`. New code should prefer `handoff()` objects. legacy_agent = getattr(handoff_item, "agent", None) if legacy_agent is not None: handoff_agent = legacy_agent @@ -2155,7 +2233,7 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: "Using legacy handoff `.agent` fallback while building agent map. " "This compatibility path is not recommended for new code." ) - if handoff_agent_name is None: + if handoff_agent_name is None and handoff_agent is not None: candidate_name = getattr(handoff_agent, "name", None) handoff_agent_name = candidate_name if isinstance(candidate_name, str) else None if handoff_agent is None or not hasattr(handoff_agent, "handoffs"): @@ -2166,8 +2244,6 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: ) continue else: - # Backward-compatibility fallback for custom legacy handoff wrappers that expose - # the target directly on `.agent` without inheriting from `Handoff`. legacy_agent = getattr(handoff_item, "agent", None) if legacy_agent is not None: handoff_agent = legacy_agent @@ -2182,7 +2258,7 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: if ( handoff_agent is not None and handoff_agent_name - and handoff_agent_name not in agent_map + and id(handoff_agent) not in visited ): queue.append(cast(Any, handoff_agent)) @@ -2194,10 +2270,86 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: continue tool_agent = getattr(tool, "_agent_instance", None) tool_agent_name = getattr(tool_agent, "name", None) - if tool_agent and tool_agent_name and tool_agent_name not in agent_map: + if tool_agent and tool_agent_name and id(tool_agent) not in visited: queue.append(tool_agent) - return agent_map + return agent_map, agent_list + + +def _deserialize_tool_origin( + tool_origin_data: dict[str, Any] | None, + agent_map: Mapping[str, Agent[Any]], + agent: Agent[Any], + agent_list: list[Agent[Any]] | None = None, +) -> ToolOrigin | None: + """Deserialize ToolOrigin from JSON data. + + Args: + tool_origin_data: Serialized tool origin dictionary. + agent_map: Map of agent names to agent instances. + agent: The agent associated with this item (used for MCP server lookup). + agent_list: List of agents in traversal order for agent_index lookup. + + Returns: + ToolOrigin instance or None if data is missing/invalid. + """ + if not tool_origin_data: + return None + + origin_type_str = tool_origin_data.get("type") + if not origin_type_str: + return None + + try: + origin_type = ToolOriginType(origin_type_str) + except ValueError: + logger.warning(f"Unknown tool origin type: {origin_type_str}") + return None + + agent_as_tool: Agent[Any] | None = None + mcp_server: Any | None = None + + if origin_type == ToolOriginType.AGENT_AS_TOOL: + agent_data = tool_origin_data.get("agent_as_tool") + if agent_data and isinstance(agent_data, Mapping): + agent_index = agent_data.get("agent_index") + if agent_index is not None and isinstance(agent_index, int) and agent_list: + if 0 <= agent_index < len(agent_list): + agent_as_tool = agent_list[agent_index] + else: + logger.warning( + "Agent index %s out of range for tool_origin (list len=%s)", + agent_index, + len(agent_list), + ) + else: + agent_name = agent_data.get("name") + if agent_name: + agent_as_tool = agent_map.get(agent_name) + if not agent_as_tool: + logger.warning(f"Agent {agent_name} not found in agent map for tool_origin") + + elif origin_type == ToolOriginType.MCP: + mcp_data = tool_origin_data.get("mcp_server") + if mcp_data and isinstance(mcp_data, Mapping): + server_name = mcp_data.get("name") + if server_name: + # Try to find the MCP server from the agent's mcp_servers list + mcp_servers = getattr(agent, "mcp_servers", []) + for server in mcp_servers: + if hasattr(server, "name") and server.name == server_name: + mcp_server = server + break + if not mcp_server: + logger.debug( + f"MCP server {server_name} not found in agent's mcp_servers for tool_origin" + ) + + return ToolOrigin( + type=origin_type, + agent_as_tool=agent_as_tool, + mcp_server=mcp_server, + ) def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[ModelResponse]: @@ -2237,17 +2389,23 @@ def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[M def _deserialize_items( - items_data: list[dict[str, Any]], agent_map: dict[str, Agent[Any]] + items_data: list[dict[str, Any]], + agent_map: dict[str, Agent[Any]], + agent_list: list[Agent[Any]] | None = None, ) -> list[RunItem]: """Deserialize run items from JSON data. Args: items_data: List of serialized run item dictionaries. agent_map: Map of agent names to agent instances. + agent_list: List of agents in traversal order for agent_index lookup. If None, + agent_as_tool falls back to name-based lookup (duplicate names not supported). Returns: List of RunItem instances. """ + if agent_list is None: + agent_list = list(agent_map.values()) result: list[RunItem] = [] @@ -2303,8 +2461,17 @@ def _resolve_agent_info( raw_item_tool = _deserialize_tool_call_raw_item(normalized_raw_item) # Preserve description if it was stored with the item description = item_data.get("description") + # Preserve tool_origin if it was stored with the item + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent, agent_list + ) result.append( - ToolCallItem(agent=agent, raw_item=raw_item_tool, description=description) + ToolCallItem( + agent=agent, + raw_item=raw_item_tool, + description=description, + tool_origin=tool_origin, + ) ) elif item_type == "tool_call_output_item": @@ -2313,11 +2480,16 @@ def _resolve_agent_info( raw_item_output = _deserialize_tool_call_output_raw_item(normalized_raw_item) if raw_item_output is None: continue + # Preserve tool_origin if it was stored with the item + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent, agent_list + ) result.append( ToolCallOutputItem( agent=agent, raw_item=raw_item_output, output=item_data.get("output", ""), + tool_origin=tool_origin, ) ) diff --git a/src/agents/tool.py b/src/agents/tool.py index 163b966c89..8bc828fb6b 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import enum import inspect import json import math @@ -50,6 +51,7 @@ if TYPE_CHECKING: from .agent import Agent, AgentBase from .items import RunItem, ToolApprovalItem + from .mcp.server import MCPServer ToolParams = ParamSpec("ToolParams") @@ -187,6 +189,96 @@ class ComputerProvider(Generic[ComputerT]): ] +class ToolOriginType(str, enum.Enum): + """The type of tool origin.""" + + FUNCTION = "function" + """Regular Python function tool created via @function_tool decorator.""" + + MCP = "mcp" + """MCP server tool converted via MCPUtil.to_function_tool().""" + + AGENT_AS_TOOL = "agent_as_tool" + """Agent converted to tool via agent.as_tool().""" + + +@dataclass +class ToolOrigin: + """Information about the origin/source of a function tool.""" + + type: ToolOriginType + """The type of tool origin.""" + + mcp_server: MCPServer | None = None + """The MCP server object. Only set when type is MCP.""" + + agent_as_tool: Agent[Any] | None = None + """The agent object. Only set when type is AGENT_AS_TOOL.""" + + _agent_as_tool_ref: weakref.ReferenceType[Agent[Any]] | None = field( + default=None, init=False, repr=False + ) + """Weak reference to agent_as_tool for memory management.""" + + def __post_init__(self) -> None: + """Initialize weak reference for agent_as_tool.""" + if self.agent_as_tool is not None: + self._agent_as_tool_ref = weakref.ref(self.agent_as_tool) + + def __getattribute__(self, name: str) -> Any: + """Lazily resolve agent_as_tool via weakref when strong ref is cleared.""" + if name == "agent_as_tool": + # Check if strong reference still exists + value = object.__getattribute__(self, "__dict__").get("agent_as_tool") + if value is not None: + return value + # Try to resolve via weakref + ref = object.__getattribute__(self, "_agent_as_tool_ref") + if ref is not None: + agent = ref() + if agent is not None: + return agent + return None + return super().__getattribute__(name) + + def release_agent(self) -> None: + """Release the strong reference to agent_as_tool while keeping a weak reference.""" + if "agent_as_tool" not in self.__dict__: + return + agent = self.__dict__.get("agent_as_tool") + if agent is not None: + self._agent_as_tool_ref = weakref.ref(agent) + # Set to None instead of deleting so dataclass repr/asdict keep working. + self.__dict__["agent_as_tool"] = None + + def __repr__(self) -> str: + """Custom repr that only includes relevant fields.""" + parts = [f"type={self.type.value!r}"] + if self.mcp_server is not None: + parts.append(f"mcp_server_name={self.mcp_server.name!r}") + agent = self.agent_as_tool + if agent is not None: + parts.append(f"agent_as_tool_name={agent.name!r}") + return f"ToolOrigin({', '.join(parts)})" + + +def _get_tool_origin_info(function_tool: FunctionTool) -> ToolOrigin | None: + """Extract origin information from a FunctionTool. + + Args: + function_tool: The function tool to extract origin info from. + + Returns: + ToolOrigin object if origin is set, otherwise None (defaults to FUNCTION type). + """ + origin = function_tool._tool_origin + if origin is None: + # Default to FUNCTION if not explicitly set + return ToolOrigin(type=ToolOriginType.FUNCTION) + + return origin + + @dataclass class FunctionToolResult: tool: FunctionTool @@ -286,6 +378,9 @@ class FunctionTool: _agent_instance: Any = field(default=None, init=False, repr=False) """Internal reference to the agent instance if this is an agent-as-tool.""" + _tool_origin: ToolOrigin | None = field(default=None, init=False, repr=False) + """Internal field tracking the origin of this tool (FUNCTION, MCP, or AGENT_AS_TOOL).""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 39d0cdb43c..55854d40c9 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -901,7 +901,7 @@ def test_build_agent_map_collects_agents_without_looping(self): agent_a.handoffs = [agent_b] agent_b.handoffs = [agent_a] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert agent_map.get("AgentA") is not None assert agent_map.get("AgentB") is not None @@ -921,7 +921,7 @@ def test_build_agent_map_handles_complex_handoff_graphs(self): agent_b.handoffs = [agent_d] agent_c.handoffs = [agent_d] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert len(agent_map) == 4 assert all(agent_map.get(name) is not None for name in ["A", "B", "C", "D"]) @@ -932,7 +932,7 @@ def test_build_agent_map_handles_handoff_objects(self): agent_b = Agent(name="AgentB") agent_a.handoffs = [handoff(agent_b)] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert sorted(agent_map.keys()) == ["AgentA", "AgentB"] @@ -950,7 +950,7 @@ def __init__(self, target: Agent[Any]): agent_a.handoffs = [LegacyHandoff(agent_b)] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert sorted(agent_map.keys()) == ["AgentA", "AgentB"] @@ -965,7 +965,7 @@ def __init__(self, target: Agent[Any]): agent_a.handoffs = [LegacyWrapper(agent_b)] # type: ignore[list-item] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert sorted(agent_map.keys()) == ["AgentA", "AgentB"] @@ -986,7 +986,7 @@ async def _invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[An ) agent_a.handoffs = [detached_handoff] - agent_map = _build_agent_map(agent_a) + agent_map, _ = _build_agent_map(agent_a) assert sorted(agent_map.keys()) == ["AgentA"] @@ -2965,7 +2965,7 @@ class AgentWithoutGetAllTools(Agent): # This should trigger line 759 (all_tools = []) result = await _deserialize_processed_response( - processed_response_data, agent_no_tools, context, {} + processed_response_data, agent_no_tools, context, {}, [] ) assert result is not None @@ -3002,8 +3002,9 @@ async def test_deserialize_processed_response_handoff_with_tool_name(self): } # This should trigger lines 778-782 and 787-796 + agent_map = {"AgentA": agent_a, "AgentB": agent_b} result = await _deserialize_processed_response( - processed_response_data, agent_a, context, {"AgentA": agent_a, "AgentB": agent_b} + processed_response_data, agent_a, context, agent_map, list(agent_map.values()) ) assert result is not None assert len(result.handoffs) == 1 @@ -3047,8 +3048,9 @@ async def tool_func(context: ToolContext[Any], arguments: str) -> str: } # This should trigger lines 801-808 + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None assert len(result.functions) == 1 @@ -3124,8 +3126,9 @@ def wait(self) -> None: } # This should trigger lines 815-824 + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None assert len(result.computer_actions) == 1 @@ -3165,8 +3168,9 @@ async def shell_executor(request: Any) -> Any: } # This should trigger the ValidationError path (lines 1299-1302) + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None # Should fall back to using tool_call_data directly when validation fails @@ -3218,8 +3222,9 @@ def delete_file(self, operation: Any) -> Any: } # This should trigger the Exception path (lines 1314-1317) + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None # Should fall back to using tool_call_data directly when deserialization fails @@ -3260,8 +3265,9 @@ async def test_deserialize_processed_response_local_shell_action_round_trip(self "interruptions": [], } + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert len(result.local_shell_calls) == 1 @@ -3310,8 +3316,9 @@ def __init__(self): } # This should trigger lines 831-852 + agent_map = {"TestAgent": agent} result = await _deserialize_processed_response( - processed_response_data, agent, context, {"TestAgent": agent} + processed_response_data, agent, context, agent_map, list(agent_map.values()) ) assert result is not None # The MCP approval request might not be deserialized if MockMCPTool isn't a HostedMCPTool, @@ -3635,6 +3642,7 @@ class AgentWithoutGetAllTools: agent, # type: ignore[arg-type] context, {}, + [], ) assert result is not None @@ -3666,7 +3674,9 @@ async def test_deserialize_processed_response_empty_mcp_tool_data(self): ], } - result = await _deserialize_processed_response(processed_response_data, agent, context, {}) + result = await _deserialize_processed_response( + processed_response_data, agent, context, {}, [] + ) # Should skip the empty mcp_tool_data and not add it to mcp_approval_requests assert len(result.mcp_approval_requests) == 0 diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py new file mode 100644 index 0000000000..245f982491 --- /dev/null +++ b/tests/test_tool_origin.py @@ -0,0 +1,386 @@ +"""Tests for tool origin tracking feature.""" + +from __future__ import annotations + +import gc +import sys +import weakref +from typing import cast + +import pytest + +from agents import Agent, FunctionTool, RunContextWrapper, Runner, function_tool +from agents.items import ToolCallItem, ToolCallItemTypes, ToolCallOutputItem +from agents.tool import ToolOrigin, ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_function_tool_origin(): + """Test that regular function tools have FUNCTION origin.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_call_items[0].tool_origin.mcp_server is None + assert tool_call_items[0].tool_origin.agent_as_tool is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_items[0].tool_origin.mcp_server is None + assert tool_output_items[0].tool_origin.agent_as_tool is None + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_mcp_tool_origin(): + """Test that MCP tools have MCP origin with server name.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server is not None + assert tool_call_items[0].tool_origin.mcp_server.name == "test_mcp_server" + assert tool_call_items[0].tool_origin.agent_as_tool is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server is not None + assert tool_output_items[0].tool_origin.mcp_server.name == "test_mcp_server" + assert tool_output_items[0].tool_origin.agent_as_tool is None + + +@pytest.mark.asyncio +async def test_agent_as_tool_origin(): + """Test that agent-as-tool has AGENT_AS_TOOL origin with agent name.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_items[0].tool_origin.mcp_server is None + assert tool_call_items[0].tool_origin.agent_as_tool is not None + assert tool_call_items[0].tool_origin.agent_as_tool.name == "nested_agent" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_items[0].tool_origin.mcp_server is None + assert tool_output_items[0].tool_origin.agent_as_tool is not None + assert tool_output_items[0].tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_multiple_tool_origins(): + """Test that multiple tools from different origins work together.""" + model = FakeModel() + nested_model = FakeModel() + + @function_tool + def func_tool(x: int) -> str: + """Function tool.""" + return f"function: {x}" + + mcp_server = FakeMCPServer(server_name="mcp_server") + mcp_server.add_tool("mcp_tool", {}) + + nested_agent = Agent(name="nested", model=nested_model, instructions="Nested agent") + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + agent_tool = nested_agent.as_tool(tool_name="agent_tool", tool_description="Agent tool") + + agent = Agent( + name="test", + model=model, + tools=[func_tool, agent_tool], + mcp_servers=[mcp_server], + ) + + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call("func_tool", '{"x": 1}'), + get_function_tool_call("mcp_tool", ""), + get_function_tool_call("agent_tool", '{"input": "test"}'), + ], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 3 + assert len(tool_output_items) == 3 + + # Find items by tool name + function_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "func_tool" + ) + mcp_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "mcp_tool" + ) + agent_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "agent_tool" + ) + + assert function_item.tool_origin is not None + assert function_item.tool_origin.type == ToolOriginType.FUNCTION + assert mcp_item.tool_origin is not None + assert mcp_item.tool_origin.type == ToolOriginType.MCP + assert mcp_item.tool_origin.mcp_server is not None + assert mcp_item.tool_origin.mcp_server.name == "mcp_server" + assert agent_item.tool_origin is not None + assert agent_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert agent_item.tool_origin.agent_as_tool is not None + assert agent_item.tool_origin.agent_as_tool.name == "nested" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_tool_origin_streaming(): + """Test that tool origin is populated correctly in streaming scenarios.""" + model = FakeModel() + server = FakeMCPServer(server_name="streaming_server") + server.add_tool("streaming_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("streaming_tool", "")], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="test") + tool_call_items = [] + tool_output_items = [] + + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + if isinstance(event.item, ToolCallItem): + tool_call_items.append(event.item) + elif isinstance(event.item, ToolCallOutputItem): + tool_output_items.append(event.item) + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server is not None + assert tool_call_items[0].tool_origin.mcp_server.name == "streaming_server" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server is not None + assert tool_output_items[0].tool_origin.mcp_server.name == "streaming_server" + + +@pytest.mark.asyncio +async def test_tool_origin_repr(): + """Test that ToolOrigin repr only shows relevant fields.""" + # FUNCTION origin + function_origin = ToolOrigin(type=ToolOriginType.FUNCTION) + assert "mcp_server_name" not in repr(function_origin) + assert "agent_as_tool_name" not in repr(function_origin) + + # MCP origin + if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + test_server = FakeMCPServer(server_name="test_server") + mcp_origin = ToolOrigin(type=ToolOriginType.MCP, mcp_server=test_server) + assert "mcp_server_name='test_server'" in repr(mcp_origin) + assert "agent_as_tool_name" not in repr(mcp_origin) + + # AGENT_AS_TOOL origin + model = FakeModel() + test_agent = Agent(name="test_agent", model=model, instructions="Test agent") + agent_origin = ToolOrigin(type=ToolOriginType.AGENT_AS_TOOL, agent_as_tool=test_agent) + assert "agent_as_tool_name='test_agent'" in repr(agent_origin) + assert "mcp_server_name" not in repr(agent_origin) + + +@pytest.mark.asyncio +async def test_tool_origin_defaults_to_function(): + """Test that tools without explicit origin default to FUNCTION.""" + model = FakeModel() + + # Create a FunctionTool directly without using @function_tool decorator + async def test_func(ctx: RunContextWrapper, args: str) -> str: + return "result" + + tool = FunctionTool( + name="direct_tool", + description="Direct tool", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=test_func, + ) + + agent = Agent(name="test", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("direct_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + + assert len(tool_call_items) == 1 + # Even though _tool_origin is None, _get_tool_origin_info defaults to FUNCTION + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_non_function_tool_items_have_no_origin(): + """Test that non-FunctionTool items (computer, shell, etc.) don't have tool_origin.""" + model = FakeModel() + + @function_tool + def func_tool() -> str: + """Function tool.""" + return "result" + + agent = Agent(name="test", model=model, tools=[func_tool]) + + # Create a ToolCallItem for a non-function tool (simulating computer/shell tool) + computer_call = { + "type": "computer_use_preview", + "call_id": "call_123", + "actions": [], + } + + # This simulates what happens for non-FunctionTool items + # They should not have tool_origin set + item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, computer_call), + agent=agent, + ) + + assert item.tool_origin is None + + +def test_tool_origin_release_agent_clears_strong_reference(): + """Test that release_agent() clears strong reference to agent_as_tool.""" + # Create a ToolOrigin with an agent_as_tool + nested_agent = Agent( + name="nested_agent", + model=FakeModel(), + instructions="You are a nested agent.", + ) + + tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool=nested_agent, + ) + + # Create a ToolCallItem with this tool_origin + tool_call_item = ToolCallItem( + raw_item=cast( + ToolCallItemTypes, + { + "type": "function_call", + "name": "test_tool", + "call_id": "call_123", + "arguments": "{}", + }, + ), + agent=nested_agent, + tool_origin=tool_origin, + ) + + # Verify agent_as_tool is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.agent_as_tool is nested_agent + + # Create weak reference to verify GC behavior + nested_agent_ref = weakref.ref(nested_agent) + + # Release agent - this should clear strong reference in tool_origin + tool_call_item.release_agent() + + # After release, agent_as_tool should still be accessible via weakref + assert tool_call_item.tool_origin.agent_as_tool is nested_agent + + # Delete the agent and force GC + del nested_agent + gc.collect() + + # After GC, agent_as_tool should be None since strong refs were cleared + assert nested_agent_ref() is None + assert tool_call_item.tool_origin.agent_as_tool is None diff --git a/tests/test_tool_origin_output_schema.py b/tests/test_tool_origin_output_schema.py new file mode 100644 index 0000000000..8556c9e9f3 --- /dev/null +++ b/tests/test_tool_origin_output_schema.py @@ -0,0 +1,105 @@ +"""Tests for tool_origin with output_schema json_tool_call.""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel + +from agents import Agent, Runner +from agents.agent_output import AgentOutputSchema +from agents.items import ModelResponse, ToolCallItem +from agents.run_internal.turn_resolution import process_model_response +from agents.tool import ToolOriginType +from agents.usage import Usage + +from .fake_model import FakeModel +from .test_responses import get_final_output_message, get_function_tool_call + + +class OutputSchema(BaseModel): + """Test output schema.""" + + result: str + + +def test_output_schema_json_tool_call_has_tool_origin(): + """Test that json_tool_call ToolCallItem has tool_origin when output_schema is enabled.""" + agent = Agent(name="test", output_type=OutputSchema) + + # Get the output_schema + from agents.run_internal.run_loop import get_output_schema + + output_schema = get_output_schema(agent) + assert output_schema is not None + assert isinstance(output_schema, AgentOutputSchema) + + # Simulate a json_tool_call response + json_output = OutputSchema(result="test").model_dump_json() + json_tool_call = get_function_tool_call("json_tool_call", json_output) + + response = ModelResponse( + output=[json_tool_call], + usage=Usage(), + response_id=None, + ) + + # Process the response + processed = process_model_response( + agent=agent, + all_tools=[], + response=response, + output_schema=output_schema, + handoffs=[], + ) + + # Find the json_tool_call item + json_tool_call_item = next( + item + for item in processed.new_items + if isinstance(item, ToolCallItem) + and hasattr(item.raw_item, "name") + and item.raw_item.name == "json_tool_call" + ) + + # Verify tool_origin is set on ToolCallItem + assert json_tool_call_item.tool_origin is not None + assert json_tool_call_item.tool_origin.type == ToolOriginType.FUNCTION + + # Verify that a ToolRunFunction was created for execution + assert len(processed.functions) == 1 + function_run = processed.functions[0] + assert function_run.function_tool.name == "json_tool_call" + + +@pytest.mark.asyncio +async def test_output_schema_json_tool_call_streaming_has_tool_origin(): + """ + Test that streamed json_tool_call ToolCallItem has tool_origin when output_schema is enabled. + """ + model = FakeModel() + agent = Agent(name="test", model=model, output_type=OutputSchema) + + # Simulate a json_tool_call response followed by completion + json_output = OutputSchema(result="test").model_dump_json() + json_tool_call = get_function_tool_call("json_tool_call", json_output) + final_output = get_final_output_message(json_output) + model.add_multiple_turn_outputs([[json_tool_call], [final_output]]) + + # Collect streamed events + streamed_tool_call_items: list[ToolCallItem] = [] + + result = Runner.run_streamed(agent, input="test") + async for event in result.stream_events(): + if event.type == "run_item_stream_event" and isinstance(event.item, ToolCallItem): + streamed_tool_call_items.append(event.item) + + # Find the json_tool_call item + json_tool_call_item = next( + item + for item in streamed_tool_call_items + if hasattr(item.raw_item, "name") and item.raw_item.name == "json_tool_call" + ) + + # Verify tool_origin is set on streamed ToolCallItem + assert json_tool_call_item.tool_origin is not None + assert json_tool_call_item.tool_origin.type == ToolOriginType.FUNCTION diff --git a/tests/test_tool_origin_rejection.py b/tests/test_tool_origin_rejection.py new file mode 100644 index 0000000000..8582e03fdb --- /dev/null +++ b/tests/test_tool_origin_rejection.py @@ -0,0 +1,205 @@ +"""Tests for tool_origin preservation on rejected function tool calls.""" + +from __future__ import annotations + +import sys + +import pytest + +from agents import Agent, function_tool +from agents.items import ToolCallOutputItem +from agents.tool import ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message +from .utils.hitl import reject_tool_call + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_rejected_function_tool_preserves_tool_origin(): + """Test that rejected function tools preserve tool_origin.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("test_tool", '{"x": 42}') + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + reject_tool_call(context, agent, tool_call, "test_tool") + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=test_tool) + results, _, _ = await execute_function_tool_calls( + agent=agent, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_rejected_agent_as_tool_preserves_tool_origin(): + """Test that rejected agent-as-tool preserves tool_origin.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("nested_tool", '{"input": "test"}') + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + from agents.tool import FunctionTool + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + assert isinstance(tool, FunctionTool) + reject_tool_call(context, orchestrator, tool_call, "nested_tool") + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=tool) + results, _, _ = await execute_function_tool_calls( + agent=orchestrator, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert result.run_item.tool_origin.agent_as_tool is not None + assert result.run_item.tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_rejected_mcp_tool_preserves_tool_origin(): + """Test that rejected MCP tools preserve tool_origin.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + # Pre-reject the tool call + tool_call = get_function_tool_call("mcp_tool", "") + from openai.types.responses import ResponseFunctionToolCall + + from agents.lifecycle import RunHooks + from agents.mcp import MCPUtil + from agents.run_config import RunConfig + from agents.run_context import RunContextWrapper + from agents.run_internal.run_steps import ToolRunFunction + from agents.run_internal.tool_execution import execute_function_tool_calls + from agents.tool import FunctionTool + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert isinstance(tool_call, ResponseFunctionToolCall) + reject_tool_call(context, agent, tool_call, "mcp_tool") + + # Get the MCP tool as FunctionTool + mcp_tools = await MCPUtil.get_all_function_tools( + agent.mcp_servers, + convert_schemas_to_strict=False, + run_context=context, + agent=agent, + ) + mcp_tool = next(tool for tool in mcp_tools if tool.name == "mcp_tool") + assert isinstance(mcp_tool, FunctionTool) + + # Execute the tool call which should be rejected + tool_run = ToolRunFunction(tool_call=tool_call, function_tool=mcp_tool) + results, _, _ = await execute_function_tool_calls( + agent=agent, + tool_runs=[tool_run], + hooks=RunHooks(), + context_wrapper=context, + config=RunConfig(), + ) + + # Should have a rejection result + assert len(results) == 1 + result = results[0] + assert result.run_item is not None + assert isinstance(result.run_item, ToolCallOutputItem) + + # Verify tool_origin is preserved on rejection + assert result.run_item.tool_origin is not None + assert result.run_item.tool_origin.type == ToolOriginType.MCP + assert result.run_item.tool_origin.mcp_server is not None + assert result.run_item.tool_origin.mcp_server.name == "test_mcp_server" diff --git a/tests/test_tool_origin_serialization.py b/tests/test_tool_origin_serialization.py new file mode 100644 index 0000000000..87bca9fcc4 --- /dev/null +++ b/tests/test_tool_origin_serialization.py @@ -0,0 +1,228 @@ +"""Tests for tool_origin serialization in RunState.""" + +from __future__ import annotations + +import sys + +import pytest + +from agents import Agent, Runner, function_tool +from agents.items import ToolCallItem, ToolCallOutputItem +from agents.run_state import RunState +from agents.tool import ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_serialize_tool_origin_function(): + """Test that FUNCTION tool_origin is serialized and deserialized.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.FUNCTION + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(agent, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.FUNCTION + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_serialize_tool_origin_agent_as_tool(): + """Test that AGENT_AS_TOOL tool_origin is serialized and deserialized.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_item.tool_origin.agent_as_tool is not None + assert tool_call_item.tool_origin.agent_as_tool.name == "nested_agent" + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_item.tool_origin.agent_as_tool is not None + assert tool_output_item.tool_origin.agent_as_tool.name == "nested_agent" + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=orchestrator, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(orchestrator, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert deserialized_tool_call.tool_origin.agent_as_tool is not None + assert deserialized_tool_call.tool_origin.agent_as_tool.name == "nested_agent" + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert deserialized_tool_output.tool_origin.agent_as_tool is not None + assert deserialized_tool_output.tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_serialize_tool_origin_mcp(): + """Test that MCP tool_origin is serialized and deserialized.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.MCP + assert tool_call_item.tool_origin.mcp_server is not None + assert tool_call_item.tool_origin.mcp_server.name == "test_mcp_server" + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.MCP + assert tool_output_item.tool_origin.mcp_server is not None + assert tool_output_item.tool_origin.mcp_server.name == "test_mcp_server" + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(agent, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.MCP + # MCP server should be reconstructed from agent's mcp_servers + assert deserialized_tool_call.tool_origin.mcp_server is not None + assert deserialized_tool_call.tool_origin.mcp_server.name == "test_mcp_server" + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.MCP + assert deserialized_tool_output.tool_origin.mcp_server is not None + assert deserialized_tool_output.tool_origin.mcp_server.name == "test_mcp_server"