Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ async def on_tool_start(
) -> None:
"""Called immediately before a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
Local tool invocations with a concrete tool call receive a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
and ``tool_arguments``.
"""
pass

Expand All @@ -91,10 +90,9 @@ async def on_tool_end(
) -> None:
"""Called immediately after a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
Local tool invocations with a concrete tool call receive a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
and ``tool_arguments``.
"""
pass

Expand Down Expand Up @@ -149,10 +147,9 @@ async def on_tool_start(
) -> None:
"""Called immediately before a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
Local tool invocations with a concrete tool call receive a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
and ``tool_arguments``.
"""
pass

Expand All @@ -165,10 +162,9 @@ async def on_tool_end(
) -> None:
"""Called immediately after a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
Local tool invocations with a concrete tool call receive a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
and ``tool_arguments``.
"""
pass

Expand Down
91 changes: 75 additions & 16 deletions src/agents/run_internal/tool_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ def _serialize_trace_payload(payload: Any) -> str:
return str(payload)


def _local_tool_context(
context_wrapper: RunContextWrapper[Any],
*,
tool_name: str,
tool_call_id: str,
tool_arguments: Any,
agent: Agent[Any],
config: RunConfig,
) -> ToolContext:
"""Build hook context for local tool calls that are not function tools."""
return ToolContext.from_agent_context(
context_wrapper,
tool_call_id,
tool_name=tool_name,
tool_arguments=_serialize_trace_payload(tool_arguments),
agent=agent,
run_config=config,
)


class ComputerAction:
"""Execute computer tool actions and emit screenshot outputs with hooks fired."""

