diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 2ca7484739..125853693d 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -75,10 +75,9 @@ async def on_tool_start( ) -> None: """Called immediately before a local tool is invoked. - For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, + Local tool invocations with a concrete tool call receive a ``ToolContext`` instance, which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, - and ``tool_arguments``. Other local tool families may provide a plain - ``RunContextWrapper`` instead. + and ``tool_arguments``. """ pass @@ -91,10 +90,9 @@ async def on_tool_end( ) -> None: """Called immediately after a local tool is invoked. - For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, + Local tool invocations with a concrete tool call receive a ``ToolContext`` instance, which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, - and ``tool_arguments``. Other local tool families may provide a plain - ``RunContextWrapper`` instead. + and ``tool_arguments``. """ pass @@ -149,10 +147,9 @@ async def on_tool_start( ) -> None: """Called immediately before a local tool is invoked. - For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, + Local tool invocations with a concrete tool call receive a ``ToolContext`` instance, which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, - and ``tool_arguments``. Other local tool families may provide a plain - ``RunContextWrapper`` instead. + and ``tool_arguments``. """ pass @@ -165,10 +162,9 @@ async def on_tool_end( ) -> None: """Called immediately after a local tool is invoked. - For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, + Local tool invocations with a concrete tool call receive a ``ToolContext`` instance, which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, - and ``tool_arguments``. Other local tool families may provide a plain - ``RunContextWrapper`` instead. + and ``tool_arguments``. """ pass diff --git a/src/agents/run_internal/tool_actions.py b/src/agents/run_internal/tool_actions.py index 310fdc2592..90ed8072e7 100644 --- a/src/agents/run_internal/tool_actions.py +++ b/src/agents/run_internal/tool_actions.py @@ -90,6 +90,26 @@ def _serialize_trace_payload(payload: Any) -> str: return str(payload) +def _local_tool_context( + context_wrapper: RunContextWrapper[Any], + *, + tool_name: str, + tool_call_id: str, + tool_arguments: Any, + agent: Agent[Any], + config: RunConfig, +) -> ToolContext: + """Build hook context for local tool calls that are not function tools.""" + return ToolContext.from_agent_context( + context_wrapper, + tool_call_id, + tool_name=tool_name, + tool_arguments=_serialize_trace_payload(tool_arguments), + agent=agent, + run_config=config, + ) + + class ComputerAction: """Execute computer tool actions and emit screenshot outputs with hooks fired.""" @@ -119,11 +139,19 @@ async def _run_action(span: Any | None) -> RunItem: computer = await resolve_computer( tool=action.computer_tool, run_context=context_wrapper ) + tool_context = _local_tool_context( + context_wrapper, + tool_name=action.computer_tool.name, + tool_call_id=action.tool_call.call_id, + tool_arguments=cls._get_trace_input_payload(action.tool_call), + agent=agent, + config=config, + ) agent_hooks = agent.hooks await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, action.computer_tool), + hooks.on_tool_start(tool_context, agent, action.computer_tool), ( - agent_hooks.on_tool_start(context_wrapper, agent, action.computer_tool) + agent_hooks.on_tool_start(tool_context, agent, action.computer_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -151,9 +179,9 @@ async def _run_action(span: Any | None) -> RunItem: raise await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output), + hooks.on_tool_end(tool_context, agent, action.computer_tool, output), ( - agent_hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output) + agent_hooks.on_tool_end(tool_context, agent, action.computer_tool, output) if agent_hooks else _coro.noop_coroutine() ), @@ -374,10 +402,18 @@ async def execute( ) -> RunItem: """Run a local shell tool call and wrap the result as a ToolCallOutputItem.""" agent_hooks = agent.hooks + tool_context = _local_tool_context( + context_wrapper, + tool_name=call.local_shell_tool.name, + tool_call_id=call.tool_call.call_id, + tool_arguments=call.tool_call.action, + agent=agent, + config=config, + ) await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), + hooks.on_tool_start(tool_context, agent, call.local_shell_tool), ( - agent_hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) + agent_hooks.on_tool_start(tool_context, agent, call.local_shell_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -391,9 +427,9 @@ async def execute( result = await output if inspect.isawaitable(output) else output await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), + hooks.on_tool_end(tool_context, agent, call.local_shell_tool, result), ( - agent_hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) + agent_hooks.on_tool_end(tool_context, agent, call.local_shell_tool, result) if agent_hooks else _coro.noop_coroutine() ), @@ -428,6 +464,14 @@ async def execute( shell_call = coerce_shell_call(call.tool_call) shell_tool = call.shell_tool agent_hooks = agent.hooks + tool_context = _local_tool_context( + context_wrapper, + tool_name=shell_tool.name, + tool_call_id=shell_call.call_id, + tool_arguments=dataclasses.asdict(shell_call.action), + agent=agent, + config=config, + ) async def _run_call(span: Any | None) -> RunItem: if span and config.trace_include_sensitive_data: @@ -467,9 +511,9 @@ async def _run_call(span: Any | None) -> RunItem: return approval_item await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, shell_tool), + hooks.on_tool_start(tool_context, agent, shell_tool), ( - agent_hooks.on_tool_start(context_wrapper, agent, shell_tool) + agent_hooks.on_tool_start(tool_context, agent, shell_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -541,9 +585,9 @@ async def _run_call(span: Any | None) -> RunItem: logger.error("Shell executor failed: %s", exc, exc_info=True) await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text), + hooks.on_tool_end(tool_context, agent, call.shell_tool, output_text), ( - agent_hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text) + agent_hooks.on_tool_end(tool_context, agent, call.shell_tool, output_text) if agent_hooks else _coro.noop_coroutine() ), @@ -747,6 +791,21 @@ async def execute( context_wrapper=context_wrapper, ) call_id = extract_apply_patch_call_id(call.tool_call) + tool_context = _local_tool_context( + context_wrapper, + tool_name=apply_patch_tool.name, + tool_call_id=call_id, + tool_arguments=[ + { + "type": operation.type, + "path": operation.path, + "diff": operation.diff, + } + for operation in operations + ], + agent=agent, + config=config, + ) async def _run_call(span: Any | None) -> RunItem: if span and config.trace_include_sensitive_data: @@ -798,9 +857,9 @@ async def _run_call(span: Any | None) -> RunItem: return approval_item await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, apply_patch_tool), + hooks.on_tool_start(tool_context, agent, apply_patch_tool), ( - agent_hooks.on_tool_start(context_wrapper, agent, apply_patch_tool) + agent_hooks.on_tool_start(tool_context, agent, apply_patch_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -854,9 +913,9 @@ async def _run_call(span: Any | None) -> RunItem: logger.error("Apply patch editor failed: %s", exc, exc_info=True) await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text), + hooks.on_tool_end(tool_context, agent, apply_patch_tool, output_text), ( - agent_hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text) + agent_hooks.on_tool_end(tool_context, agent, apply_patch_tool, output_text) if agent_hooks else _coro.noop_coroutine() ), diff --git a/tests/test_apply_patch_tool.py b/tests/test_apply_patch_tool.py index 1e66312c4a..586648824b 100644 --- a/tests/test_apply_patch_tool.py +++ b/tests/test_apply_patch_tool.py @@ -18,6 +18,7 @@ from agents.editor import ApplyPatchOperation, ApplyPatchResult from agents.items import ToolApprovalItem, ToolCallOutputItem from agents.run_internal.run_loop import ApplyPatchAction, ToolRunApplyPatchCall +from agents.tool_context import ToolContext from .testing_processor import SPAN_PROCESSOR_TESTING from .utils.hitl import ( @@ -83,6 +84,25 @@ def delete_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: return ApplyPatchResult(output=f"Deleted {operation.path}") +class RecordingRunHooks(RunHooks[Any]): + """Capture apply_patch hook contexts.""" + + def __init__(self) -> None: + super().__init__() + self.start_contexts: list[RunContextWrapper[Any]] = [] + self.end_contexts: list[RunContextWrapper[Any]] = [] + + async def on_tool_start( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + ) -> None: + self.start_contexts.append(context) + + async def on_tool_end( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str + ) -> None: + self.end_contexts.append(context) + + @pytest.mark.asyncio async def test_apply_patch_tool_success() -> None: editor = RecordingEditor() @@ -90,11 +110,12 @@ async def test_apply_patch_tool_success() -> None: agent, context_wrapper, tool_run = build_apply_patch_call( tool, "call_apply", {"type": "update_file", "path": "tasks.md", "diff": "-a\n+b\n"} ) + hooks = RecordingRunHooks() result = await ApplyPatchAction.execute( agent=agent, call=tool_run, - hooks=RunHooks[Any](), + hooks=hooks, context_wrapper=context_wrapper, config=RunConfig(), ) @@ -109,6 +130,13 @@ async def test_apply_patch_tool_success() -> None: assert editor.operations[0].ctx_wrapper is context_wrapper assert isinstance(raw_item["output"], str) assert raw_item["output"].startswith("Updated tasks.md") + assert len(hooks.start_contexts) == 1 + assert hooks.start_contexts == hooks.end_contexts + hook_context = hooks.start_contexts[0] + assert isinstance(hook_context, ToolContext) + assert hook_context.tool_call_id == "call_apply" + assert hook_context.tool_name == tool.name + assert json.loads(hook_context.tool_arguments)[0]["path"] == "tasks.md" input_payload = result.to_input_item() assert isinstance(input_payload, dict) payload_dict = cast(dict[str, Any], input_payload) diff --git a/tests/test_computer_action.py b/tests/test_computer_action.py index 3aa908c66c..354e928213 100644 --- a/tests/test_computer_action.py +++ b/tests/test_computer_action.py @@ -47,6 +47,7 @@ from agents.run_internal import run_loop from agents.run_internal.run_loop import ComputerAction, ToolRunComputerAction from agents.tool import ComputerToolSafetyCheckData +from agents.tool_context import ToolContext from .fake_model import FakeModel from .test_responses import get_text_message @@ -501,18 +502,18 @@ class LoggingRunHooks(RunHooks[Any]): def __init__(self) -> None: super().__init__() - self.started: list[tuple[Agent[Any], Any]] = [] - self.ended: list[tuple[Agent[Any], Any, str]] = [] + self.started: list[tuple[RunContextWrapper[Any], Agent[Any], Any]] = [] + self.ended: list[tuple[RunContextWrapper[Any], Agent[Any], Any, str]] = [] async def on_tool_start( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any ) -> None: - self.started.append((agent, tool)) + self.started.append((context, agent, tool)) async def on_tool_end( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str ) -> None: - self.ended.append((agent, tool, result)) + self.ended.append((context, agent, tool, result)) class LoggingAgentHooks(AgentHooks[Any]): @@ -520,18 +521,18 @@ class LoggingAgentHooks(AgentHooks[Any]): def __init__(self) -> None: super().__init__() - self.started: list[tuple[Agent[Any], Any]] = [] - self.ended: list[tuple[Agent[Any], Any, str]] = [] + self.started: list[tuple[RunContextWrapper[Any], Agent[Any], Any]] = [] + self.ended: list[tuple[RunContextWrapper[Any], Agent[Any], Any, str]] = [] async def on_tool_start( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any ) -> None: - self.started.append((agent, tool)) + self.started.append((context, agent, tool)) async def on_tool_end( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str ) -> None: - self.ended.append((agent, tool, result)) + self.ended.append((context, agent, tool, result)) @pytest.mark.asyncio @@ -572,13 +573,25 @@ async def test_execute_invokes_hooks_and_returns_tool_call_output() -> None: assert len(run_hooks.started) == 1 and len(agent_hooks.started) == 1 assert len(run_hooks.ended) == 1 and len(agent_hooks.ended) == 1 # The hook invocations should refer to our agent and tool. - assert run_hooks.started[0][0] is agent - assert run_hooks.ended[0][0] is agent - assert run_hooks.started[0][1] is comptool - assert run_hooks.ended[0][1] is comptool + run_start_context = run_hooks.started[0][0] + run_end_context = run_hooks.ended[0][0] + agent_start_context = agent_hooks.started[0][0] + agent_end_context = agent_hooks.ended[0][0] + assert isinstance(run_start_context, ToolContext) + assert run_start_context is run_end_context + assert isinstance(agent_start_context, ToolContext) + assert agent_start_context is agent_end_context + assert run_start_context.tool_call_id == "tool123" + assert agent_start_context.tool_call_id == "tool123" + assert run_start_context.tool_name == comptool.name + assert json.loads(run_start_context.tool_arguments)["type"] == "click" + assert run_hooks.started[0][1] is agent + assert run_hooks.ended[0][1] is agent + assert run_hooks.started[0][2] is comptool + assert run_hooks.ended[0][2] is comptool # The result passed to on_tool_end should be the raw screenshot string. - assert run_hooks.ended[0][2] == "xyz" - assert agent_hooks.ended[0][2] == "xyz" + assert run_hooks.ended[0][3] == "xyz" + assert agent_hooks.ended[0][3] == "xyz" # The computer should have performed a click then a screenshot. assert computer.calls == [("click", (1, 2, "left")), ("screenshot", ())] # The returned item should include the agent, output string, and a ComputerCallOutput. diff --git a/tests/test_local_shell_tool.py b/tests/test_local_shell_tool.py index cdc0d9a7f1..b4d10ef20f 100644 --- a/tests/test_local_shell_tool.py +++ b/tests/test_local_shell_tool.py @@ -4,6 +4,7 @@ and that Runner.run executes local shell calls and records their outputs. """ +import json from typing import Any, cast import pytest @@ -21,6 +22,7 @@ ) from agents.items import ToolCallOutputItem from agents.run_internal.run_loop import LocalShellAction, ToolRunLocalShellCall +from agents.tool_context import ToolContext from .fake_model import FakeModel from .test_responses import get_text_message @@ -38,6 +40,25 @@ def __call__(self, request: LocalShellCommandRequest) -> str: return self.output +class RecordingRunHooks(RunHooks[Any]): + """Capture local shell hook contexts.""" + + def __init__(self) -> None: + super().__init__() + self.start_contexts: list[RunContextWrapper[Any]] = [] + self.end_contexts: list[RunContextWrapper[Any]] = [] + + async def on_tool_start( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + ) -> None: + self.start_contexts.append(context) + + async def on_tool_end( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str + ) -> None: + self.end_contexts.append(context) + + @pytest.mark.asyncio async def test_local_shell_action_execute_invokes_executor() -> None: executor = RecordingLocalShellExecutor(output="test output") @@ -61,13 +82,15 @@ async def test_local_shell_action_execute_invokes_executor() -> None: tool_run = ToolRunLocalShellCall(tool_call=tool_call, local_shell_tool=tool) agent = Agent(name="test_agent", tools=[tool]) context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + hooks = RecordingRunHooks() + config = RunConfig() output_item = await LocalShellAction.execute( agent=agent, call=tool_run, - hooks=RunHooks[Any](), + hooks=hooks, context_wrapper=context_wrapper, - config=RunConfig(), + config=config, ) assert len(executor.calls) == 1 @@ -79,6 +102,16 @@ async def test_local_shell_action_execute_invokes_executor() -> None: assert request.data.action.env == {"TEST": "value"} assert request.data.action.timeout_ms == 5000 assert request.data.action.working_directory == "/tmp" + assert len(hooks.start_contexts) == 1 + assert hooks.start_contexts == hooks.end_contexts + hook_context = hooks.start_contexts[0] + assert isinstance(hook_context, ToolContext) + assert hook_context.context is context_wrapper.context + assert hook_context.tool_name == tool.name + assert hook_context.tool_call_id == "call_456" + assert hook_context.agent is agent + assert hook_context.run_config is config + assert json.loads(hook_context.tool_arguments)["command"] == ["bash", "-c", "ls"] assert isinstance(output_item, ToolCallOutputItem) assert output_item.agent is agent diff --git a/tests/test_shell_tool.py b/tests/test_shell_tool.py index f9467e2f90..78226fbe2c 100644 --- a/tests/test_shell_tool.py +++ b/tests/test_shell_tool.py @@ -21,6 +21,7 @@ from agents.items import ToolApprovalItem, ToolCallOutputItem from agents.run_internal.run_loop import ShellAction, ToolRunShellCall, execute_shell_calls from agents.tool import ShellOnApprovalFunctionResult +from agents.tool_context import ToolContext from .testing_processor import SPAN_PROCESSOR_TESTING from .utils.hitl import ( @@ -59,6 +60,25 @@ def _shell_call(call_id: str = "call_shell") -> dict[str, Any]: ) +class RecordingRunHooks(RunHooks[Any]): + """Capture shell tool hook contexts.""" + + def __init__(self) -> None: + super().__init__() + self.start_contexts: list[RunContextWrapper[Any]] = [] + self.end_contexts: list[RunContextWrapper[Any]] = [] + + async def on_tool_start( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + ) -> None: + self.start_contexts.append(context) + + async def on_tool_end( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str + ) -> None: + self.end_contexts.append(context) + + def test_shell_tool_defaults_to_local_environment() -> None: shell_tool = ShellTool(executor=lambda request: "ok") @@ -251,11 +271,12 @@ async def test_shell_tool_structured_output_is_rendered() -> None: tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool) agent = Agent(name="shell-agent", tools=[shell_tool]) context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + hooks = RecordingRunHooks() result = await ShellAction.execute( agent=agent, call=tool_run, - hooks=RunHooks[Any](), + hooks=hooks, context_wrapper=context_wrapper, config=RunConfig(), ) @@ -277,6 +298,13 @@ async def test_shell_tool_structured_output_is_rendered() -> None: assert first_output["outcome"]["type"] == "exit" assert first_output["outcome"]["exit_code"] == 0 assert "command" not in first_output + assert len(hooks.start_contexts) == 1 + assert hooks.start_contexts == hooks.end_contexts + hook_context = hooks.start_contexts[0] + assert isinstance(hook_context, ToolContext) + assert hook_context.tool_call_id == "call_shell" + assert hook_context.tool_name == shell_tool.name + assert json.loads(hook_context.tool_arguments)["commands"] == ["echo hi", "ls"] input_payload = result.to_input_item() assert isinstance(input_payload, dict) payload_dict = cast(dict[str, Any], input_payload)