Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from opentelemetry import trace as trace_api

from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
from ..hooks import AfterModelCallEvent, AfterToolCallEvent, BeforeModelCallEvent, MessageAddedEvent
from ..telemetry.metrics import Trace
from ..telemetry.tracer import Tracer, get_tracer
from ..tools._validator import validate_and_prepare_tools
Expand Down Expand Up @@ -476,15 +476,48 @@ async def _handle_tool_execution(
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]

interrupts = []

if agent._interrupt_state.activated:
tool_results.extend(agent._interrupt_state.context["tool_results"])

# Replay after-tool-call hooks for tools that were interrupted after execution.
# The tool result is already preserved in tool_results; we re-fire the hook so
# the callback receives the human response via event.interrupt() return value.
after_tool_snapshots: list[dict[str, Any]] = agent._interrupt_state.context.get("after_tool_events", [])
for snapshot in after_tool_snapshots:
tool_name = snapshot["tool_use"]["name"]
tool_func = agent.tool_registry.dynamic_tools.get(tool_name) or agent.tool_registry.registry.get(tool_name)
original_exception = Exception(snapshot["exception"]) if snapshot.get("exception") else None
original_event = AfterToolCallEvent(
agent=agent,
selected_tool=tool_func,
tool_use=snapshot["tool_use"],
invocation_state=invocation_state,
result=snapshot["result"],
exception=original_exception,
cancel_message=snapshot.get("cancel_message"),
)
replayed, new_interrupts = await agent.hooks.invoke_callbacks_async(original_event)
if new_interrupts:
interrupts.extend(new_interrupts)
continue
tool_use_id = original_event.tool_use["toolUseId"]
if getattr(replayed, "retry", False):
# Hook wants to re-execute the tool — remove preserved result and re-queue
tool_results[:] = [tr for tr in tool_results if tr["toolUseId"] != tool_use_id]
tool_uses.append(original_event.tool_use)
else:
# Update result in case the hook modified it
for i, tr in enumerate(tool_results):
if tr["toolUseId"] == tool_use_id:
tool_results[i] = replayed.result
break

# Filter to only the interrupted tools when resuming from interrupt (tool uses without results)
tool_use_ids = {tool_result["toolUseId"] for tool_result in tool_results}
tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids]

interrupts = []