Expand Down Expand Up @@ -119,11 +139,19 @@ async def _run_action(span: Any | None) -> RunItem:
computer = await resolve_computer(
tool=action.computer_tool, run_context=context_wrapper
)
tool_context = _local_tool_context(
context_wrapper,
tool_name=action.computer_tool.name,
tool_call_id=action.tool_call.call_id,
tool_arguments=cls._get_trace_input_payload(action.tool_call),
agent=agent,
config=config,
)
agent_hooks = agent.hooks
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
hooks.on_tool_start(tool_context, agent, action.computer_tool),
(
agent_hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
agent_hooks.on_tool_start(tool_context, agent, action.computer_tool)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -151,9 +179,9 @@ async def _run_action(span: Any | None) -> RunItem:
raise

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
hooks.on_tool_end(tool_context, agent, action.computer_tool, output),
(
agent_hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
agent_hooks.on_tool_end(tool_context, agent, action.computer_tool, output)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -374,10 +402,18 @@ async def execute(
) -> RunItem:
"""Run a local shell tool call and wrap the result as a ToolCallOutputItem."""
agent_hooks = agent.hooks
tool_context = _local_tool_context(
context_wrapper,
tool_name=call.local_shell_tool.name,
tool_call_id=call.tool_call.call_id,
tool_arguments=call.tool_call.action,
agent=agent,
config=config,
)
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
hooks.on_tool_start(tool_context, agent, call.local_shell_tool),
(
agent_hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
agent_hooks.on_tool_start(tool_context, agent, call.local_shell_tool)
if agent_hooks
else _coro.noop_coroutine()
),
Expand All @@ -391,9 +427,9 @@ async def execute(
result = await output if inspect.isawaitable(output) else output

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
hooks.on_tool_end(tool_context, agent, call.local_shell_tool, result),
(
agent_hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
agent_hooks.on_tool_end(tool_context, agent, call.local_shell_tool, result)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -428,6 +464,14 @@ async def execute(
shell_call = coerce_shell_call(call.tool_call)
shell_tool = call.shell_tool
agent_hooks = agent.hooks
tool_context = _local_tool_context(
context_wrapper,
tool_name=shell_tool.name,
tool_call_id=shell_call.call_id,
tool_arguments=dataclasses.asdict(shell_call.action),
agent=agent,
config=config,
)

async def _run_call(span: Any | None) -> RunItem:
if span and config.trace_include_sensitive_data:
Expand Down Expand Up @@ -467,9 +511,9 @@ async def _run_call(span: Any | None) -> RunItem:
return approval_item

await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, shell_tool),
hooks.on_tool_start(tool_context, agent, shell_tool),
(
agent_hooks.on_tool_start(context_wrapper, agent, shell_tool)
agent_hooks.on_tool_start(tool_context, agent, shell_tool)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -541,9 +585,9 @@ async def _run_call(span: Any | None) -> RunItem:
logger.error("Shell executor failed: %s", exc, exc_info=True)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text),
hooks.on_tool_end(tool_context, agent, call.shell_tool, output_text),
(
agent_hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text)
agent_hooks.on_tool_end(tool_context, agent, call.shell_tool, output_text)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -747,6 +791,21 @@ async def execute(
context_wrapper=context_wrapper,
)
call_id = extract_apply_patch_call_id(call.tool_call)
tool_context = _local_tool_context(
context_wrapper,
tool_name=apply_patch_tool.name,
tool_call_id=call_id,
tool_arguments=[
{
"type": operation.type,
"path": operation.path,
"diff": operation.diff,
}
for operation in operations
],
agent=agent,
config=config,
)

async def _run_call(span: Any | None) -> RunItem:
if span and config.trace_include_sensitive_data:
Expand Down Expand Up @@ -798,9 +857,9 @@ async def _run_call(span: Any | None) -> RunItem:
return approval_item

await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, apply_patch_tool),
hooks.on_tool_start(tool_context, agent, apply_patch_tool),
(
agent_hooks.on_tool_start(context_wrapper, agent, apply_patch_tool)
agent_hooks.on_tool_start(tool_context, agent, apply_patch_tool)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -854,9 +913,9 @@ async def _run_call(span: Any | None) -> RunItem:
logger.error("Apply patch editor failed: %s", exc, exc_info=True)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text),
hooks.on_tool_end(tool_context, agent, apply_patch_tool, output_text),
(
agent_hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text)
agent_hooks.on_tool_end(tool_context, agent, apply_patch_tool, output_text)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down
30 changes: 29 additions & 1 deletion tests/test_apply_patch_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from agents.editor import ApplyPatchOperation, ApplyPatchResult
from agents.items import ToolApprovalItem, ToolCallOutputItem
from agents.run_internal.run_loop import ApplyPatchAction, ToolRunApplyPatchCall
from agents.tool_context import ToolContext

from .testing_processor import SPAN_PROCESSOR_TESTING
from .utils.hitl import (
Expand Down Expand Up @@ -83,18 +84,38 @@ def delete_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult:
return ApplyPatchResult(output=f"Deleted {operation.path}")


class RecordingRunHooks(RunHooks[Any]):
"""Capture apply_patch hook contexts."""

def __init__(self) -> None:
super().__init__()
self.start_contexts: list[RunContextWrapper[Any]] = []
self.end_contexts: list[RunContextWrapper[Any]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
) -> None:
self.start_contexts.append(context)

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
) -> None:
self.end_contexts.append(context)


@pytest.mark.asyncio
async def test_apply_patch_tool_success() -> None:
editor = RecordingEditor()
tool = ApplyPatchTool(editor=editor)
agent, context_wrapper, tool_run = build_apply_patch_call(
tool, "call_apply", {"type": "update_file", "path": "tasks.md", "diff": "-a\n+b\n"}
)
hooks = RecordingRunHooks()

result = await ApplyPatchAction.execute(
agent=agent,
call=tool_run,
hooks=RunHooks[Any](),
hooks=hooks,
context_wrapper=context_wrapper,
config=RunConfig(),
)
Expand All @@ -109,6 +130,13 @@ async def test_apply_patch_tool_success() -> None:
assert editor.operations[0].ctx_wrapper is context_wrapper
assert isinstance(raw_item["output"], str)
assert raw_item["output"].startswith("Updated tasks.md")
assert len(hooks.start_contexts) == 1
assert hooks.start_contexts == hooks.end_contexts
hook_context = hooks.start_contexts[0]
assert isinstance(hook_context, ToolContext)
assert hook_context.tool_call_id == "call_apply"
assert hook_context.tool_name == tool.name
assert json.loads(hook_context.tool_arguments)[0]["path"] == "tasks.md"
input_payload = result.to_input_item()
assert isinstance(input_payload, dict)
payload_dict = cast(dict[str, Any], input_payload)
Expand Down
41 changes: 27 additions & 14 deletions tests/test_computer_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from agents.run_internal import run_loop
from agents.run_internal.run_loop import ComputerAction, ToolRunComputerAction
from agents.tool import ComputerToolSafetyCheckData
from agents.tool_context import ToolContext

from .fake_model import FakeModel
from .test_responses import get_text_message
Expand Down Expand Up @@ -501,37 +502,37 @@ class LoggingRunHooks(RunHooks[Any]):

def __init__(self) -> None:
super().__init__()
self.started: list[tuple[Agent[Any], Any]] = []
self.ended: list[tuple[Agent[Any], Any, str]] = []
self.started: list[tuple[RunContextWrapper[Any], Agent[Any], Any]] = []
self.ended: list[tuple[RunContextWrapper[Any], Agent[Any], Any, str]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
) -> None:
self.started.append((agent, tool))
self.started.append((context, agent, tool))

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
) -> None:
self.ended.append((agent, tool, result))
self.ended.append((context, agent, tool, result))


class LoggingAgentHooks(AgentHooks[Any]):
"""Minimal override to capture agent's tool hook invocations."""

def __init__(self) -> None:
super().__init__()
self.started: list[tuple[Agent[Any], Any]] = []
self.ended: list[tuple[Agent[Any], Any, str]] = []
self.started: list[tuple[RunContextWrapper[Any], Agent[Any], Any]] = []
self.ended: list[tuple[RunContextWrapper[Any], Agent[Any], Any, str]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
) -> None:
self.started.append((agent, tool))
self.started.append((context, agent, tool))

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
) -> None:
self.ended.append((agent, tool, result))
self.ended.append((context, agent, tool, result))


@pytest.mark.asyncio
Expand Down Expand Up @@ -572,13 +573,25 @@ async def test_execute_invokes_hooks_and_returns_tool_call_output() -> None:
assert len(run_hooks.started) == 1 and len(agent_hooks.started) == 1
assert len(run_hooks.ended) == 1 and len(agent_hooks.ended) == 1
# The hook invocations should refer to our agent and tool.
assert run_hooks.started[0][0] is agent
assert run_hooks.ended[0][0] is agent
assert run_hooks.started[0][1] is comptool
assert run_hooks.ended[0][1] is comptool
run_start_context = run_hooks.started[0][0]
run_end_context = run_hooks.ended[0][0]
agent_start_context = agent_hooks.started[0][0]
agent_end_context = agent_hooks.ended[0][0]
assert isinstance(run_start_context, ToolContext)
assert run_start_context is run_end_context
assert isinstance(agent_start_context, ToolContext)
assert agent_start_context is agent_end_context
assert run_start_context.tool_call_id == "tool123"
assert agent_start_context.tool_call_id == "tool123"
assert run_start_context.tool_name == comptool.name
assert json.loads(run_start_context.tool_arguments)["type"] == "click"
assert run_hooks.started[0][1] is agent
assert run_hooks.ended[0][1] is agent
assert run_hooks.started[0][2] is comptool
assert run_hooks.ended[0][2] is comptool
# The result passed to on_tool_end should be the raw screenshot string.
assert run_hooks.ended[0][2] == "xyz"
assert agent_hooks.ended[0][2] == "xyz"
assert run_hooks.ended[0][3] == "xyz"
assert agent_hooks.ended[0][3] == "xyz"
# The computer should have performed a click then a screenshot.
assert computer.calls == [("click", (1, 2, "left")), ("screenshot", ())]
# The returned item should include the agent, output string, and a ComputerCallOutput.
Expand Down
Loading
Loading