diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e8ea3c9bc..21ff784ca 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -857,7 +857,7 @@ async def stream_async( self._end_agent_trace_span(response=result) - except Exception as e: + except BaseException as e: self._end_agent_trace_span(error=e) raise @@ -1060,7 +1060,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """Ends a trace span for the agent. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index bf1cc7a84..e1e601df1 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -138,6 +138,7 @@ async def event_loop_cycle( custom_trace_attributes=agent.trace_attributes, ) invocation_state["event_loop_cycle_span"] = cycle_span + model_events: AsyncGenerator[TypedEvent, None] | None = None with trace_api.use_span(cycle_span, end_on_exit=False): try: @@ -153,15 +154,21 @@ async def event_loop_cycle( model_events = _handle_model_execution( agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + try: + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + finally: + await model_events.aclose() stop_reason, message, *_ = model_event["stop"] yield ModelMessageEvent(message=message) except Exception as e: tracer.end_span_with_error(cycle_span, str(e), e) raise + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise try: if stop_reason == "max_tokens": @@ -238,6 +245,9 @@ async def event_loop_cycle( yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise async def recurse_event_loop( @@ -323,6 +333,7 @@ async def _handle_model_execution( system_prompt=agent.system_prompt, system_prompt_content=agent._system_prompt_content, ) + streamed_events: AsyncGenerator[TypedEvent, None] | None = None with trace_api.use_span(model_invoke_span, end_on_exit=False): try: await agent.hooks.invoke_callbacks_async( @@ -338,18 +349,22 @@ async def _handle_model_execution( else: tool_specs = agent.tool_registry.get_all_tool_specs() - async for event in stream_messages( - agent.model, - agent.system_prompt, - agent.messages, - tool_specs, - system_prompt_content=agent._system_prompt_content, - tool_choice=structured_output_context.tool_choice, - invocation_state=invocation_state, - model_state=agent._model_state, - cancel_signal=agent._cancel_signal, - ): - yield event + try: + streamed_events = stream_messages( + agent.model, + agent.system_prompt, + agent.messages, + tool_specs, + system_prompt_content=agent._system_prompt_content, + tool_choice=structured_output_context.tool_choice, + invocation_state=invocation_state, + model_state=agent._model_state, + cancel_signal=agent._cancel_signal, + ) + async for event in streamed_events: + yield event + finally: + await streamed_events.aclose() stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -417,6 +432,9 @@ async def _handle_model_execution( # No retry requested, raise the exception yield ForceStopEvent(reason=e) raise e + except BaseException as e: + tracer.end_span_with_error(model_invoke_span, str(e), e) + raise try: # Add message in trace and mark the end of the stream messages trace diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index a422d3cbf..afbd5bc4f 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -185,7 +185,7 @@ def _end_span( self, span: Span, attributes: dict[str, AttributeValue] | None = None, - error: Exception | None = None, + error: BaseException | None = None, error_message: str | None = None, ) -> None: """Generic helper method to end a span. @@ -219,7 +219,7 @@ def _end_span( finally: span.end() - def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None: """End a span with error status. Args: @@ -445,7 +445,9 @@ def start_tool_call_span( return span - def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: + def end_tool_call_span( + self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None + ) -> None: """End a tool call span with results. Args: @@ -644,7 +646,7 @@ def end_agent_span( self, span: Span, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """End an agent span with results and metrics. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 1e27274a1..d483db573 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1425,6 +1425,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist): + """Test that stream_async ends the agent span when a BaseException occurs.""" + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + test_exception = KeyboardInterrupt("stop now") + mock_model.mock_stream.side_effect = test_exception + + agent = Agent(model=mock_model) + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = agent.stream_async("test prompt") + await alist(stream) + + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + def test_agent_init_with_state_object(): agent = Agent(state=AgentState({"foo": "bar"})) assert agent.state.get("foo") == "bar" diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 871371f5f..f40231432 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -685,6 +685,121 @@ async def test_event_loop_tracing_with_tool_execution( assert mock_tracer.end_model_invoke_span.call_count == 2 +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_stream_aclose( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(10) + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await anext(stream) + await anext(stream) + await anext(stream) + await stream.aclose() + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_task_cancellation( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + blocked_on_stream = asyncio.Event() + release_stream = asyncio.Event() + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + blocked_on_stream.set() + await release_stream.wait() + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + async def consume() -> None: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + async for _ in stream: + pass + + task = asyncio.create_task(consume()) + await blocked_on_stream.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt( + mock_get_tracer, + agent, + model, + mock_tracer, + alist, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + test_exception = KeyboardInterrupt("stop now") + model.stream.side_effect = test_exception + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + assert mock_tracer.end_span_with_error.call_args_list == [ + call(model_span, "stop now", test_exception), + call(cycle_span, "stop now", test_exception), + ] + + @pytest.mark.asyncio async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle( agent, diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 8af7b782e..163a6741e 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -141,6 +141,18 @@ def test_end_span_with_empty_exception_message_uses_exception_name(mock_span): mock_span.end.assert_called_once() +def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span): + """Test that empty BaseException messages fall back to the exception type name.""" + tracer = Tracer() + error = KeyboardInterrupt() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_span_with_error_prefers_explicit_message(mock_span): """Test that an explicit error message takes precedence over the exception text.""" tracer = Tracer() @@ -1221,6 +1233,45 @@ def test_end_span_with_exception_handling(mock_span): pytest.fail("_end_span should not raise exceptions") +def test_force_flush_with_error(mock_span, mock_get_tracer_provider): + """Test force flush with error handling.""" + # Setup the tracer with a provider that raises an exception on force_flush + tracer = Tracer() + + mock_tracer_provider = mock_get_tracer_provider.return_value + mock_tracer_provider.force_flush.side_effect = Exception("Force flush error") + + # Should not raise an exception + tracer._end_span(mock_span) + + # Verify force_flush was called + mock_tracer_provider.force_flush.assert_called_once() + + +def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that agent spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_agent_span(mock_span, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that tool call spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_tool_call_span(mock_span, None, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_tool_call_span_with_none(mock_span): """Test ending a tool call span with None result.""" tracer = Tracer()