diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 2b45ccaed6..1db65930ac 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -574,6 +574,18 @@ async def _send_tool_rejection( ) ) + async def _send_tool_failure_output( + self, event: RealtimeModelToolCallEvent, error: Exception + ) -> None: + """Complete a failed known tool call with model-visible output.""" + await self._model.send_event( + RealtimeModelSendToolOutput( + tool_call=event, + output=f"Tool {event.name} failed: {error}", + start_response=True, + ) + ) + async def _resolve_approval_rejection_message(self, *, tool: FunctionTool, call_id: str) -> str: """Resolve model-visible output text for approval rejections.""" explicit_message = self._context_wrapper.get_rejection_message( @@ -694,11 +706,15 @@ async def _handle_tool_call( tool_arguments=event.arguments, agent=agent, ) - result = await invoke_function_tool( - function_tool=func_tool, - context=tool_context, - arguments=event.arguments, - ) + try: + result = await invoke_function_tool( + function_tool=func_tool, + context=tool_context, + arguments=event.arguments, + ) + except Exception as exc: + await self._send_tool_failure_output(event, exc) + raise await self._model.send_event( RealtimeModelSendToolOutput( @@ -729,11 +745,15 @@ async def _handle_tool_call( ) # Execute the handoff to get the new agent - result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) - if not isinstance(result, RealtimeAgent): - raise UserError( - f"Handoff {handoff.tool_name} returned invalid result: {type(result)}" - ) + try: + result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) + if not isinstance(result, RealtimeAgent): + raise UserError( + f"Handoff {handoff.tool_name} returned invalid result: {type(result)}" + ) + except Exception as exc: + await self._send_tool_failure_output(event, exc) + raise # Store previous agent for event previous_agent = agent diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 67cf717aa5..2f49665497 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -1109,7 +1109,13 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: with pytest.raises(ToolTimeoutError, match="timed out"): await session._handle_tool_call(tool_call_event) - assert len(mock_model.sent_tool_outputs) == 0 + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert "Tool slow_tool failed" in sent_output + assert "timed out" in sent_output + assert start_response is True + assert session._event_queue.qsize() == 1 tool_start_event = await session._event_queue.get() @@ -1196,7 +1202,12 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: assert isinstance(session._stored_exception, ToolTimeoutError) assert session._stored_exception.tool_name == "slow_tool" - assert len(mock_model.sent_tool_outputs) == 0 + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert "Tool slow_tool failed" in sent_output + assert "timed out" in sent_output + assert start_response is True events = [] while True: @@ -1215,6 +1226,58 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: assert "Tool call task failed" in error_event.error["message"] assert "timed out" in error_event.error["message"] + @pytest.mark.asyncio + async def test_function_call_event_exception_sends_model_visible_output( + self, mock_model, mock_agent + ): + async def invoke_failing_tool(_ctx: ToolContext[Any], _arguments: str) -> str: + raise ValueError("tool failed") + + failing_tool = FunctionTool( + name="failing_tool", + description="fails", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=invoke_failing_tool, + ) + mock_agent.get_all_tools.return_value = [failing_tool] + + session = RealtimeSession(mock_model, mock_agent, None) + tool_call_event = RealtimeModelToolCallEvent( + name="failing_tool", + call_id="call_fails_async", + arguments="{}", + ) + + await session.on_event(tool_call_event) + + tool_call_tasks = list(session._tool_call_tasks) + assert len(tool_call_tasks) == 1 + await asyncio.gather(*tool_call_tasks, return_exceptions=True) + + assert isinstance(session._stored_exception, ValueError) + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert sent_output == "Tool failing_tool failed: tool failed" + assert start_response is True + + events = [] + while True: + event = await asyncio.wait_for(session._event_queue.get(), timeout=1) + events.append(event) + if isinstance(event, RealtimeError): + break + + assert any( + isinstance(event, RealtimeRawModelEvent) and event.data == tool_call_event + for event in events + ) + assert any(isinstance(event, RealtimeToolStart) for event in events) + + error_event = next(event for event in events if isinstance(event, RealtimeError)) + assert "Tool call task failed" in error_event.error["message"] + assert "tool failed" in error_event.error["message"] + @pytest.mark.asyncio async def test_function_tool_with_multiple_tools_available(self, mock_model, mock_agent): """Test function tool execution when multiple tools are available""" @@ -1286,6 +1349,37 @@ async def test_handoff_tool_handling(self, mock_model): # Verify agent was updated assert session._current_agent == second_agent + @pytest.mark.asyncio + async def test_handoff_tool_failure_sends_model_visible_output(self, mock_model): + target_agent = RealtimeAgent(name="target_agent") + failing_handoff = Handoff( + tool_name="switch", + tool_description="Switch agents", + input_json_schema={}, + on_invoke_handoff=AsyncMock(side_effect=RuntimeError("handoff failed")), + input_filter=None, + agent_name=target_agent.name, + is_enabled=True, + ) + agent = RealtimeAgent(name="agent", handoffs=[failing_handoff]) + session = RealtimeSession(mock_model, agent, None) + + tool_call_event = RealtimeModelToolCallEvent( + name="switch", + call_id="call_handoff_failed", + arguments="{}", + ) + + with pytest.raises(RuntimeError, match="handoff failed"): + await session._handle_tool_call(tool_call_event) + + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert sent_output == "Tool switch failed: handoff failed" + assert start_response is True + assert session._current_agent == agent + @pytest.mark.asyncio async def test_unknown_tool_handling(self, mock_model, mock_agent, mock_function_tool): """Test that unknown tools complete the model call without starting a response.""" @@ -1605,7 +1699,7 @@ async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str: async def test_function_tool_exception_handling( self, mock_model, mock_agent, mock_function_tool ): - """Test that exceptions in function tools are handled (currently they propagate)""" + """Test that function tool exceptions are sent to the model and then propagated.""" # Set up tool to raise exception mock_function_tool.on_invoke_tool.side_effect = ValueError("Tool error") mock_agent.get_all_tools.return_value = [mock_function_tool] @@ -1616,7 +1710,6 @@ async def test_function_tool_exception_handling( name="test_function", call_id="call_error", arguments="{}" ) - # Currently exceptions propagate (no error handling implemented) with pytest.raises(ValueError, match="Tool error"): await session._handle_tool_call(tool_call_event) @@ -1626,8 +1719,11 @@ async def test_function_tool_exception_handling( assert isinstance(tool_start_event, RealtimeToolStart) assert tool_start_event.arguments == "{}" - # But no tool output should have been sent and no end event queued - assert len(mock_model.sent_tool_outputs) == 0 + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert sent_output == "Tool test_function failed: Tool error" + assert start_response is True @pytest.mark.asyncio async def test_tool_call_with_complex_arguments(