Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ async def stream_async(

self._end_agent_trace_span(response=result)

except Exception as e:
except BaseException as e:
self._end_agent_trace_span(error=e)
raise

Expand Down Expand Up @@ -1060,7 +1060,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
def _end_agent_trace_span(
self,
response: AgentResult | None = None,
error: Exception | None = None,
error: BaseException | None = None,
) -> None:
"""Ends a trace span for the agent.

Expand Down
48 changes: 33 additions & 15 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def event_loop_cycle(
custom_trace_attributes=agent.trace_attributes,
)
invocation_state["event_loop_cycle_span"] = cycle_span
model_events: AsyncGenerator[TypedEvent, None] | None = None

with trace_api.use_span(cycle_span, end_on_exit=False):
try:
Expand All @@ -153,15 +154,21 @@ async def event_loop_cycle(
model_events = _handle_model_execution(
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
try:
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
finally:
await model_events.aclose()

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
except Exception as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise
except BaseException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: The except Exception and except BaseException handlers here have identical bodies — both call tracer.end_span_with_error and re-raise. Since Exception is a subclass of BaseException, the except BaseException handler alone would catch both.

Suggestion: Consolidate into a single handler:

except BaseException as e:
    tracer.end_span_with_error(cycle_span, str(e), e)
    raise

Note: The other two locations (lines 233-250 and 398-430) are correctly separated since the except Exception handlers do additional work (wrapping in EventLoopException, retry logic, etc.) that should not apply to BaseException subclasses.


try:
if stop_reason == "max_tokens":
Expand Down Expand Up @@ -238,6 +245,9 @@ async def event_loop_cycle(
yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e
except BaseException as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise


async def recurse_event_loop(
Expand Down Expand Up @@ -323,6 +333,7 @@ async def _handle_model_execution(
system_prompt=agent.system_prompt,
system_prompt_content=agent._system_prompt_content,
)
streamed_events: AsyncGenerator[TypedEvent, None] | None = None
with trace_api.use_span(model_invoke_span, end_on_exit=False):
try:
await agent.hooks.invoke_callbacks_async(
Expand All @@ -338,18 +349,22 @@ async def _handle_model_execution(
else:
tool_specs = agent.tool_registry.get_all_tool_specs()

async for event in stream_messages(
agent.model,
agent.system_prompt,
agent.messages,
tool_specs,
system_prompt_content=agent._system_prompt_content,
tool_choice=structured_output_context.tool_choice,
invocation_state=invocation_state,
model_state=agent._model_state,
cancel_signal=agent._cancel_signal,
):
yield event
try:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it be better to move try on top of streamed_events = stream_messages(...)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it.

streamed_events = stream_messages(
agent.model,
agent.system_prompt,
agent.messages,
tool_specs,
system_prompt_content=agent._system_prompt_content,
tool_choice=structured_output_context.tool_choice,
invocation_state=invocation_state,
model_state=agent._model_state,
cancel_signal=agent._cancel_signal,
)
async for event in streamed_events:
yield event
finally:
await streamed_events.aclose()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: streamed_events is initialized to None (line 336), but the finally block unconditionally calls await streamed_events.aclose(). If stream_messages() raises before returning the generator object (e.g., due to a future signature change or dynamic error), this would raise AttributeError: 'NoneType' object has no attribute 'aclose' — masking the original exception.

Currently this is safe because stream_messages is an async generator function (calling it just creates the object without executing any code), but the code is fragile.

Suggestion: Add a guard:

finally:
    if streamed_events is not None:
        await streamed_events.aclose()

The same pattern at line 162 for model_events has the same concern, though it's slightly safer because model_events is assigned before the try block.


stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})
Expand Down Expand Up @@ -417,6 +432,9 @@ async def _handle_model_execution(
# No retry requested, raise the exception
yield ForceStopEvent(reason=e)
raise e
except BaseException as e:
tracer.end_span_with_error(model_invoke_span, str(e), e)
raise

try:
# Add message in trace and mark the end of the stream messages trace
Expand Down
10 changes: 6 additions & 4 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _end_span(
self,
span: Span,
attributes: dict[str, AttributeValue] | None = None,
error: Exception | None = None,
error: BaseException | None = None,
error_message: str | None = None,
) -> None:
"""Generic helper method to end a span.
Expand Down Expand Up @@ -219,7 +219,7 @@ def _end_span(
finally:
span.end()

def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None:
def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None:
"""End a span with error status.

Args:
Expand Down Expand Up @@ -445,7 +445,9 @@ def start_tool_call_span(

return span

def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None:
def end_tool_call_span(
self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None
) -> None:
"""End a tool call span with results.

Args:
Expand Down Expand Up @@ -644,7 +646,7 @@ def end_agent_span(
self,
span: Span,
response: AgentResult | None = None,
error: Exception | None = None,
error: BaseException | None = None,
) -> None:
"""End an agent span with results and metrics.

Expand Down
21 changes: 21 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)


@pytest.mark.asyncio
@unittest.mock.patch("strands.agent.agent.get_tracer")
async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist):
"""Test that stream_async ends the agent span when a BaseException occurs."""
mock_tracer = unittest.mock.MagicMock()
mock_span = unittest.mock.MagicMock()
mock_tracer.start_agent_span.return_value = mock_span
mock_get_tracer.return_value = mock_tracer

test_exception = KeyboardInterrupt("stop now")
mock_model.mock_stream.side_effect = test_exception

agent = Agent(model=mock_model)

with pytest.raises(KeyboardInterrupt, match="stop now"):
stream = agent.stream_async("test prompt")
await alist(stream)

mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)


def test_agent_init_with_state_object():
agent = Agent(state=AgentState({"foo": "bar"}))
assert agent.state.get("foo") == "bar"
Expand Down
115 changes: 115 additions & 0 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,121 @@ async def test_event_loop_tracing_with_tool_execution(
assert mock_tracer.end_model_invoke_span.call_count == 2


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_stream_aclose(
mock_get_tracer,
agent,
model,
mock_tracer,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

async def interrupted_stream():
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
await asyncio.sleep(10)
yield {"contentBlockStop": {}}

model.stream.return_value = interrupted_stream()

stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
await anext(stream)
await anext(stream)
await anext(stream)
await stream.aclose()

assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span]
assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [
"",
"",
]


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_task_cancellation(
mock_get_tracer,
agent,
model,
mock_tracer,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

blocked_on_stream = asyncio.Event()
release_stream = asyncio.Event()

async def interrupted_stream():
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
blocked_on_stream.set()
await release_stream.wait()
yield {"contentBlockStop": {}}

model.stream.return_value = interrupted_stream()

async def consume() -> None:
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
async for _ in stream:
pass

task = asyncio.create_task(consume())
await blocked_on_stream.wait()
task.cancel()

with pytest.raises(asyncio.CancelledError):
await task

assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span]
assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [
"",
"",
]


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt(
mock_get_tracer,
agent,
model,
mock_tracer,
alist,
):
mock_get_tracer.return_value = mock_tracer
cycle_span = MagicMock()
mock_tracer.start_event_loop_cycle_span.return_value = cycle_span
model_span = MagicMock()
mock_tracer.start_model_invoke_span.return_value = model_span

test_exception = KeyboardInterrupt("stop now")
model.stream.side_effect = test_exception

with pytest.raises(KeyboardInterrupt, match="stop now"):
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
await alist(stream)

assert mock_tracer.end_span_with_error.call_args_list == [
call(model_span, "stop now", test_exception),
call(cycle_span, "stop now", test_exception),
]


@pytest.mark.asyncio
async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle(
agent,
Expand Down
51 changes: 51 additions & 0 deletions tests/strands/telemetry/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,18 @@ def test_end_span_with_empty_exception_message_uses_exception_name(mock_span):
mock_span.end.assert_called_once()


def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span):
"""Test that empty BaseException messages fall back to the exception type name."""
tracer = Tracer()
error = KeyboardInterrupt()

tracer.end_span_with_error(mock_span, "", error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_span_with_error_prefers_explicit_message(mock_span):
"""Test that an explicit error message takes precedence over the exception text."""
tracer = Tracer()
Expand Down Expand Up @@ -1221,6 +1233,45 @@ def test_end_span_with_exception_handling(mock_span):
pytest.fail("_end_span should not raise exceptions")


def test_force_flush_with_error(mock_span, mock_get_tracer_provider):
"""Test force flush with error handling."""
# Setup the tracer with a provider that raises an exception on force_flush
tracer = Tracer()

mock_tracer_provider = mock_get_tracer_provider.return_value
mock_tracer_provider.force_flush.side_effect = Exception("Force flush error")

# Should not raise an exception
tracer._end_span(mock_span)

# Verify force_flush was called
mock_tracer_provider.force_flush.assert_called_once()


def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span):
"""Test that agent spans fall back to the exception type name for empty errors."""
tracer = Tracer()
error = Exception()

tracer.end_agent_span(mock_span, error=error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span):
"""Test that tool call spans fall back to the exception type name for empty errors."""
tracer = Tracer()
error = Exception()

tracer.end_tool_call_span(mock_span, None, error=error)

mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception")
mock_span.record_exception.assert_called_once_with(error)
mock_span.end.assert_called_once()


def test_end_tool_call_span_with_none(mock_span):
"""Test ending a tool call span with None result."""
tracer = Tracer()
Expand Down