Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,18 @@ async def _send_tool_rejection(
)
)

async def _send_tool_failure_output(
self, event: RealtimeModelToolCallEvent, error: Exception
) -> None:
"""Complete a failed known tool call with model-visible output."""
await self._model.send_event(
RealtimeModelSendToolOutput(
tool_call=event,
output=f"Tool {event.name} failed: {error}",
start_response=True,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a breaking change

)
)

async def _resolve_approval_rejection_message(self, *, tool: FunctionTool, call_id: str) -> str:
"""Resolve model-visible output text for approval rejections."""
explicit_message = self._context_wrapper.get_rejection_message(
Expand Down Expand Up @@ -694,11 +706,15 @@ async def _handle_tool_call(
tool_arguments=event.arguments,
agent=agent,
)
result = await invoke_function_tool(
function_tool=func_tool,
context=tool_context,
arguments=event.arguments,
)
try:
result = await invoke_function_tool(
function_tool=func_tool,
context=tool_context,
arguments=event.arguments,
)
except Exception as exc:
await self._send_tool_failure_output(event, exc)
raise

await self._model.send_event(
RealtimeModelSendToolOutput(
Expand Down Expand Up @@ -729,11 +745,15 @@ async def _handle_tool_call(
)

# Execute the handoff to get the new agent
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
if not isinstance(result, RealtimeAgent):
raise UserError(
f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
)
try:
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
if not isinstance(result, RealtimeAgent):
raise UserError(
f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
)
except Exception as exc:
await self._send_tool_failure_output(event, exc)
raise

# Store previous agent for event
previous_agent = agent
Expand Down
108 changes: 102 additions & 6 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,13 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
with pytest.raises(ToolTimeoutError, match="timed out"):
await session._handle_tool_call(tool_call_event)

assert len(mock_model.sent_tool_outputs) == 0
assert len(mock_model.sent_tool_outputs) == 1
sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0]
assert sent_call == tool_call_event
assert "Tool slow_tool failed" in sent_output
assert "timed out" in sent_output
assert start_response is True

assert session._event_queue.qsize() == 1

tool_start_event = await session._event_queue.get()
Expand Down Expand Up @@ -1196,7 +1202,12 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str:

assert isinstance(session._stored_exception, ToolTimeoutError)
assert session._stored_exception.tool_name == "slow_tool"
assert len(mock_model.sent_tool_outputs) == 0
assert len(mock_model.sent_tool_outputs) == 1
sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0]
assert sent_call == tool_call_event
assert "Tool slow_tool failed" in sent_output
assert "timed out" in sent_output
assert start_response is True

events = []
while True:
Expand All @@ -1215,6 +1226,58 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
assert "Tool call task failed" in error_event.error["message"]
assert "timed out" in error_event.error["message"]

@pytest.mark.asyncio
async def test_function_call_event_exception_sends_model_visible_output(
self, mock_model, mock_agent
):
async def invoke_failing_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
raise ValueError("tool failed")

failing_tool = FunctionTool(
name="failing_tool",
description="fails",
params_json_schema={"type": "object", "properties": {}},
on_invoke_tool=invoke_failing_tool,
)
mock_agent.get_all_tools.return_value = [failing_tool]

session = RealtimeSession(mock_model, mock_agent, None)
tool_call_event = RealtimeModelToolCallEvent(
name="failing_tool",
call_id="call_fails_async",
arguments="{}",
)

await session.on_event(tool_call_event)

tool_call_tasks = list(session._tool_call_tasks)
assert len(tool_call_tasks) == 1
await asyncio.gather(*tool_call_tasks, return_exceptions=True)

assert isinstance(session._stored_exception, ValueError)
assert len(mock_model.sent_tool_outputs) == 1
sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0]
assert sent_call == tool_call_event
assert sent_output == "Tool failing_tool failed: tool failed"
assert start_response is True

events = []
while True:
event = await asyncio.wait_for(session._event_queue.get(), timeout=1)
events.append(event)
if isinstance(event, RealtimeError):
break

assert any(
isinstance(event, RealtimeRawModelEvent) and event.data == tool_call_event
for event in events
)
assert any(isinstance(event, RealtimeToolStart) for event in events)

error_event = next(event for event in events if isinstance(event, RealtimeError))
assert "Tool call task failed" in error_event.error["message"]
assert "tool failed" in error_event.error["message"]

@pytest.mark.asyncio
async def test_function_tool_with_multiple_tools_available(self, mock_model, mock_agent):
"""Test function tool execution when multiple tools are available"""
Expand Down Expand Up @@ -1286,6 +1349,37 @@ async def test_handoff_tool_handling(self, mock_model):
# Verify agent was updated
assert session._current_agent == second_agent

@pytest.mark.asyncio
async def test_handoff_tool_failure_sends_model_visible_output(self, mock_model):
target_agent = RealtimeAgent(name="target_agent")
failing_handoff = Handoff(
tool_name="switch",
tool_description="Switch agents",
input_json_schema={},
on_invoke_handoff=AsyncMock(side_effect=RuntimeError("handoff failed")),
input_filter=None,
agent_name=target_agent.name,
is_enabled=True,
)
agent = RealtimeAgent(name="agent", handoffs=[failing_handoff])
session = RealtimeSession(mock_model, agent, None)

tool_call_event = RealtimeModelToolCallEvent(
name="switch",
call_id="call_handoff_failed",
arguments="{}",
)

with pytest.raises(RuntimeError, match="handoff failed"):
await session._handle_tool_call(tool_call_event)

assert len(mock_model.sent_tool_outputs) == 1
sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0]
assert sent_call == tool_call_event
assert sent_output == "Tool switch failed: handoff failed"
assert start_response is True
assert session._current_agent == agent

@pytest.mark.asyncio
async def test_unknown_tool_handling(self, mock_model, mock_agent, mock_function_tool):
"""Test that unknown tools complete the model call without starting a response."""
Expand Down Expand Up @@ -1605,7 +1699,7 @@ async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
async def test_function_tool_exception_handling(
self, mock_model, mock_agent, mock_function_tool
):
"""Test that exceptions in function tools are handled (currently they propagate)"""
"""Test that function tool exceptions are sent to the model and then propagated."""
# Set up tool to raise exception
mock_function_tool.on_invoke_tool.side_effect = ValueError("Tool error")
mock_agent.get_all_tools.return_value = [mock_function_tool]
Expand All @@ -1616,7 +1710,6 @@ async def test_function_tool_exception_handling(
name="test_function", call_id="call_error", arguments="{}"
)

# Currently exceptions propagate (no error handling implemented)
with pytest.raises(ValueError, match="Tool error"):
await session._handle_tool_call(tool_call_event)

Expand All @@ -1626,8 +1719,11 @@ async def test_function_tool_exception_handling(
assert isinstance(tool_start_event, RealtimeToolStart)
assert tool_start_event.arguments == "{}"

# But no tool output should have been sent and no end event queued
assert len(mock_model.sent_tool_outputs) == 0
assert len(mock_model.sent_tool_outputs) == 1
sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0]
assert sent_call == tool_call_event
assert sent_output == "Tool test_function failed: Tool error"
assert start_response is True

@pytest.mark.asyncio
async def test_tool_call_with_complex_arguments(
Expand Down