From eb20ac970ab7edd0a126f5ae619de515b954ccd7 Mon Sep 17 00:00:00 2001 From: Alistair Knox Date: Tue, 21 Apr 2026 15:34:04 +0000 Subject: [PATCH] feat: support interrupts from AfterToolCallEvent (#1165) Add interrupt support to AfterToolCallEvent, enabling human-in-the-loop workflows after tool execution. This allows hook callbacks to pause agent execution after inspecting a tool result (e.g., on failure) and resume with a human response. --- src/strands/event_loop/event_loop.py | 56 +++- src/strands/hooks/events.py | 20 +- src/strands/tools/executors/_executor.py | 31 +- src/strands/types/_events.py | 19 +- .../test_after_tool_call_interrupt.py | 283 ++++++++++++++++++ tests/strands/event_loop/test_event_loop.py | 1 + .../test_after_tool_call_interrupt.py | 258 ++++++++++++++++ 7 files changed, 657 insertions(+), 11 deletions(-) create mode 100644 tests/strands/event_loop/test_after_tool_call_interrupt.py create mode 100644 tests/strands/tools/executors/test_after_tool_call_interrupt.py diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index bf1cc7a84..0704a125b 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,7 +15,7 @@ from opentelemetry import trace as trace_api -from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent +from ..hooks import AfterModelCallEvent, AfterToolCallEvent, BeforeModelCallEvent, MessageAddedEvent from ..telemetry.metrics import Trace from ..telemetry.tracer import Tracer, get_tracer from ..tools._validator import validate_and_prepare_tools @@ -476,15 +476,48 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] + interrupts = [] + if agent._interrupt_state.activated: tool_results.extend(agent._interrupt_state.context["tool_results"]) + # Replay after-tool-call hooks for tools that were interrupted after execution. + # The tool result is already preserved in tool_results; we re-fire the hook so + # the callback receives the human response via event.interrupt() return value. + after_tool_snapshots: list[dict[str, Any]] = agent._interrupt_state.context.get("after_tool_events", []) + for snapshot in after_tool_snapshots: + tool_name = snapshot["tool_use"]["name"] + tool_func = agent.tool_registry.dynamic_tools.get(tool_name) or agent.tool_registry.registry.get(tool_name) + original_exception = Exception(snapshot["exception"]) if snapshot.get("exception") else None + original_event = AfterToolCallEvent( + agent=agent, + selected_tool=tool_func, + tool_use=snapshot["tool_use"], + invocation_state=invocation_state, + result=snapshot["result"], + exception=original_exception, + cancel_message=snapshot.get("cancel_message"), + ) + replayed, new_interrupts = await agent.hooks.invoke_callbacks_async(original_event) + if new_interrupts: + interrupts.extend(new_interrupts) + continue + tool_use_id = original_event.tool_use["toolUseId"] + if getattr(replayed, "retry", False): + # Hook wants to re-execute the tool — remove preserved result and re-queue + tool_results[:] = [tr for tr in tool_results if tr["toolUseId"] != tool_use_id] + tool_uses.append(original_event.tool_use) + else: + # Update result in case the hook modified it + for i, tr in enumerate(tool_results): + if tr["toolUseId"] == tool_use_id: + tool_results[i] = replayed.result + break + # Filter to only the interrupted tools when resuming from interrupt (tool uses without results) tool_use_ids = {tool_result["toolUseId"] for tool_result in tool_results} tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids] - interrupts = [] - # Check for cancellation before tool execution # Add tool_result for each tool_use to maintain valid conversation state if agent._cancel_signal.is_set(): @@ -528,9 +561,20 @@ async def _handle_tool_execution( tool_events = agent.tool_executor._execute( agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) + after_tool_interrupt_events: list[dict[str, Any]] = [] async for tool_event in tool_events: if isinstance(tool_event, ToolInterruptEvent): interrupts.extend(tool_event["tool_interrupt_event"]["interrupts"]) + if isinstance(tool_event.source_event, AfterToolCallEvent): + evt = tool_event.source_event + after_tool_interrupt_events.append( + { + "tool_use": evt.tool_use, + "result": evt.result, + "cancel_message": evt.cancel_message, + "exception": str(evt.exception) if evt.exception else None, + } + ) yield tool_event @@ -544,7 +588,11 @@ async def _handle_tool_execution( if interrupts: # Session state stored on AfterInvocationEvent. - agent._interrupt_state.context = {"tool_use_message": message, "tool_results": tool_results} + agent._interrupt_state.context = { + "tool_use_message": message, + "tool_results": tool_results, + "after_tool_events": after_tool_interrupt_events, + } agent._interrupt_state.activate() agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 9186e0e70..1e1476cec 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -171,7 +171,7 @@ def _interrupt_id(self, name: str) -> str: @dataclass -class AfterToolCallEvent(HookEvent): +class AfterToolCallEvent(HookEvent, _Interruptible): """Event triggered after a tool invocation completes. This event is fired after the agent has finished executing a tool, @@ -193,6 +193,12 @@ class AfterToolCallEvent(HookEvent): - ToolResultEvent is NOT emitted for discarded attempts - only the final attempt's result is emitted and added to the conversation history + Interrupts: + Hook callbacks can call ``event.interrupt(name, reason)`` to pause agent execution + and request human input. The tool result is preserved and the tool will not be + re-executed on resume. See :func:`strands.event_loop.event_loop._handle_tool_execution` + for the replay mechanism. + Attributes: selected_tool: The tool that was invoked. It may be None if tool lookup failed. tool_use: The tool parameters that were passed to the tool invoked. @@ -221,6 +227,18 @@ def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + return f"v1:after_tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + @dataclass class BeforeModelCallEvent(HookEvent): diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 2c602a560..dfbd47e6a 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -170,7 +170,7 @@ async def _stream( "content": [{"text": cancel_message}], } - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + after_event, interrupts = await ToolExecutor._invoke_after_tool_call_hook( agent, None, tool_use, @@ -179,6 +179,12 @@ async def _stream( exception=Exception(cancel_message), cancel_message=cancel_message, ) + + if interrupts: + tool_results.append(after_event.result) + yield ToolInterruptEvent(tool_use, interrupts, source_event=after_event) + return + yield ToolResultEvent(after_event.result, exception=after_event.exception) tool_results.append(after_event.result) return @@ -209,9 +215,15 @@ async def _stream( } unknown_tool_error = Exception(f"Unknown tool: {tool_name}") - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + after_event, interrupts = await ToolExecutor._invoke_after_tool_call_hook( agent, selected_tool, tool_use, invocation_state, result, exception=unknown_tool_error ) + + if interrupts: + tool_results.append(after_event.result) + yield ToolInterruptEvent(tool_use, interrupts, source_event=after_event) + return + # Check if retry requested for unknown tool error # Use getattr because BidiAfterToolCallEvent doesn't have retry attribute if getattr(after_event, "retry", False): @@ -256,10 +268,15 @@ async def _stream( result = cast(ToolResult, event) - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + after_event, interrupts = await ToolExecutor._invoke_after_tool_call_hook( agent, selected_tool, tool_use, invocation_state, result, exception=exception ) + if interrupts: + tool_results.append(after_event.result) + yield ToolInterruptEvent(tool_use, interrupts, source_event=after_event) + return + # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) if getattr(after_event, "retry", False): logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) @@ -277,9 +294,15 @@ async def _stream( "content": [{"text": f"Error: {str(e)}"}], } - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + after_event, interrupts = await ToolExecutor._invoke_after_tool_call_hook( agent, selected_tool, tool_use, invocation_state, error_result, exception=e ) + + if interrupts: + tool_results.append(after_event.result) + yield ToolInterruptEvent(tool_use, interrupts, source_event=after_event) + return + # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) if getattr(after_event, "retry", False): logger.debug("tool_name=<%s> | retry requested after exception, retrying tool call", tool_name) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1d5a5de79..858897f7e 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from ..agent import AgentResult from ..agent._agent_as_tool import _AgentAsTool + from ..hooks.registry import BaseHookEvent from ..multiagent.base import MultiAgentResult, NodeResult @@ -373,11 +374,25 @@ def message(self) -> str: class ToolInterruptEvent(TypedEvent): - """Event emitted when a tool is interrupted.""" + """Event emitted when a tool is interrupted. - def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None: + Attributes: + source_event: The hook event that raised the interrupt, if available. For interrupts + raised from AfterToolCallEvent, this preserves the original event so the after-hook + can be replayed on resume without re-executing the tool. + """ + + def __init__( + self, tool_use: ToolUse, interrupts: list[Interrupt], source_event: "BaseHookEvent | None" = None + ) -> None: """Set interrupt in the event payload.""" super().__init__({"tool_interrupt_event": {"tool_use": tool_use, "interrupts": interrupts}}) + self._source_event = source_event + + @property + def source_event(self) -> "BaseHookEvent | None": + """The hook event that raised the interrupt.""" + return self._source_event @property def tool_use_id(self) -> str: diff --git a/tests/strands/event_loop/test_after_tool_call_interrupt.py b/tests/strands/event_loop/test_after_tool_call_interrupt.py new file mode 100644 index 000000000..f9430722f --- /dev/null +++ b/tests/strands/event_loop/test_after_tool_call_interrupt.py @@ -0,0 +1,283 @@ +"""Tests for AfterToolCallEvent interrupt resume in the event loop. + +Covers the replay logic in _handle_tool_execution: +- Resume replays after-hook, callback gets response +- Resume with retry=True re-queues tool for execution +- Resume with result modification preserves the modified result +- Multiple tools where one has after-interrupt +""" + +import threading +import unittest.mock + +import pytest + +import strands +import strands.event_loop.event_loop +from strands import Agent +from strands.event_loop._retry import ModelRetryStrategy +from strands.hooks import AfterToolCallEvent, HookRegistry +from strands.interrupt import Interrupt, _InterruptState +from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.executors import SequentialToolExecutor +from strands.tools.registry import ToolRegistry + + +@pytest.fixture +def tool_registry(): + return ToolRegistry() + + +@pytest.fixture +def tool(tool_registry): + @strands.tool + def tool_for_testing(random_string: str): + return random_string + + tool_registry.register_tool(tool_for_testing) + return tool_for_testing + + +@pytest.fixture +def hook_registry(): + registry = HookRegistry() + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry + + +@pytest.fixture +def model(): + return unittest.mock.Mock() + + +@pytest.fixture +def agent(model, tool_registry, hook_registry): + mock = unittest.mock.Mock(name="agent") + mock.__class__ = Agent + mock.config.cache_points = [] + mock.model = model + mock.system_prompt = "test" + mock.messages = [{"role": "user", "content": [{"text": "Hello"}]}] + mock.tool_registry = tool_registry + mock.thread_pool = None + mock.event_loop_metrics = EventLoopMetrics() + mock.event_loop_metrics.reset_usage_metrics() + mock.hooks = hook_registry + mock.tool_executor = SequentialToolExecutor() + mock._interrupt_state = _InterruptState() + mock._cancel_signal = threading.Event() + mock._model_state = {} + mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() + return mock + + +@pytest.mark.asyncio +async def test_after_tool_interrupt_and_resume(agent, model, tool, agenerator, alist): + """Full cycle: tool runs → after-hook interrupts → resume → after-hook replays with response.""" + + # Step 1: First invocation — tool runs, after-hook interrupts + model.stream.return_value = agenerator( + [ + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "t1", "name": "tool_for_testing"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "hello"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ) + + def interrupt_after(event): + if isinstance(event, AfterToolCallEvent): + event.interrupt("approval", reason="needs review") + + agent.hooks.add_callback(AfterToolCallEvent, interrupt_after) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + stop_reason = events[-1]["stop"][0] + interrupts = events[-1]["stop"][4] + assert stop_reason == "interrupt" + assert len(interrupts) == 1 + assert interrupts[0].name == "approval" + + # Verify after_tool_events saved in context + assert "after_tool_events" in agent._interrupt_state.context + assert len(agent._interrupt_state.context["after_tool_events"]) == 1 + + # Verify tool result preserved + assert len(agent._interrupt_state.context["tool_results"]) == 1 + assert agent._interrupt_state.context["tool_results"][0]["status"] == "success" + + # Step 2: Resume — provide response, after-hook replays + interrupt_id = interrupts[0].id + agent._interrupt_state.interrupts[interrupt_id].response = "APPROVED" + + # Remove the interrupt hook and add one that captures the response + captured = {} + + # Replace hook: on resume, interrupt() returns the response + agent.hooks = HookRegistry() + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(agent.hooks) + + def capture_response(event): + if isinstance(event, AfterToolCallEvent): + captured["response"] = event.interrupt("approval", reason="needs review") + + agent.hooks.add_callback(AfterToolCallEvent, capture_response) + + model.stream.return_value = agenerator([{"contentBlockStop": {}}]) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + stop_reason = events[-1]["stop"][0] + assert stop_reason == "end_turn" + assert captured["response"] == "APPROVED" + + # Interrupt state cleared + assert not agent._interrupt_state.activated + + +@pytest.mark.asyncio +async def test_after_tool_interrupt_resume_with_retry(agent, model, tool, agenerator, alist): + """On resume, if after-hook sets retry=True, tool re-executes.""" + tool_use = {"toolUseId": "t1", "name": "tool_for_testing", "input": {"random_string": "attempt1"}} + tool_use_message = { + "role": "assistant", + "content": [{"toolUse": tool_use}], + } + + # Set up interrupt state as if first invocation already happened + original_result = {"toolUseId": "t1", "status": "error", "content": [{"text": "failed"}]} + + interrupt = Interrupt( + id="v1:after_tool_call:t1:7eb5933b-ed83-5e65-84e6-fa22d85940c9", + name="retry_check", + reason="tool failed", + response="RETRY", + ) + + agent._interrupt_state.context = { + "tool_use_message": tool_use_message, + "tool_results": [original_result], + "after_tool_events": [{"tool_use": tool_use, "result": original_result, "cancel_message": None}], + } + agent._interrupt_state.interrupts[interrupt.id] = interrupt + agent._interrupt_state.activate() + + # On resume, after-hook gets response and sets retry + def retry_on_response(event): + if isinstance(event, AfterToolCallEvent) and event.result.get("status") == "error": + response = event.interrupt("retry_check", reason="tool failed") + if response == "RETRY": + event.retry = True + + agent.hooks.add_callback(AfterToolCallEvent, retry_on_response) + + # Model responds with end_turn after the retried tool completes + model.stream.side_effect = [agenerator([{"contentBlockStop": {}}])] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + stop_reason = events[-1]["stop"][0] + assert stop_reason == "end_turn" + + # The tool re-executed (retry removed the error result and re-queued), + # so the final result message should contain a success result + result_messages = [ + m + for m in agent.messages + if isinstance(m, dict) and m.get("role") == "user" and any("toolResult" in c for c in m.get("content", [])) + ] + assert len(result_messages) > 0 + last_tool_result = result_messages[-1]["content"][0]["toolResult"] + assert last_tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_after_tool_interrupt_resume_modifies_result(agent, model, tool, agenerator, alist): + """On resume, after-hook can modify the result without retry.""" + tool_use = {"toolUseId": "t1", "name": "tool_for_testing", "input": {"random_string": "test"}} + tool_use_message = { + "role": "assistant", + "content": [{"toolUse": tool_use}], + } + + original_result = {"toolUseId": "t1", "status": "error", "content": [{"text": "original error"}]} + + interrupt = Interrupt( + id="v1:after_tool_call:t1:6124fc2a-cbe6-5805-84ac-5847c3fe6953", + name="fix_result", + reason="error", + response="USE_DEFAULT", + ) + + agent._interrupt_state.context = { + "tool_use_message": tool_use_message, + "tool_results": [original_result], + "after_tool_events": [{"tool_use": tool_use, "result": original_result, "cancel_message": None}], + } + agent._interrupt_state.interrupts[interrupt.id] = interrupt + agent._interrupt_state.activate() + + def modify_on_response(event): + if isinstance(event, AfterToolCallEvent): + response = event.interrupt("fix_result", reason="error") + if response == "USE_DEFAULT": + event.result = {"toolUseId": "t1", "status": "success", "content": [{"text": "default value"}]} + + agent.hooks.add_callback(AfterToolCallEvent, modify_on_response) + + model.stream.return_value = agenerator([{"contentBlockStop": {}}]) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + stop_reason = events[-1]["stop"][0] + assert stop_reason == "end_turn" + + # The modified result should be in the conversation + result_messages = [ + m + for m in agent.messages + if isinstance(m, dict) and m.get("role") == "user" and any("toolResult" in c for c in m.get("content", [])) + ] + assert len(result_messages) > 0 + last_tool_result = result_messages[-1]["content"][0]["toolResult"] + assert last_tool_result["status"] == "success" + assert last_tool_result["content"][0]["text"] == "default value" + + +@pytest.mark.asyncio +async def test_after_tool_interrupt_context_saved_correctly(agent, model, tool, agenerator, alist): + """Interrupt context includes after_tool_events list.""" + model.stream.return_value = agenerator( + [ + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "t1", "name": "tool_for_testing"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "x"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ) + + def interrupt_after(event): + if isinstance(event, AfterToolCallEvent): + event.interrupt("check", reason="test") + + agent.hooks.add_callback(AfterToolCallEvent, interrupt_after) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + await alist(stream) + + ctx = agent._interrupt_state.context + assert "after_tool_events" in ctx + assert len(ctx["after_tool_events"]) == 1 + snapshot = ctx["after_tool_events"][0] + assert isinstance(snapshot, dict) + assert "tool_use" in snapshot + assert "result" in snapshot + assert snapshot["result"]["toolUseId"] == "t1" diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 871371f5f..3a4b51af0 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -999,6 +999,7 @@ def interrupt_callback(event): "role": "assistant", "metadata": ANY, }, + "after_tool_events": [], }, "interrupts": { "v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9": { diff --git a/tests/strands/tools/executors/test_after_tool_call_interrupt.py b/tests/strands/tools/executors/test_after_tool_call_interrupt.py new file mode 100644 index 000000000..c4c02ba25 --- /dev/null +++ b/tests/strands/tools/executors/test_after_tool_call_interrupt.py @@ -0,0 +1,258 @@ +"""Tests for AfterToolCallEvent interrupt support. + +Covers all 4 interrupt paths in _executor.py: +- Success path: tool succeeds, after-hook interrupts +- Exception path: tool raises, after-hook interrupts +- Cancel path: before-hook cancels tool, after-hook interrupts +- Unknown tool path: tool not found, after-hook interrupts + +Also covers resume behavior and retry-on-resume. +""" + +import pytest + +import strands +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent +from strands.interrupt import Interrupt +from strands.tools.executors._executor import ToolExecutor +from strands.types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent +from strands.types.tools import ToolUse + + +@pytest.fixture +def executor(): + class TestExecutor(ToolExecutor): + def _execute(self, _agent, _tool_uses, _tool_results, _invocation_state): + raise NotImplementedError + + return TestExecutor() + + +# -- Success path -- + + +@pytest.mark.asyncio +async def test_after_tool_call_interrupt_on_success(executor, agent, tool_results, invocation_state, alist): + """AfterToolCallEvent interrupt on successful tool execution yields ToolInterruptEvent with source_event.""" + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "t1", "input": {}} + + def interrupt_after(event): + if isinstance(event, AfterToolCallEvent): + event.interrupt("review", reason="check result") + + agent.hooks.add_callback(AfterToolCallEvent, interrupt_after) + + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + assert events[0].interrupts[0].name == "review" + assert isinstance(events[0].source_event, AfterToolCallEvent) + assert events[0].source_event.result["status"] == "success" + + # Result preserved in tool_results for resume + assert len(tool_results) == 1 + assert tool_results[0]["toolUseId"] == "t1" + assert tool_results[0]["status"] == "success" + + +@pytest.mark.asyncio +async def test_after_tool_call_interrupt_resume_on_success(executor, agent, tool_results, invocation_state, alist): + """On resume, after-hook re-fires and callback gets the interrupt response.""" + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "t1", "input": {}} + + interrupt = Interrupt( + id="v1:after_tool_call:t1:fd6381ef-9533-5ce1-8a4d-75db796edf35", + name="review", + reason="check result", + response="APPROVED", + ) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + captured = {} + + def interrupt_after(event): + if isinstance(event, AfterToolCallEvent): + captured["response"] = event.interrupt("review", reason="check result") + + agent.hooks.add_callback(AfterToolCallEvent, interrupt_after) + + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + # No interrupt this time — response was available + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + assert captured["response"] == "APPROVED" + + +# -- Exception path -- + + +@pytest.mark.asyncio +async def test_after_tool_call_interrupt_on_exception(executor, agent, tool_results, invocation_state, alist): + """AfterToolCallEvent interrupt when tool raises an exception.""" + tool_use: ToolUse = {"name": "exception_tool", "toolUseId": "t1", "input": {}} + + def interrupt_on_error(event): + if isinstance(event, AfterToolCallEvent) and event.exception: + event.interrupt("error_review", reason=str(event.exception)) + + agent.hooks.add_callback(AfterToolCallEvent, interrupt_on_error) + + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + assert events[0].interrupts[0].name == "error_review" + assert isinstance(events[0].source_event, AfterToolCallEvent) + assert events[0].source_event.result["status"] == "error" + + # Result preserved + assert len(tool_results) == 1 + assert tool_results[0]["status"] == "error" + + +# -- Cancel path -- + + +@pytest.mark.asyncio +async def test_after_tool_call_interrupt_on_cancel(executor, agent, tool_results, invocation_state, alist): + """AfterToolCallEvent interrupt when tool was cancelled by before-hook.""" + + def cancel_tool(event): + if isinstance(event, BeforeToolCallEvent): + event.cancel_tool = True + + def interrupt_after_cancel(event): + if isinstance(event, AfterToolCallEvent) and event.cancel_message: + event.interrupt("cancel_review", reason="cancelled") + + agent.hooks.add_callback(BeforeToolCallEvent, cancel_tool) + agent.hooks.add_callback(AfterToolCallEvent, interrupt_after_cancel) + + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "t1", "input": {}} + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + # ToolCancelEvent then ToolInterruptEvent + assert isinstance(events[0], ToolCancelEvent) + assert isinstance(events[1], ToolInterruptEvent) + assert events[1].interrupts[0].name == "cancel_review" + assert isinstance(events[1].source_event, AfterToolCallEvent) + + assert len(tool_results) == 1 + assert tool_results[0]["status"] == "error" + + +# -- Unknown tool path -- + + +@pytest.mark.asyncio +async def test_after_tool_call_interrupt_on_unknown_tool(executor, agent, tool_results, invocation_state, alist): + """AfterToolCallEvent interrupt when tool is not found.""" + + def interrupt_on_unknown(event): + if isinstance(event, AfterToolCallEvent) and event.exception: + event.interrupt("unknown_review", reason="tool missing") + + agent.hooks.add_callback(AfterToolCallEvent, interrupt_on_unknown) + + tool_use: ToolUse = {"name": "nonexistent_tool", "toolUseId": "t1", "input": {}} + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + assert events[0].interrupts[0].name == "unknown_review" + + assert len(tool_results) == 1 + assert "Unknown tool" in tool_results[0]["content"][0]["text"] + + +# -- Interrupt ID uniqueness -- + + +@pytest.mark.asyncio +async def test_after_tool_call_interrupt_id_uses_tool_use_id(executor, agent, tool_results, invocation_state, alist): + """Interrupt ID includes toolUseId so different tool calls produce different IDs.""" + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "unique_id_123", "input": {}} + + def interrupt_after(event): + if isinstance(event, AfterToolCallEvent): + event.interrupt("check", reason="test") + + agent.hooks.add_callback(AfterToolCallEvent, interrupt_after) + + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + interrupt = events[0].interrupts[0] + assert "unique_id_123" in interrupt.id + assert interrupt.id.startswith("v1:after_tool_call:") + + +# -- source_event typing -- + + +@pytest.mark.asyncio +async def test_tool_interrupt_event_source_event_none_for_before_hook( + executor, agent, tool_results, invocation_state, alist +): + """BeforeToolCallEvent interrupts produce ToolInterruptEvent with source_event=None.""" + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "t1", "input": {}} + + def interrupt_before(event): + if isinstance(event, BeforeToolCallEvent): + event.interrupt("before_check", reason="test") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_before) + + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + assert isinstance(events[0], ToolInterruptEvent) + assert events[0].source_event is None + + +# -- No interrupt when not raised -- + + +@pytest.mark.asyncio +async def test_after_tool_call_no_interrupt_when_not_raised(executor, agent, tool_results, invocation_state, alist): + """Normal after-hook without interrupt proceeds as usual.""" + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "t1", "input": {}} + + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + assert events[0].tool_result["status"] == "success" + + +# -- Interrupt takes precedence over retry -- + + +@pytest.mark.asyncio +async def test_after_tool_call_interrupt_takes_precedence_over_retry( + executor, agent, tool_results, invocation_state, alist +): + """When interrupt is raised, retry flag on the event is never checked.""" + call_count = {"n": 0} + + @strands.tool(name="counted_tool") + def counted_tool(): + call_count["n"] += 1 + return "done" + + agent.tool_registry.register_tool(counted_tool) + + def interrupt_and_retry(event): + if isinstance(event, AfterToolCallEvent): + # interrupt() raises InterruptException before retry is read + event.retry = True + event.interrupt("block", reason="paused") + + agent.hooks.add_callback(AfterToolCallEvent, interrupt_and_retry) + + tool_use: ToolUse = {"name": "counted_tool", "toolUseId": "t1", "input": {}} + events = await alist(executor._stream(agent, tool_use, tool_results, invocation_state)) + + # Tool ran once, interrupt stopped the loop + assert call_count["n"] == 1 + assert isinstance(events[0], ToolInterruptEvent)