diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 2b45ccaed6..d8465df9c5 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -22,7 +22,12 @@ from ..logger import logger from ..run_config import ToolErrorFormatterArgs from ..run_context import RunContextWrapper, TContext -from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, FunctionTool, invoke_function_tool +from ..tool import ( + DEFAULT_APPROVAL_REJECTION_MESSAGE, + FunctionTool, + default_tool_error_function, + invoke_function_tool, +) from ..tool_context import ToolContext from ..util._approvals import evaluate_needs_approval_setting from .agent import RealtimeAgent @@ -694,11 +699,28 @@ async def _handle_tool_call( tool_arguments=event.arguments, agent=agent, ) - result = await invoke_function_tool( - function_tool=func_tool, - context=tool_context, - arguments=event.arguments, - ) + try: + result = await invoke_function_tool( + function_tool=func_tool, + context=tool_context, + arguments=event.arguments, + ) + except Exception as exc: + error_message = default_tool_error_function(self._context_wrapper, exc) + logger.warning( + "Tool %r raised %s: %s; sending error output to model.", + event.name, + type(exc).__name__, + exc, + ) + await self._model.send_event( + RealtimeModelSendToolOutput( + tool_call=event, + output=error_message, + start_response=True, + ) + ) + return await self._model.send_event( RealtimeModelSendToolOutput( diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 67cf717aa5..57f2666220 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -1106,12 +1106,17 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: arguments="{}", ) - with pytest.raises(ToolTimeoutError, match="timed out"): - await session._handle_tool_call(tool_call_event) + # After the fix: exception is caught, error output is sent to the model, + # and _handle_tool_call returns normally without re-raising. + await session._handle_tool_call(tool_call_event) - assert len(mock_model.sent_tool_outputs) == 0 - assert session._event_queue.qsize() == 1 + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert "error" in sent_output.lower() + assert start_response is True + assert session._event_queue.qsize() == 1 tool_start_event = await session._event_queue.get() assert isinstance(tool_start_event, RealtimeToolStart) assert tool_start_event.tool == timeout_tool @@ -1194,26 +1199,30 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: assert len(tool_call_tasks) == 1 await asyncio.gather(*tool_call_tasks, return_exceptions=True) - assert isinstance(session._stored_exception, ToolTimeoutError) - assert session._stored_exception.tool_name == "slow_tool" - assert len(mock_model.sent_tool_outputs) == 0 + # After the fix: the exception is caught inside _handle_tool_call, an error + # message is sent to the model, and the task completes normally — so + # _stored_exception is NOT set (session stays alive) and the model DOES receive + # a tool output so it can continue rather than hanging for 30+ seconds. + assert session._stored_exception is None + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert "error" in sent_output.lower() + assert start_response is True + # Drain all queued events and verify the basic lifecycle events are present. events = [] - while True: - event = await asyncio.wait_for(session._event_queue.get(), timeout=1) - events.append(event) - if isinstance(event, RealtimeError): - break + while not session._event_queue.empty(): + events.append(session._event_queue.get_nowait()) assert any( isinstance(event, RealtimeRawModelEvent) and event.data == tool_call_event for event in events ) assert any(isinstance(event, RealtimeToolStart) for event in events) - - error_event = next(event for event in events if isinstance(event, RealtimeError)) - assert "Tool call task failed" in error_event.error["message"] - assert "timed out" in error_event.error["message"] + # No RealtimeError should be emitted via the task-done callback since the + # exception was handled gracefully inside _handle_tool_call. + assert not any(isinstance(event, RealtimeError) for event in events) @pytest.mark.asyncio async def test_function_tool_with_multiple_tools_available(self, mock_model, mock_agent): @@ -1605,7 +1614,8 @@ async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str: async def test_function_tool_exception_handling( self, mock_model, mock_agent, mock_function_tool ): - """Test that exceptions in function tools are handled (currently they propagate)""" + """Test that exceptions in function tools are caught and an error message is sent to + the model, so the session doesn't hang waiting for a tool output that never arrives.""" # Set up tool to raise exception mock_function_tool.on_invoke_tool.side_effect = ValueError("Tool error") mock_agent.get_all_tools.return_value = [mock_function_tool] @@ -1616,9 +1626,8 @@ async def test_function_tool_exception_handling( name="test_function", call_id="call_error", arguments="{}" ) - # Currently exceptions propagate (no error handling implemented) - with pytest.raises(ValueError, match="Tool error"): - await session._handle_tool_call(tool_call_event) + # Exception is now caught — _handle_tool_call returns normally. + await session._handle_tool_call(tool_call_event) # Tool start event should have been queued before the error assert session._event_queue.qsize() == 1 @@ -1626,8 +1635,12 @@ async def test_function_tool_exception_handling( assert isinstance(tool_start_event, RealtimeToolStart) assert tool_start_event.arguments == "{}" - # But no tool output should have been sent and no end event queued - assert len(mock_model.sent_tool_outputs) == 0 + # An error tool output IS sent so the model can continue rather than hang. + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert "error" in sent_output.lower() + assert start_response is True @pytest.mark.asyncio async def test_tool_call_with_complex_arguments(