Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .._async import run_async
from ..event_loop._retry import ModelRetryStrategy
from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle
from ..experimental.checkpoint import Checkpoint
from ..tools._tool_helpers import generate_missing_tool_result_content
from ..types._snapshot import (
SNAPSHOT_SCHEMA_VERSION,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
tool_executor: ToolExecutor | None = None,
retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY,
concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW,
checkpointing: bool = False,
Comment thread
JackYPCOnline marked this conversation as resolved.
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -214,6 +216,11 @@ def __init__(
Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations.
Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided
only for advanced use cases where the caller understands the risks.
checkpointing: When True, the event loop pauses at cycle boundaries (after model call,
after all tools execute) and returns an AgentResult with stop_reason="checkpoint"
and a populated ``checkpoint`` field. Persist the checkpoint and resume by passing
``[{"checkpointResume": {"checkpoint": checkpoint.to_dict()}}]`` as the next prompt.
Defaults to False. See :mod:`strands.experimental.checkpoint` for usage and limitations.

Raises:
ValueError: If agent id contains path separators.
Expand Down Expand Up @@ -304,6 +311,10 @@ def __init__(

self._interrupt_state = _InterruptState()

# Checkpointing: when True, event loop pauses at cycle boundaries
self._checkpointing: bool = checkpointing
self._checkpoint_resume_context: Checkpoint | None = None

# Runtime state for model providers (e.g., server-side response ids)
self._model_state: dict[str, Any] = {}

Expand Down Expand Up @@ -374,12 +385,18 @@ def cancel(self) -> None:
This method is thread-safe and can be called from any context
(e.g., another thread, web request handler, background task).

The agent will stop gracefully at the next checkpoint:
The agent will stop gracefully at the next cancellation-safe point:
- During model response streaming
- Before tool execution

The agent will return a result with stop_reason="cancelled".

Note:
"Cancellation-safe point" is distinct from
:class:`~strands.experimental.checkpoint.Checkpoint` boundaries.
Cancel takes precedence: a cancel signal at either checkpoint boundary
surfaces as ``stop_reason="cancelled"``, not ``"checkpoint"``.

Example:
```python
agent = Agent(model=model)
Expand Down Expand Up @@ -994,10 +1011,57 @@ async def _execute_event_loop_cycle(
if structured_output_context:
structured_output_context.cleanup(self.tool_registry)

def _try_consume_checkpoint_resume(self, prompt: list[Any]) -> bool:
"""Detect, validate, and consume a ``checkpointResume`` prompt block.

Returns True if the prompt was a resume block (state restored, caller
should skip normal message conversion). Returns False if no resume
block is present. Raises on malformed or misconfigured input.

Follows interrupt-resume error conventions: TypeError for shape issues,
KeyError for missing keys, ValueError for misconfig. A schema mismatch
in the checkpoint payload raises ``CheckpointException``.
"""
has_checkpoint_resume = any(isinstance(content, dict) and "checkpointResume" in content for content in prompt)
if not has_checkpoint_resume:
return False

if not self._checkpointing:
raise ValueError(
"Received checkpointResume block but agent was created with "
"checkpointing=False. Pass checkpointing=True when constructing "
"the Agent to enable durable execution."
)

invalid_types = [
key for content in prompt if isinstance(content, dict) for key in content if key != "checkpointResume"
]
if invalid_types:
raise TypeError(
f"content_types=<{invalid_types}> | checkpointResume cannot be mixed with other content types"
)

if len(prompt) != 1:
raise TypeError(f"block_count=<{len(prompt)}> | only one checkpointResume block permitted per prompt")

resume_block = prompt[0].get("checkpointResume", {})
if not isinstance(resume_block, dict) or "checkpoint" not in resume_block:
raise KeyError("checkpoint | missing required key in checkpointResume block")

checkpoint = Checkpoint.from_dict(resume_block["checkpoint"])
self.load_snapshot(Snapshot.from_dict(checkpoint.snapshot))
self._checkpoint_resume_context = checkpoint
return True

async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
if self._interrupt_state.activated:
return []

# Resume detection — must run before existing shape handling so checkpointResume
# blocks aren't misinterpreted as content blocks.
if isinstance(prompt, list) and prompt and self._try_consume_checkpoint_resume(prompt):
return []

messages: Messages | None = None
if prompt is not None:
# Check if the latest message is toolUse
Expand Down
18 changes: 16 additions & 2 deletions src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pydantic import BaseModel

from ..experimental.checkpoint import Checkpoint
from ..interrupt import Interrupt
from ..telemetry.metrics import EventLoopMetrics
from ..types.content import Message
Expand All @@ -26,6 +27,9 @@ class AgentResult:
state: Additional state information from the event loop.
interrupts: List of interrupts if raised by user.
structured_output: Parsed structured output when structured_output_model was specified.
checkpoint: Checkpoint captured when the agent paused for durable execution.
Populated only when stop_reason == "checkpoint". See
strands.experimental.checkpoint for usage.
"""

stop_reason: StopReason
Expand All @@ -34,6 +38,7 @@ class AgentResult:
state: Any
interrupts: Sequence[Interrupt] | None = None
structured_output: BaseModel | None = None
checkpoint: Checkpoint | None = None

@property
def context_size(self) -> int | None:
Expand Down Expand Up @@ -85,15 +90,23 @@ def from_dict(cls, data: dict[str, Any]) -> "AgentResult":
Returns:
AgentResult instance
Raises:
TypeError: If the data format is invalid@
TypeError: If the data format is invalid
"""
if data.get("type") != "agent_result":
raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}")

message = cast(Message, data.get("message"))
stop_reason = cast(StopReason, data.get("stop_reason"))
checkpoint_data = data.get("checkpoint")
checkpoint = Checkpoint.from_dict(checkpoint_data) if checkpoint_data else None

return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={})
return cls(
message=message,
stop_reason=stop_reason,
metrics=EventLoopMetrics(),
state={},
checkpoint=checkpoint,
)

def to_dict(self) -> dict[str, Any]:
"""Convert this AgentResult to JSON-serializable dictionary.
Expand All @@ -105,4 +118,5 @@ def to_dict(self) -> dict[str, Any]:
"type": "agent_result",
"message": self.message,
"stop_reason": self.stop_reason,
"checkpoint": self.checkpoint.to_dict() if self.checkpoint else None,
}
90 changes: 88 additions & 2 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from opentelemetry import trace as trace_api

from ..experimental.checkpoint import Checkpoint, CheckpointPosition
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
from ..telemetry.metrics import Trace
from ..telemetry.tracer import Tracer, get_tracer
Expand Down Expand Up @@ -75,6 +76,28 @@ def _has_tool_use_in_latest_message(messages: "Messages") -> bool:
return False


def _build_checkpoint_stop_event(
agent: "Agent",
position: "CheckpointPosition",
cycle_index: int,
message: Message,
request_state: Any,
) -> EventLoopStopEvent:
"""Build a checkpoint stop event. Used at ``after_model`` and ``after_tools``."""
checkpoint = Checkpoint(
position=position,
cycle_index=cycle_index,
snapshot=agent.take_snapshot(preset="session").to_dict(),
)
return EventLoopStopEvent(
"checkpoint",
message,
agent.event_loop_metrics,
request_state,
checkpoint=checkpoint,
)


async def event_loop_cycle(
agent: "Agent",
invocation_state: dict[str, Any],
Expand Down Expand Up @@ -103,12 +126,16 @@ async def event_loop_cycle(
structured_output_context: Optional context for structured output management.

Yields:
Model and tool stream events. The last event is a tuple containing:
Model and tool stream events. The final ``EventLoopStopEvent`` payload
(``event["stop"]``) is a 7-tuple:

- StopReason: Reason the model stopped generating (e.g., "tool_use")
- StopReason: Reason the model stopped generating (e.g., "tool_use", "checkpoint")
- Message: The generated message from the model
- EventLoopMetrics: Updated metrics for the event loop
- Any: Updated request state
- Sequence[Interrupt] | None: Interrupts raised during the cycle, if any
- BaseModel | None: Structured output result, if any
- Checkpoint | None: Checkpoint captured when stop_reason == "checkpoint"

Raises:
EventLoopException: If an error occurs during execution
Expand All @@ -122,6 +149,18 @@ async def event_loop_cycle(
# Initialize state and get cycle trace
if "request_state" not in invocation_state:
invocation_state["request_state"] = {}

# Consume checkpoint resume context (one-shot).
resume_context = agent._checkpoint_resume_context
if resume_context is not None:
agent._checkpoint_resume_context = None
# after_tools completed that cycle, so the next cycle starts at +1
next_cycle = (
resume_context.cycle_index + 1 if resume_context.position == "after_tools" else resume_context.cycle_index
)
invocation_state["_checkpoint_cycle_index"] = next_cycle
invocation_state["_checkpoint_resume_position"] = resume_context.position

attributes = {"event_loop_cycle_id": str(invocation_state.get("event_loop_cycle_id"))}
cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes)
invocation_state["event_loop_cycle_trace"] = cycle_trace
Expand Down Expand Up @@ -181,6 +220,25 @@ async def event_loop_cycle(
)

if stop_reason == "tool_use":
# Checkpoint after model call, before tools. Cancel takes precedence.
if agent._checkpointing and not agent._cancel_signal.is_set():
resume_position = invocation_state.pop("_checkpoint_resume_position", None)
if resume_position == "after_model":
pass # Just resumed here — skip re-checkpoint
else:
cycle_index = invocation_state.get("_checkpoint_cycle_index", 0)
Comment thread
JackYPCOnline marked this conversation as resolved.
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
if cycle_span:
tracer.end_event_loop_cycle_span(span=cycle_span, message=message)
yield _build_checkpoint_stop_event(
agent=agent,
position="after_model",
cycle_index=cycle_index,
message=message,
request_state=invocation_state["request_state"],
)
return

# Handle tool execution
tool_events = _handle_tool_execution(
stop_reason,
Expand Down Expand Up @@ -590,6 +648,34 @@ async def _handle_tool_execution(
)
return

# Checkpoint after all tools complete, before the next model call.
# Only emitted on tool_use cycles; end_turn on the first call completes
# normally with no checkpoint. Cancel takes precedence.
if agent._checkpointing and not agent._cancel_signal.is_set():
cycle_index = invocation_state.get("_checkpoint_cycle_index", 0)
invocation_state["_checkpoint_cycle_index"] = cycle_index + 1
yield _build_checkpoint_stop_event(
agent=agent,
position="after_tools",
cycle_index=cycle_index,
message=message,
request_state=invocation_state["request_state"],
)
return

# Checkpointing-only: if cancel suppressed the after_tools checkpoint above,
# surface it as "cancelled" now rather than recursing into another model call
# that would also cancel. Non-checkpointing callers fall through to
# recurse_event_loop so the existing cancel-during-model-stream path handles it.
if agent._checkpointing and agent._cancel_signal.is_set():
yield EventLoopStopEvent(
"cancelled",
message,
agent.event_loop_metrics,
invocation_state["request_state"],
)
return

events = recurse_event_loop(
agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context
)
Expand Down
Loading
Loading