# Check for cancellation before tool execution
# Add tool_result for each tool_use to maintain valid conversation state
if agent._cancel_signal.is_set():
Expand Down Expand Up @@ -528,9 +561,20 @@ async def _handle_tool_execution(
tool_events = agent.tool_executor._execute(
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context
)
after_tool_interrupt_events: list[dict[str, Any]] = []
async for tool_event in tool_events:
if isinstance(tool_event, ToolInterruptEvent):
interrupts.extend(tool_event["tool_interrupt_event"]["interrupts"])
if isinstance(tool_event.source_event, AfterToolCallEvent):
evt = tool_event.source_event
after_tool_interrupt_events.append(
{
"tool_use": evt.tool_use,
"result": evt.result,
"cancel_message": evt.cancel_message,
"exception": str(evt.exception) if evt.exception else None,
}
)

yield tool_event

Expand All @@ -544,7 +588,11 @@ async def _handle_tool_execution(

if interrupts:
# Session state stored on AfterInvocationEvent.
agent._interrupt_state.context = {"tool_use_message": message, "tool_results": tool_results}
agent._interrupt_state.context = {
"tool_use_message": message,
"tool_results": tool_results,
"after_tool_events": after_tool_interrupt_events,
}
agent._interrupt_state.activate()

agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
Expand Down
20 changes: 19 additions & 1 deletion src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _interrupt_id(self, name: str) -> str:


@dataclass
class AfterToolCallEvent(HookEvent):
class AfterToolCallEvent(HookEvent, _Interruptible):
"""Event triggered after a tool invocation completes.

This event is fired after the agent has finished executing a tool,
Expand All @@ -193,6 +193,12 @@ class AfterToolCallEvent(HookEvent):
- ToolResultEvent is NOT emitted for discarded attempts - only the final attempt's
result is emitted and added to the conversation history

Interrupts:
Hook callbacks can call ``event.interrupt(name, reason)`` to pause agent execution
and request human input. The tool result is preserved and the tool will not be
re-executed on resume. See :func:`strands.event_loop.event_loop._handle_tool_execution`
for the replay mechanism.

Attributes:
selected_tool: The tool that was invoked. It may be None if tool lookup failed.
tool_use: The tool parameters that were passed to the tool invoked.
Expand Down Expand Up @@ -221,6 +227,18 @@ def should_reverse_callbacks(self) -> bool:
"""True to invoke callbacks in reverse order."""
return True

@override
def _interrupt_id(self, name: str) -> str:
"""Unique id for the interrupt.

Args:
name: User defined name for the interrupt.

Returns:
Interrupt id.
"""
return f"v1:after_tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"


@dataclass
class BeforeModelCallEvent(HookEvent):
Expand Down
31 changes: 27 additions & 4 deletions src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def _stream(
"content": [{"text": cancel_message}],
}

after_event, _ = await ToolExecutor._invoke_after_tool_call_hook(
after_event, interrupts = await ToolExecutor._invoke_after_tool_call_hook(
agent,
None,
tool_use,
Expand All @@ -179,6 +179,12 @@ async def _stream(
exception=Exception(cancel_message),
cancel_message=cancel_message,
)

if interrupts:
tool_results.append(after_event.result)
yield ToolInterruptEvent(tool_use, interrupts, source_event=after_event)
return

yield ToolResultEvent(after_event.result, exception=after_event.exception)
tool_results.append(after_event.result)
return
Expand Down Expand Up @@ -209,9 +215,15 @@ async def _stream(
}

unknown_tool_error = Exception(f"Unknown tool: {tool_name}")
after_event, _ = await ToolExecutor._invoke_after_tool_call_hook(
after_event, interrupts = await ToolExecutor._invoke_after_tool_call_hook(
agent, selected_tool, tool_use, invocation_state, result, exception=unknown_tool_error
)

if interrupts:
tool_results.append(after_event.result)
yield ToolInterruptEvent(tool_use, interrupts, source_event=after_event)
return

# Check if retry requested for unknown tool error
# Use getattr because BidiAfterToolCallEvent doesn't have retry attribute
if getattr(after_event, "retry", False):
Expand Down Expand Up @@ -256,10 +268,15 @@ async def _stream(

result = cast(ToolResult, event)

after_event, _ = await ToolExecutor._invoke_after_tool_call_hook(
after_event, interrupts = await ToolExecutor._invoke_after_tool_call_hook(
agent, selected_tool, tool_use, invocation_state, result, exception=exception
)

if interrupts:
tool_results.append(after_event.result)
yield ToolInterruptEvent(tool_use, interrupts, source_event=after_event)
return

# Check if retry requested (getattr for BidiAfterToolCallEvent compatibility)
if getattr(after_event, "retry", False):
logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name)
Expand All @@ -277,9 +294,15 @@ async def _stream(
"content": [{"text": f"Error: {str(e)}"}],
}

after_event, _ = await ToolExecutor._invoke_after_tool_call_hook(
after_event, interrupts = await ToolExecutor._invoke_after_tool_call_hook(
agent, selected_tool, tool_use, invocation_state, error_result, exception=e
)

if interrupts:
tool_results.append(after_event.result)
yield ToolInterruptEvent(tool_use, interrupts, source_event=after_event)
return

# Check if retry requested (getattr for BidiAfterToolCallEvent compatibility)
if getattr(after_event, "retry", False):
logger.debug("tool_name=<%s> | retry requested after exception, retrying tool call", tool_name)
Expand Down
19 changes: 17 additions & 2 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
if TYPE_CHECKING:
from ..agent import AgentResult
from ..agent._agent_as_tool import _AgentAsTool
from ..hooks.registry import BaseHookEvent
from ..multiagent.base import MultiAgentResult, NodeResult


Expand Down Expand Up @@ -373,11 +374,25 @@ def message(self) -> str:


class ToolInterruptEvent(TypedEvent):
"""Event emitted when a tool is interrupted."""
"""Event emitted when a tool is interrupted.

def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None:
Attributes:
source_event: The hook event that raised the interrupt, if available. For interrupts
raised from AfterToolCallEvent, this preserves the original event so the after-hook
can be replayed on resume without re-executing the tool.
"""

def __init__(
self, tool_use: ToolUse, interrupts: list[Interrupt], source_event: "BaseHookEvent | None" = None
) -> None:
"""Set interrupt in the event payload."""
super().__init__({"tool_interrupt_event": {"tool_use": tool_use, "interrupts": interrupts}})
self._source_event = source_event

@property
def source_event(self) -> "BaseHookEvent | None":
"""The hook event that raised the interrupt."""
return self._source_event

@property
def tool_use_id(self) -> str:
Expand Down
Loading
Loading