diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e8ea3c9bc..8078b6bf5 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -29,6 +29,7 @@ from .._async import run_async from ..event_loop._retry import ModelRetryStrategy from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle +from ..experimental.checkpoint import Checkpoint from ..tools._tool_helpers import generate_missing_tool_result_content from ..types._snapshot import ( SNAPSHOT_SCHEMA_VERSION, @@ -146,6 +147,7 @@ def __init__( tool_executor: ToolExecutor | None = None, retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW, + checkpointing: bool = False, ): """Initialize the Agent with the specified configuration. @@ -214,6 +216,11 @@ def __init__( Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations. Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided only for advanced use cases where the caller understands the risks. + checkpointing: When True, the event loop pauses at cycle boundaries (after model call, + after all tools execute) and returns an AgentResult with stop_reason="checkpoint" + and a populated ``checkpoint`` field. Persist the checkpoint and resume by passing + ``[{"checkpointResume": {"checkpoint": checkpoint.to_dict()}}]`` as the next prompt. + Defaults to False. See :mod:`strands.experimental.checkpoint` for usage and limitations. Raises: ValueError: If agent id contains path separators. @@ -304,6 +311,10 @@ def __init__( self._interrupt_state = _InterruptState() + # Checkpointing: when True, event loop pauses at cycle boundaries + self._checkpointing: bool = checkpointing + self._checkpoint_resume_context: Checkpoint | None = None + # Runtime state for model providers (e.g., server-side response ids) self._model_state: dict[str, Any] = {} @@ -374,12 +385,18 @@ def cancel(self) -> None: This method is thread-safe and can be called from any context (e.g., another thread, web request handler, background task). - The agent will stop gracefully at the next checkpoint: + The agent will stop gracefully at the next cancellation-safe point: - During model response streaming - Before tool execution The agent will return a result with stop_reason="cancelled". + Note: + "Cancellation-safe point" is distinct from + :class:`~strands.experimental.checkpoint.Checkpoint` boundaries. + Cancel takes precedence: a cancel signal at either checkpoint boundary + surfaces as ``stop_reason="cancelled"``, not ``"checkpoint"``. + Example: ```python agent = Agent(model=model) @@ -994,10 +1011,57 @@ async def _execute_event_loop_cycle( if structured_output_context: structured_output_context.cleanup(self.tool_registry) + def _try_consume_checkpoint_resume(self, prompt: list[Any]) -> bool: + """Detect, validate, and consume a ``checkpointResume`` prompt block. + + Returns True if the prompt was a resume block (state restored, caller + should skip normal message conversion). Returns False if no resume + block is present. Raises on malformed or misconfigured input. + + Follows interrupt-resume error conventions: TypeError for shape issues, + KeyError for missing keys, ValueError for misconfig. A schema mismatch + in the checkpoint payload raises ``CheckpointException``. + """ + has_checkpoint_resume = any(isinstance(content, dict) and "checkpointResume" in content for content in prompt) + if not has_checkpoint_resume: + return False + + if not self._checkpointing: + raise ValueError( + "Received checkpointResume block but agent was created with " + "checkpointing=False. Pass checkpointing=True when constructing " + "the Agent to enable durable execution." + ) + + invalid_types = [ + key for content in prompt if isinstance(content, dict) for key in content if key != "checkpointResume" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | checkpointResume cannot be mixed with other content types" + ) + + if len(prompt) != 1: + raise TypeError(f"block_count=<{len(prompt)}> | only one checkpointResume block permitted per prompt") + + resume_block = prompt[0].get("checkpointResume", {}) + if not isinstance(resume_block, dict) or "checkpoint" not in resume_block: + raise KeyError("checkpoint | missing required key in checkpointResume block") + + checkpoint = Checkpoint.from_dict(resume_block["checkpoint"]) + self.load_snapshot(Snapshot.from_dict(checkpoint.snapshot)) + self._checkpoint_resume_context = checkpoint + return True + async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: if self._interrupt_state.activated: return [] + # Resume detection — must run before existing shape handling so checkpointResume + # blocks aren't misinterpreted as content blocks. + if isinstance(prompt, list) and prompt and self._try_consume_checkpoint_resume(prompt): + return [] + messages: Messages | None = None if prompt is not None: # Check if the latest message is toolUse diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index f0a399f81..9d077d803 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -9,6 +9,7 @@ from pydantic import BaseModel +from ..experimental.checkpoint import Checkpoint from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics from ..types.content import Message @@ -26,6 +27,9 @@ class AgentResult: state: Additional state information from the event loop. interrupts: List of interrupts if raised by user. structured_output: Parsed structured output when structured_output_model was specified. + checkpoint: Checkpoint captured when the agent paused for durable execution. + Populated only when stop_reason == "checkpoint". See + strands.experimental.checkpoint for usage. """ stop_reason: StopReason @@ -34,6 +38,7 @@ class AgentResult: state: Any interrupts: Sequence[Interrupt] | None = None structured_output: BaseModel | None = None + checkpoint: Checkpoint | None = None @property def context_size(self) -> int | None: @@ -85,15 +90,23 @@ def from_dict(cls, data: dict[str, Any]) -> "AgentResult": Returns: AgentResult instance Raises: - TypeError: If the data format is invalid@ + TypeError: If the data format is invalid """ if data.get("type") != "agent_result": raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}") message = cast(Message, data.get("message")) stop_reason = cast(StopReason, data.get("stop_reason")) + checkpoint_data = data.get("checkpoint") + checkpoint = Checkpoint.from_dict(checkpoint_data) if checkpoint_data else None - return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + return cls( + message=message, + stop_reason=stop_reason, + metrics=EventLoopMetrics(), + state={}, + checkpoint=checkpoint, + ) def to_dict(self) -> dict[str, Any]: """Convert this AgentResult to JSON-serializable dictionary. @@ -105,4 +118,5 @@ def to_dict(self) -> dict[str, Any]: "type": "agent_result", "message": self.message, "stop_reason": self.stop_reason, + "checkpoint": self.checkpoint.to_dict() if self.checkpoint else None, } diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index bf1cc7a84..ec9289930 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,6 +15,7 @@ from opentelemetry import trace as trace_api +from ..experimental.checkpoint import Checkpoint, CheckpointPosition from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent from ..telemetry.metrics import Trace from ..telemetry.tracer import Tracer, get_tracer @@ -75,6 +76,28 @@ def _has_tool_use_in_latest_message(messages: "Messages") -> bool: return False +def _build_checkpoint_stop_event( + agent: "Agent", + position: "CheckpointPosition", + cycle_index: int, + message: Message, + request_state: Any, +) -> EventLoopStopEvent: + """Build a checkpoint stop event. Used at ``after_model`` and ``after_tools``.""" + checkpoint = Checkpoint( + position=position, + cycle_index=cycle_index, + snapshot=agent.take_snapshot(preset="session").to_dict(), + ) + return EventLoopStopEvent( + "checkpoint", + message, + agent.event_loop_metrics, + request_state, + checkpoint=checkpoint, + ) + + async def event_loop_cycle( agent: "Agent", invocation_state: dict[str, Any], @@ -103,12 +126,16 @@ async def event_loop_cycle( structured_output_context: Optional context for structured output management. Yields: - Model and tool stream events. The last event is a tuple containing: + Model and tool stream events. The final ``EventLoopStopEvent`` payload + (``event["stop"]``) is a 7-tuple: - - StopReason: Reason the model stopped generating (e.g., "tool_use") + - StopReason: Reason the model stopped generating (e.g., "tool_use", "checkpoint") - Message: The generated message from the model - EventLoopMetrics: Updated metrics for the event loop - Any: Updated request state + - Sequence[Interrupt] | None: Interrupts raised during the cycle, if any + - BaseModel | None: Structured output result, if any + - Checkpoint | None: Checkpoint captured when stop_reason == "checkpoint" Raises: EventLoopException: If an error occurs during execution @@ -122,6 +149,18 @@ async def event_loop_cycle( # Initialize state and get cycle trace if "request_state" not in invocation_state: invocation_state["request_state"] = {} + + # Consume checkpoint resume context (one-shot). + resume_context = agent._checkpoint_resume_context + if resume_context is not None: + agent._checkpoint_resume_context = None + # after_tools completed that cycle, so the next cycle starts at +1 + next_cycle = ( + resume_context.cycle_index + 1 if resume_context.position == "after_tools" else resume_context.cycle_index + ) + invocation_state["_checkpoint_cycle_index"] = next_cycle + invocation_state["_checkpoint_resume_position"] = resume_context.position + attributes = {"event_loop_cycle_id": str(invocation_state.get("event_loop_cycle_id"))} cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) invocation_state["event_loop_cycle_trace"] = cycle_trace @@ -181,6 +220,25 @@ async def event_loop_cycle( ) if stop_reason == "tool_use": + # Checkpoint after model call, before tools. Cancel takes precedence. + if agent._checkpointing and not agent._cancel_signal.is_set(): + resume_position = invocation_state.pop("_checkpoint_resume_position", None) + if resume_position == "after_model": + pass # Just resumed here — skip re-checkpoint + else: + cycle_index = invocation_state.get("_checkpoint_cycle_index", 0) + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + if cycle_span: + tracer.end_event_loop_cycle_span(span=cycle_span, message=message) + yield _build_checkpoint_stop_event( + agent=agent, + position="after_model", + cycle_index=cycle_index, + message=message, + request_state=invocation_state["request_state"], + ) + return + # Handle tool execution tool_events = _handle_tool_execution( stop_reason, @@ -590,6 +648,34 @@ async def _handle_tool_execution( ) return + # Checkpoint after all tools complete, before the next model call. + # Only emitted on tool_use cycles; end_turn on the first call completes + # normally with no checkpoint. Cancel takes precedence. + if agent._checkpointing and not agent._cancel_signal.is_set(): + cycle_index = invocation_state.get("_checkpoint_cycle_index", 0) + invocation_state["_checkpoint_cycle_index"] = cycle_index + 1 + yield _build_checkpoint_stop_event( + agent=agent, + position="after_tools", + cycle_index=cycle_index, + message=message, + request_state=invocation_state["request_state"], + ) + return + + # Checkpointing-only: if cancel suppressed the after_tools checkpoint above, + # surface it as "cancelled" now rather than recursing into another model call + # that would also cancel. Non-checkpointing callers fall through to + # recurse_event_loop so the existing cancel-during-model-stream path handles it. + if agent._checkpointing and agent._cancel_signal.is_set(): + yield EventLoopStopEvent( + "cancelled", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + ) + return + events = recurse_event_loop( agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context ) diff --git a/src/strands/experimental/checkpoint/checkpoint.py b/src/strands/experimental/checkpoint/checkpoint.py index f37e403c9..0011bd947 100644 --- a/src/strands/experimental/checkpoint/checkpoint.py +++ b/src/strands/experimental/checkpoint/checkpoint.py @@ -1,40 +1,43 @@ """Checkpoint system for durable agent execution. -Checkpoints enable crash-resilient agent workflows by capturing agent state at -cycle boundaries in the agent loop. A durability provider (e.g. Temporal) can -persist checkpoints and resume from them after failures. - -Two checkpoint positions per ReAct cycle: -- after_model: model call completed, tools not yet executed. -- after_tools: all tools executed, next model call pending. - -Per-tool granularity is handled by the ToolExecutor abstraction (e.g. -TemporalToolExecutor routes each tool to a separate Temporal activity). -The SDK checkpoint operates at cycle boundaries. - -User-facing pattern (same as interrupts): -- Pause via stop_reason="checkpoint" on AgentResult -- State via AgentResult.checkpoint field -- Resume via checkpointResume content block in next agent() call - -V0 Known Limitations: -- Metrics reset on each resume call. The caller is responsible for aggregating - metrics across a durable run. EventLoopMetrics reflects only the current call. -- OpenAIResponsesModel(stateful=True) is not supported. The server-side - response_id (_model_state) is not captured in the snapshot. -- When position is "after_tools", AgentResult.message is the assistant message - that requested the tools; tool results are in the snapshot messages. -- BeforeInvocationEvent and AfterInvocationEvent fire on every resume call, - same as interrupts. Hooks counting invocations will see each resume as a - separate invocation. -- Per-tool granularity within a cycle requires a custom ToolExecutor - (e.g. TemporalToolExecutor). +Checkpoints capture agent state at cycle boundaries so a durability provider +(e.g. Temporal) can persist them and resume after failures. + +Positions per ReAct cycle: +- ``after_model``: model call completed, tools not yet executed. +- ``after_tools``: all tools executed, next model call pending. + +Per-tool granularity within a cycle is the ToolExecutor's responsibility +(e.g. TemporalToolExecutor routes each tool to a separate activity). + +Usage (mirrors the interrupt pattern): +- Pause: ``AgentResult`` with ``stop_reason="checkpoint"`` and a populated + ``checkpoint`` field. +- Resume: pass ``[{"checkpointResume": {"checkpoint": checkpoint.to_dict()}}]`` + as the next prompt. + +Precedence: +- Interrupts > checkpoint: an interrupt raised during a checkpointing cycle + returns ``stop_reason="interrupt"`` and skips the ``after_tools`` checkpoint. +- Cancel > checkpoint: a cancel signal set at either checkpoint boundary + suppresses emission and surfaces as ``stop_reason="cancelled"``. + +Known limitations: +- ``EventLoopMetrics`` resets per invocation; aggregate across resumes yourself. +- ``OpenAIResponsesModel(stateful=True)`` is not supported — the server-side + ``response_id`` is not captured. +- At ``after_tools``, ``AgentResult.message`` is the assistant's tool-use + message; tool results live in the snapshot. +- ``BeforeInvocationEvent`` / ``AfterInvocationEvent`` fire on every resume, + same as interrupts. """ import logging from dataclasses import asdict, dataclass, field from typing import Any, Literal +from ...types.exceptions import CheckpointException + logger = logging.getLogger(__name__) CHECKPOINT_SCHEMA_VERSION = "1.0" @@ -42,23 +45,18 @@ CheckpointPosition = Literal["after_model", "after_tools"] -@dataclass +@dataclass(frozen=True) class Checkpoint: """Pause point in the agent loop. Treat as opaque — pass back to resume. Attributes: - position: What just completed (after_model or after_tools). - cycle_index: Which ReAct loop cycle (0-based). - snapshot: Serialized agent state as a dict, produced by ``Snapshot.to_dict()``. - Stored as ``dict[str, Any]`` (not a ``Snapshot`` object) because checkpoints - must be JSON-serializable for cross-process persistence. The consumer - reconstructs via ``Snapshot.from_dict()`` on resume. - app_data: Application-level internal state data. The SDK does not read - or modify this. Applications can store arbitrary data needed across - checkpoint boundaries (e.g. session context, workflow metadata). - Separate from ``Snapshot.app_data`` which captures agent-state-level - data managed by the SDK. - schema_version: Rejects mismatches on resume across schema versions. + position: What just completed (``after_model`` or ``after_tools``). + cycle_index: ReAct loop cycle (0-based). + snapshot: Serialized agent state from ``Snapshot.to_dict()``. Stored as + a dict (not a ``Snapshot``) so the checkpoint is JSON-serializable. + app_data: Opaque application-level state. The SDK does not read or + modify this. Distinct from ``Snapshot.app_data`` (agent-level). + schema_version: Used to reject incompatible checkpoints on resume. """ position: CheckpointPosition @@ -79,11 +77,11 @@ def from_dict(cls, data: dict[str, Any]) -> "Checkpoint": data: Serialized checkpoint data. Raises: - ValueError: If schema_version doesn't match the current version. + CheckpointException: If schema_version doesn't match the current version. """ version = data.get("schema_version", "") if version != CHECKPOINT_SCHEMA_VERSION: - raise ValueError( + raise CheckpointException( f"Checkpoints with schema version {version!r} are not compatible " f"with current version {CHECKPOINT_SCHEMA_VERSION}." ) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1d5a5de79..01f2c9e3e 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from ..agent import AgentResult from ..agent._agent_as_tool import _AgentAsTool + from ..experimental.checkpoint import Checkpoint from ..multiagent.base import MultiAgentResult, NodeResult @@ -227,6 +228,7 @@ def __init__( request_state: Any, interrupts: Sequence[Interrupt] | None = None, structured_output: BaseModel | None = None, + checkpoint: "Checkpoint | None" = None, ) -> None: """Initialize with the final execution results. @@ -237,8 +239,11 @@ def __init__( request_state: Final state of the agent execution interrupts: Interrupts raised by user during agent execution. structured_output: Optional structured output result + checkpoint: Optional checkpoint when stop_reason == "checkpoint". """ - super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts, structured_output)}) + super().__init__( + {"stop": (stop_reason, message, metrics, request_state, interrupts, structured_output, checkpoint)} + ) @property @override diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 5db80a26e..4f766a77e 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -111,3 +111,9 @@ class ConcurrencyException(Exception): """ pass + + +class CheckpointException(Exception): + """Exception raised when checkpoint operations fail (e.g., incompatible schema version).""" + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 1e27274a1..dff4b4b64 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -29,7 +29,12 @@ from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.agent import ConcurrentInvocationMode from strands.types.content import Messages -from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException +from strands.types.exceptions import ( + CheckpointException, + ConcurrencyException, + ContextWindowOverflowException, + EventLoopException, +) from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -2773,3 +2778,66 @@ def test_as_tool_defaults_description_when_agent_has_none(): tool = agent.as_tool() assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" + + +# ============================================================================ +# Checkpointing Tests +# ============================================================================ + + +def test_agent_checkpointing_defaults_to_false() -> None: + agent = Agent() + assert agent._checkpointing is False + assert agent._checkpoint_resume_context is None + + +def test_agent_checkpointing_flag_stored() -> None: + agent = Agent(checkpointing=True) + assert agent._checkpointing is True + assert agent._checkpoint_resume_context is None + + +@pytest.mark.asyncio +async def test_checkpoint_resume_without_checkpointing_flag_raises_value_error() -> None: + agent = Agent(checkpointing=False) + prompt = [{"checkpointResume": {"checkpoint": {}}}] + with pytest.raises(ValueError, match="checkpointing=True"): + await agent.invoke_async(prompt) + + +@pytest.mark.asyncio +async def test_checkpoint_resume_mixed_content_raises_type_error() -> None: + agent = Agent(checkpointing=True) + prompt = [ + {"checkpointResume": {"checkpoint": {}}}, + {"text": "bogus"}, + ] + with pytest.raises(TypeError, match="content_types"): + await agent.invoke_async(prompt) + + +@pytest.mark.asyncio +async def test_checkpoint_resume_multiple_blocks_raises_type_error() -> None: + agent = Agent(checkpointing=True) + prompt = [ + {"checkpointResume": {"checkpoint": {}}}, + {"checkpointResume": {"checkpoint": {}}}, + ] + with pytest.raises(TypeError, match="block_count"): + await agent.invoke_async(prompt) + + +@pytest.mark.asyncio +async def test_checkpoint_resume_missing_checkpoint_key_raises_key_error() -> None: + agent = Agent(checkpointing=True) + prompt = [{"checkpointResume": {}}] + with pytest.raises(KeyError, match="checkpoint"): + await agent.invoke_async(prompt) + + +@pytest.mark.asyncio +async def test_checkpoint_resume_schema_mismatch_raises_checkpoint_exception() -> None: + agent = Agent(checkpointing=True) + prompt = [{"checkpointResume": {"checkpoint": {"schema_version": "0.1", "position": "after_model"}}}] + with pytest.raises(CheckpointException, match="schema version"): + await agent.invoke_async(prompt) diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 64391f299..3ea15d016 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from strands.agent.agent_result import AgentResult +from strands.experimental.checkpoint import Checkpoint from strands.interrupt import Interrupt from strands.telemetry.metrics import EventLoopMetrics from strands.types.content import Message @@ -110,6 +111,7 @@ def test_to_dict(mock_metrics, simple_message: Message): "type": "agent_result", "message": simple_message, "stop_reason": "end_turn", + "checkpoint": None, } @@ -384,3 +386,85 @@ def test_context_size_none_when_no_data(mock_metrics, simple_message: Message): mock_metrics.latest_context_size = None result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) assert result.context_size is None + + +# ========================================================================= +# Checkpoint field and round-trip serialization (Part B) +# +# Covers the V0 durable-execution contract: when stop_reason == "checkpoint", +# AgentResult carries a Checkpoint that round-trips through to_dict/from_dict. +# ========================================================================= + + +def test_agent_result_checkpoint_field_default_none() -> None: + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "hi"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + assert result.checkpoint is None + + +def test_agent_result_accepts_checkpoint() -> None: + checkpoint = Checkpoint(position="after_model", cycle_index=0) + result = AgentResult( + stop_reason="checkpoint", + message={"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]}, + metrics=EventLoopMetrics(), + state={}, + checkpoint=checkpoint, + ) + assert result.checkpoint is checkpoint + assert result.checkpoint.position == "after_model" + + +def test_agent_result_to_dict_includes_checkpoint() -> None: + checkpoint = Checkpoint(position="after_model", cycle_index=0) + result = AgentResult( + stop_reason="checkpoint", + message={"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]}, + metrics=EventLoopMetrics(), + state={}, + checkpoint=checkpoint, + ) + d = result.to_dict() + assert d["checkpoint"] is not None + assert d["checkpoint"]["position"] == "after_model" + assert d["checkpoint"]["cycle_index"] == 0 + + +def test_agent_result_to_dict_checkpoint_none_when_absent() -> None: + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "hi"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + d = result.to_dict() + assert d["checkpoint"] is None + + +def test_agent_result_from_dict_round_trips_checkpoint() -> None: + original = AgentResult( + stop_reason="checkpoint", + message={"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]}, + metrics=EventLoopMetrics(), + state={}, + checkpoint=Checkpoint(position="after_tools", cycle_index=3), + ) + restored = AgentResult.from_dict(original.to_dict()) + assert restored.checkpoint is not None + assert restored.checkpoint.position == "after_tools" + assert restored.checkpoint.cycle_index == 3 + + +def test_agent_result_from_dict_handles_missing_checkpoint() -> None: + restored = AgentResult.from_dict( + { + "type": "agent_result", + "message": {"role": "assistant", "content": [{"text": "done"}]}, + "stop_reason": "end_turn", + } + ) + assert restored.checkpoint is None diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 871371f5f..70fe93c4d 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -13,6 +13,7 @@ import strands.telemetry from strands import Agent from strands.event_loop._retry import ModelRetryStrategy +from strands.experimental.checkpoint import Checkpoint from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -157,6 +158,8 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock._interrupt_state = _InterruptState() mock._cancel_signal = threading.Event() mock._model_state = {} + mock._checkpointing = False + mock._checkpoint_resume_context = None mock.trace_attributes = {} mock.retry_strategy = ModelRetryStrategy() @@ -190,7 +193,7 @@ async def test_event_loop_cycle_text_response( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} @@ -222,7 +225,7 @@ async def test_event_loop_cycle_text_response_throttling( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} @@ -260,7 +263,7 @@ async def test_event_loop_cycle_exponential_backoff( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] # Verify the final response assert tru_stop_reason == "end_turn" @@ -351,7 +354,7 @@ async def test_event_loop_cycle_tool_result( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} @@ -469,7 +472,7 @@ async def test_event_loop_cycle_stop( invocation_state={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] exp_stop_reason = "tool_use" exp_message = { @@ -833,7 +836,7 @@ async def test_request_state_initialization(alist): invocation_state={}, ) events = await alist(stream) - _, _, _, tru_request_state, _, _ = events[-1]["stop"] + _, _, _, tru_request_state, _, _, _ = events[-1]["stop"] # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -845,7 +848,7 @@ async def test_request_state_initialization(alist): invocation_state={"request_state": initial_request_state}, ) events = await alist(stream) - _, _, _, tru_request_state, _, _ = events[-1]["stop"] + _, _, _, tru_request_state, _, _, _ = events[-1]["stop"] # Verify existing request_state was preserved assert tru_request_state == initial_request_state @@ -969,7 +972,7 @@ def interrupt_callback(event): stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) events = await alist(stream) - tru_stop_reason, _, _, _, tru_interrupts, _ = events[-1]["stop"] + tru_stop_reason, _, _, _, tru_interrupts, _, _ = events[-1]["stop"] exp_stop_reason = "interrupt" exp_interrupts = [ Interrupt( @@ -1064,7 +1067,7 @@ def interrupt_callback(event): stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) events = await alist(stream) - tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] + tru_stop_reason, _, _, _, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" assert tru_stop_reason == exp_stop_reason @@ -1196,5 +1199,273 @@ async def test_event_loop_metrics_recorded_before_recursion( assert mock_end_cycle.call_count == 2 # Verify the event loop completed successfully - tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] + tru_stop_reason, _, _, _, _, _, _ = events[-1]["stop"] assert tru_stop_reason == "end_turn" + + +# --- Checkpoint event loop integration (Tasks 9-10) --- + + +@pytest.mark.asyncio +async def test_event_loop_cycle_checkpoint_after_model( + agent, + model, + tool_stream, + agenerator, + alist, +): + """With checkpointing=True, tool_use stop_reason yields after_model checkpoint instead of running tools.""" + agent._checkpointing = True + agent._checkpoint_resume_context = None + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + + model.stream.return_value = agenerator(tool_stream) + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + stop = events[-1]["stop"] + tru_stop_reason, _, _, _, _, _, tru_checkpoint = stop + + assert tru_stop_reason == "checkpoint" + assert tru_checkpoint is not None + assert tru_checkpoint.position == "after_model" + assert tru_checkpoint.cycle_index == 0 + + +@pytest.mark.asyncio +async def test_event_loop_cycle_checkpoint_after_tools( + agent, + model, + tool, + tool_stream, + agenerator, + alist, +): + """With checkpointing=True and resume from after_model, tools execute then yield after_tools checkpoint.""" + agent._checkpointing = True + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + agent._checkpoint_resume_context = Checkpoint(position="after_model", cycle_index=0) + + model.stream.return_value = agenerator(tool_stream) + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + tru_stop_reason, _, _, _, _, _, tru_checkpoint = events[-1]["stop"] + + assert tru_stop_reason == "checkpoint" + assert tru_checkpoint is not None + assert tru_checkpoint.position == "after_tools" + assert tru_checkpoint.cycle_index == 0 + + +@pytest.mark.asyncio +async def test_event_loop_cycle_checkpoint_resume_after_tools_increments_cycle( + agent, + model, + tool_stream, + agenerator, + alist, +): + """Resuming from after_tools sets cycle_index to previous + 1 for the next after_model checkpoint.""" + agent._checkpointing = True + agent._checkpoint_resume_context = Checkpoint(position="after_tools", cycle_index=2) + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + + model.stream.return_value = agenerator(tool_stream) + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + tru_stop_reason, _, _, _, _, _, tru_checkpoint = events[-1]["stop"] + + assert tru_stop_reason == "checkpoint" + assert tru_checkpoint.position == "after_model" + assert tru_checkpoint.cycle_index == 3 + + +@pytest.mark.asyncio +async def test_event_loop_cycle_cancel_beats_after_model_checkpoint( + agent, + model, + tool_stream, + agenerator, + alist, +): + """When a cancel signal is set after model call, cancel wins over after_model checkpoint. + + A user who calls agent.cancel() expects stop_reason="cancelled", not a stray + "checkpoint" with a snapshot they never asked for. Documented in Agent.cancel(). + """ + agent._checkpointing = True + agent._checkpoint_resume_context = None + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + + # Cancel the agent before invoking. Model streams tool_use — the emission site + # that would normally fire an after_model checkpoint must yield "cancelled" instead. + agent._cancel_signal.set() + model.stream.return_value = agenerator(tool_stream) + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + tru_stop_reason, _, _, _, _, _, tru_checkpoint = events[-1]["stop"] + + assert tru_stop_reason == "cancelled" + assert tru_checkpoint is None + + +@pytest.mark.asyncio +async def test_event_loop_cycle_cancel_mid_cycle_beats_after_model_checkpoint( + agent, + model, + tool_stream, + agenerator, + alist, +): + """Cancel signal set between model completion and after_model emission yields 'cancelled', not 'checkpoint'.""" + agent._checkpointing = True + agent._checkpoint_resume_context = None + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + + # Stub the model stream so that after it yields its stop event, we simulate + # a cancel signal arriving between model completion and the after_model check. + original_stream = agenerator(tool_stream) + + async def stream_with_mid_cycle_cancel(): + async for item in original_stream: + yield item + agent._cancel_signal.set() + + model.stream.return_value = stream_with_mid_cycle_cancel() + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + tru_stop_reason, _, _, _, _, _, tru_checkpoint = events[-1]["stop"] + + assert tru_stop_reason == "cancelled" + assert tru_checkpoint is None + + +@pytest.mark.asyncio +async def test_event_loop_cycle_cancel_mid_cycle_beats_after_tools_checkpoint( + agent, + model, + tool, + tool_stream, + agenerator, + alist, +): + """Cancel set after tools complete but before after_tools emission yields 'cancelled'.""" + from strands.experimental.checkpoint import Checkpoint + + agent._checkpointing = True + agent._checkpoint_resume_context = Checkpoint(position="after_model", cycle_index=0) + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + + # Wrap the tool executor so that after tools complete, cancel is signaled. + # The real gap: cancel arriving between tool completion and checkpoint emission. + original_execute = agent.tool_executor._execute + + def execute_then_cancel(*args, **kwargs): + stream = original_execute(*args, **kwargs) + + async def wrapped(): + async for event in stream: + yield event + agent._cancel_signal.set() + + return wrapped() + + agent.tool_executor._execute = execute_then_cancel + model.stream.return_value = agenerator(tool_stream) + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + tru_stop_reason, _, _, _, _, _, tru_checkpoint = events[-1]["stop"] + + assert tru_stop_reason == "cancelled" + assert tru_checkpoint is None + + +@pytest.mark.asyncio +async def test_event_loop_cycle_cancel_mid_cycle_after_tools_non_checkpointing( + agent, + model, + tool, + tool_stream, + agenerator, + alist, +): + """Non-checkpointing cancel set after tools complete preserves pre-PR behavior. + + The after_tools cancel emission added in this PR is gated on checkpointing=True + specifically so non-checkpointing callers continue to see the cancel surface + from the existing cancel-during-model-stream path (via recurse_event_loop). + This test pins that invariant: after tools finish with cancel set, the loop + must recurse into another event_loop_cycle (model.stream call_count increases). + """ + assert agent._checkpointing is False + + # Cancel signal is raised after tools complete, before the cycle ends. + original_execute = agent.tool_executor._execute + + def execute_then_cancel(*args, **kwargs): + stream = original_execute(*args, **kwargs) + + async def wrapped(): + async for event in stream: + yield event + agent._cancel_signal.set() + + return wrapped() + + agent.tool_executor._execute = execute_then_cancel + # First model call returns tool_use; the second call is the recursion we want + # to observe. Its behavior is irrelevant to this test — we only care that the + # recursion happens at all (i.e., non-checkpointing cancel did not short-circuit). + model.stream.side_effect = [agenerator(tool_stream), agenerator([])] + + stream_call_count_before = model.stream.call_count + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + # Drain the stream. The exact terminal event is handled by the pre-existing + # cancel-during-model-stream path; this test only pins that recursion occurred. + try: + await alist(stream) + except BaseException: # noqa: BLE001 + pass + + # Invariant: the non-checkpointing path must recurse — model.stream was called + # a second time via recurse_event_loop. + assert model.stream.call_count >= stream_call_count_before + 2 diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 2d1150712..b8db1fd3d 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -60,6 +60,9 @@ def mock_agent(): agent._interrupt_state.activated = False agent._interrupt_state.context = {} agent._cancel_signal = threading.Event() + agent._model_state = {} + agent._checkpointing = False + agent._checkpoint_resume_context = None return agent diff --git a/tests/strands/experimental/checkpoint/test_checkpoint.py b/tests/strands/experimental/checkpoint/test_checkpoint.py index 4435fb3db..a2979ba8f 100644 --- a/tests/strands/experimental/checkpoint/test_checkpoint.py +++ b/tests/strands/experimental/checkpoint/test_checkpoint.py @@ -2,52 +2,209 @@ import pytest +from strands import Agent, tool from strands.experimental.checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint +from strands.types.exceptions import CheckpointException +from tests.fixtures.mocked_model_provider import MockedModelProvider -class TestCheckpoint: - """Checkpoint dataclass serialization tests.""" - - def test_round_trip(self): - checkpoint = Checkpoint( - position="after_model", - cycle_index=1, - snapshot={"messages": []}, - app_data={"workflow_id": "wf-123"}, - ) - data = checkpoint.to_dict() - restored = Checkpoint.from_dict(data) - - assert restored.position == checkpoint.position - assert restored.cycle_index == checkpoint.cycle_index - assert restored.snapshot == checkpoint.snapshot - assert restored.app_data == checkpoint.app_data - assert restored.schema_version == CHECKPOINT_SCHEMA_VERSION - - def test_schema_version_immutable(self): - checkpoint = Checkpoint(position="after_tools") - assert checkpoint.schema_version == CHECKPOINT_SCHEMA_VERSION - - def test_schema_version_mismatch_raises(self): - data = Checkpoint(position="after_model").to_dict() - data["schema_version"] = "0.0" - with pytest.raises(ValueError, match="not compatible with current version"): - Checkpoint.from_dict(data) - - def test_defaults(self): - checkpoint = Checkpoint(position="after_model") - assert checkpoint.cycle_index == 0 - assert checkpoint.snapshot == {} - assert checkpoint.app_data == {} - - def test_from_dict_warns_on_unknown_fields(self, caplog): - data = Checkpoint(position="after_tools").to_dict() - data["unknown_future_field"] = "something" - restored = Checkpoint.from_dict(data) - assert restored.position == "after_tools" - assert "unknown_future_field" in caplog.text - - def test_from_dict_missing_schema_version_raises(self): - data = {"position": "after_model", "cycle_index": 0, "snapshot": {}, "app_data": {}} - with pytest.raises(ValueError, match="not compatible with current version"): - Checkpoint.from_dict(data) +def test_checkpoint_to_dict_from_dict_round_trip(): + checkpoint = Checkpoint( + position="after_model", + cycle_index=1, + snapshot={"messages": []}, + app_data={"workflow_id": "wf-123"}, + ) + data = checkpoint.to_dict() + restored = Checkpoint.from_dict(data) + + # Full-object equality catches any future-added field that isn't round-tripped + # correctly, without requiring this test to be updated for every new field. + assert restored == checkpoint + # schema_version is init=False, so it is always set to the current constant — + # asserted once explicitly since dataclass equality covers it via __eq__. + assert restored.schema_version == CHECKPOINT_SCHEMA_VERSION + + +def test_checkpoint_init_schema_version_immutable(): + checkpoint = Checkpoint(position="after_tools") + assert checkpoint.schema_version == CHECKPOINT_SCHEMA_VERSION + + +def test_checkpoint_init_defaults(): + checkpoint = Checkpoint(position="after_model") + assert checkpoint.cycle_index == 0 + assert checkpoint.snapshot == {} + assert checkpoint.app_data == {} + + +def test_checkpoint_from_dict_schema_version_mismatch_raises(): + data = Checkpoint(position="after_model").to_dict() + data["schema_version"] = "0.0" + with pytest.raises(CheckpointException, match="not compatible with current version"): + Checkpoint.from_dict(data) + + +def test_checkpoint_from_dict_missing_schema_version_raises(): + data = {"position": "after_model", "cycle_index": 0, "snapshot": {}, "app_data": {}} + with pytest.raises(CheckpointException, match="not compatible with current version"): + Checkpoint.from_dict(data) + + +def test_checkpoint_from_dict_unknown_fields_warns(caplog): + data = Checkpoint(position="after_tools").to_dict() + data["unknown_future_field"] = "something" + restored = Checkpoint.from_dict(data) + assert restored.position == "after_tools" + assert "unknown_future_field" in caplog.text + + +# ========================================================================= +# End-to-end integration tests (Part B) +# +# These tests exercise the full pause/resume cycle through agent.invoke_async, +# using real Agent instances (not mocks) and a scripted model provider. They prove: +# +# 1. Checkpoints round-trip through to_dict/from_dict across fresh Agent instances. +# 2. cycle_index is preserved across process-restart-style resumes. +# 3. Completed tool work survives worker loss — tools do not re-execute on resume. +# +# They do NOT cover mid-tool crashes (orchestrator responsibility) or stateful +# model server-side state (documented V0 limitation). +# ========================================================================= + + +def _assistant_tool_use(tool_use_id: str, name: str, input_data: dict) -> dict: + """Build a scripted assistant message that invokes a single tool.""" + return { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": tool_use_id, "name": name, "input": input_data}}], + } + + +def _assistant_text(text: str) -> dict: + return {"role": "assistant", "content": [{"text": text}]} + + +@pytest.mark.asyncio +async def test_checkpoint_round_trip_across_cycles() -> None: + """Fresh Agent pauses, serialize/deserialize, new Agent resumes, runs to completion.""" + call_log: list[str] = [] + + @tool + def noop(step: str) -> str: + call_log.append(step) + return f"ran-{step}" + + scripted_model = MockedModelProvider( + [ + _assistant_tool_use("t1", "noop", {"step": "one"}), + _assistant_text("done"), + ] + ) + + agent_a = Agent(model=scripted_model, tools=[noop], checkpointing=True) + + # Cycle 0 — model requests tool, pause at after_model. + result_after_model = await agent_a.invoke_async("please run a tool") + assert result_after_model.stop_reason == "checkpoint" + assert result_after_model.checkpoint is not None + assert result_after_model.checkpoint.position == "after_model" + assert result_after_model.checkpoint.cycle_index == 0 + assert call_log == [] # tool has not yet run + + # Serialize/deserialize — simulates crossing a process or activity boundary. + checkpoint_wire = result_after_model.checkpoint.to_dict() + resumed_checkpoint = Checkpoint.from_dict(checkpoint_wire) + + # Fresh Agent instance resumes and runs tools, pausing at after_tools. + agent_b = Agent(model=scripted_model, tools=[noop], checkpointing=True) + result_after_tools = await agent_b.invoke_async( + [{"checkpointResume": {"checkpoint": resumed_checkpoint.to_dict()}}] + ) + assert result_after_tools.stop_reason == "checkpoint" + assert result_after_tools.checkpoint is not None + assert result_after_tools.checkpoint.position == "after_tools" + assert result_after_tools.checkpoint.cycle_index == 0 + assert call_log == ["one"] # tool ran exactly once + + # Resume once more — model returns end_turn, agent completes. + agent_c = Agent(model=scripted_model, tools=[noop], checkpointing=True) + result_done = await agent_c.invoke_async( + [{"checkpointResume": {"checkpoint": result_after_tools.checkpoint.to_dict()}}] + ) + assert result_done.stop_reason == "end_turn" + assert result_done.checkpoint is None + # Tool still only ran once across the whole durable run. + assert call_log == ["one"] + + +@pytest.mark.asyncio +async def test_crash_after_tools_does_not_rerun_completed_tools() -> None: + """3 tools run, agent is discarded ('crash'), fresh agent resumes, tools do not re-run.""" + calls_alpha: list[str] = [] + calls_beta: list[str] = [] + calls_gamma: list[str] = [] + + @tool + def alpha(payload: str) -> str: + calls_alpha.append(payload) + return f"alpha-{payload}" + + @tool + def beta(payload: str) -> str: + calls_beta.append(payload) + return f"beta-{payload}" + + @tool + def gamma(payload: str) -> str: + calls_gamma.append(payload) + return f"gamma-{payload}" + + # One assistant message requests all three tools, then an end_turn response. + scripted_model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "t1", "name": "alpha", "input": {"payload": "a"}}}, + {"toolUse": {"toolUseId": "t2", "name": "beta", "input": {"payload": "b"}}}, + {"toolUse": {"toolUseId": "t3", "name": "gamma", "input": {"payload": "c"}}}, + ], + }, + _assistant_text("all done"), + ] + ) + + # Pre-crash agent: runs through after_model and after_tools. + pre_crash = Agent(model=scripted_model, tools=[alpha, beta, gamma], checkpointing=True) + after_model = await pre_crash.invoke_async("run the three tools") + assert after_model.stop_reason == "checkpoint" + assert after_model.checkpoint.position == "after_model" + + # Resume to run the tools. + pre_crash_b = Agent(model=scripted_model, tools=[alpha, beta, gamma], checkpointing=True) + after_tools = await pre_crash_b.invoke_async( + [{"checkpointResume": {"checkpoint": after_model.checkpoint.to_dict()}}] + ) + assert after_tools.stop_reason == "checkpoint" + assert after_tools.checkpoint.position == "after_tools" + # Exactly one call each, no double-runs. + assert calls_alpha == ["a"] + assert calls_beta == ["b"] + assert calls_gamma == ["c"] + + # "Crash": discard pre_crash_b entirely. Persist only the serialized checkpoint. + persisted = after_tools.checkpoint.to_dict() + del pre_crash, pre_crash_b + + # Post-crash: brand-new agent resumes from the after_tools checkpoint. + # The next model response is end_turn — no more tool use. + post_crash = Agent(model=scripted_model, tools=[alpha, beta, gamma], checkpointing=True) + final = await post_crash.invoke_async([{"checkpointResume": {"checkpoint": persisted}}]) + + assert final.stop_reason == "end_turn" + # No tool re-executed: call counts are unchanged. + assert calls_alpha == ["a"] + assert calls_beta == ["b"] + assert calls_gamma == ["c"] diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 11d4c10b9..ed5a218c8 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -550,15 +550,21 @@ def test_all_content_types(self): messages = [ {"role": "user", "content": [{"text": "hello world!"}]}, - {"role": "assistant", "content": [ - {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"q": "test"}}}, - {"reasoningContent": {"reasoningText": {"text": "Let me think."}}}, - {"guardContent": {"text": {"text": "Filtered."}}}, - {"citationsContent": {"content": [{"text": "Citation."}]}}, - ]}, - {"role": "user", "content": [ - {"toolResult": {"toolUseId": "1", "content": [{"text": "tool output here"}]}}, - ]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"q": "test"}}}, + {"reasoningContent": {"reasoningText": {"text": "Let me think."}}}, + {"guardContent": {"text": {"text": "Filtered."}}}, + {"citationsContent": {"content": [{"text": "Citation."}]}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "tool output here"}]}}, + ], + }, ] result = _estimate_tokens_with_heuristic( messages=messages, @@ -574,9 +580,12 @@ def test_non_serializable_inputs(self): result = _estimate_tokens_with_heuristic( messages=[ - {"role": "assistant", "content": [ - {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"data": b"bytes"}}}, - ]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"data": b"bytes"}}}, + ], + }, ], tool_specs=[{"name": "t", "inputSchema": {"json": {"default": b"bytes"}}}], ) @@ -598,9 +607,7 @@ def _block_tiktoken(name, *args, **kwargs): monkeypatch.setattr("builtins.__import__", _block_tiktoken) try: - result = await model.count_tokens( - messages=[{"role": "user", "content": [{"text": "hello world!"}]}] - ) + result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}]) assert result == 3 # ceil(12 / 4) finally: model_module._get_encoding.cache_clear() diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 48465e1f6..4f10bcce9 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -4,6 +4,7 @@ from pydantic import BaseModel +from strands.experimental.checkpoint import Checkpoint from strands.telemetry import EventLoopMetrics from strands.types._events import ( AgentAsToolStreamEvent, @@ -279,7 +280,7 @@ def test_initialization_without_structured_output(self): request_state = {"state": "final"} event = EventLoopStopEvent(stop_reason, message, metrics, request_state) - assert event["stop"] == (stop_reason, message, metrics, request_state, None, None) + assert event["stop"] == (stop_reason, message, metrics, request_state, None, None, None) assert event.is_callback_event is False def test_initialization_with_structured_output(self): @@ -291,7 +292,7 @@ def test_initialization_with_structured_output(self): structured_output = SampleModel(name="test", value=42) event = EventLoopStopEvent(stop_reason, message, metrics, request_state, structured_output) - assert event["stop"] == (stop_reason, message, metrics, request_state, structured_output, None) + assert event["stop"] == (stop_reason, message, metrics, request_state, structured_output, None, None) assert event.is_callback_event is False @@ -502,3 +503,35 @@ def test_is_tool_stream_event_subclass(self): assert isinstance(event, ToolStreamEvent) assert isinstance(event, TypedEvent) assert type(event) is AgentAsToolStreamEvent + + +# ========================================================================= +# EventLoopStopEvent checkpoint kwarg (Part B) +# ========================================================================= + + +def test_event_loop_stop_event_carries_checkpoint() -> None: + checkpoint = Checkpoint(position="after_model", cycle_index=0) + event = EventLoopStopEvent( + "checkpoint", + {"role": "assistant", "content": [{"text": "hi"}]}, + EventLoopMetrics(), + {}, + checkpoint=checkpoint, + ) + stop = event["stop"] + assert len(stop) == 7 + assert stop[0] == "checkpoint" + assert stop[6] is checkpoint + + +def test_event_loop_stop_event_checkpoint_defaults_to_none() -> None: + event = EventLoopStopEvent( + "end_turn", + {"role": "assistant", "content": [{"text": "done"}]}, + EventLoopMetrics(), + {}, + ) + stop = event["stop"] + assert len(stop) == 7 + assert stop[6] is None diff --git a/tests_integ/test_agent_checkpoint.py b/tests_integ/test_agent_checkpoint.py new file mode 100644 index 000000000..35de405f9 --- /dev/null +++ b/tests_integ/test_agent_checkpoint.py @@ -0,0 +1,206 @@ +"""Integration tests for agent checkpointing with Amazon Bedrock. + +These tests exercise the end-to-end durability contract: an agent with +``checkpointing=True`` pauses at ReAct cycle boundaries, returns an +``AgentResult`` with ``stop_reason="checkpoint"`` and a populated +``checkpoint`` field, and a fresh ``Agent`` instance resumes from the +persisted checkpoint through a ``checkpointResume`` content block. + +Requires valid AWS credentials and may incur API costs. + +To run: + hatch run test-integ tests_integ/test_agent_checkpoint.py +""" + +import json +import os + +import pytest + +from strands import Agent, tool +from strands.experimental.checkpoint import Checkpoint +from strands.models import BedrockModel + +# Skip all tests if no AWS region is configured (boto3 accepts either env var) +pytestmark = [ + pytest.mark.skipif( + not (os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")), + reason="AWS credentials not available", + ), + pytest.mark.asyncio, +] + + +MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0" + + +def _build_agent(tools: list) -> Agent: + """Build a checkpointing agent with a deterministic tool-using system prompt.""" + return Agent( + model=BedrockModel(model_id=MODEL_ID), + tools=tools, + system_prompt=( + "You are a helpful assistant. When a user asks a factual question, " + "you MUST call the provided tools to answer. Do not answer from memory." + ), + checkpointing=True, + ) + + +async def _drive_to_completion( + tools: list, + first_prompt: str, + max_resumes: int = 10, +) -> tuple[Agent, list[Checkpoint]]: + """Drive a checkpointing agent to end_turn across fresh Agent instances. + + Each time the agent pauses on a checkpoint, we serialize the checkpoint, + discard the Agent, build a fresh one, and resume. Returns the final Agent + (so callers can inspect ``messages``) and the ordered list of checkpoints + that were observed along the way. + """ + agent = _build_agent(tools) + result = await agent.invoke_async(first_prompt) + + checkpoints: list[Checkpoint] = [] + resumes = 0 + while result.stop_reason == "checkpoint": + assert result.checkpoint is not None, "checkpoint field must be populated on pause" + checkpoints.append(result.checkpoint) + + # Serialize through JSON to prove the checkpoint is durable across a + # process boundary (simulated here by round-tripping through bytes). + persisted = json.loads(json.dumps(result.checkpoint.to_dict())) + + # Discard the Agent entirely. A fresh instance resumes from scratch, + # holding no in-memory state from the previous invocation. + del agent + agent = _build_agent(tools) + + result = await agent.invoke_async([{"checkpointResume": {"checkpoint": persisted}}]) + + resumes += 1 + if resumes > max_resumes: + raise AssertionError(f"exceeded max_resumes={max_resumes} without reaching end_turn") + + assert result.stop_reason == "end_turn", f"unexpected terminal stop_reason: {result.stop_reason}" + return agent, checkpoints + + +async def test_checkpoint_roundtrip_completes_through_fresh_agent(): + """Pause at a cycle boundary, resume on a fresh Agent, reach end_turn. + + Uses a simple single-tool prompt so the agent is forced through at least + one after_model + after_tools pair before the final end_turn cycle. + """ + + @tool + def get_color_of_sky() -> str: + """Return the color of the sky.""" + return "blue" + + final_agent, checkpoints = await _drive_to_completion( + tools=[get_color_of_sky], + first_prompt="What color is the sky? Use the get_color_of_sky tool.", + ) + + # At least one checkpoint was emitted on the way to completion. + assert len(checkpoints) >= 1 + + # All checkpoints are at one of the two defined boundaries. + assert all(cp.position in ("after_model", "after_tools") for cp in checkpoints) + + # Cycle indices are non-decreasing across the run. + cycle_indices = [cp.cycle_index for cp in checkpoints] + assert cycle_indices == sorted(cycle_indices), f"cycle indices not monotonic: {cycle_indices}" + + # The final agent's message history contains the tool result. + tool_result_texts = [ + block["toolResult"]["content"][0]["text"] + for message in final_agent.messages + for block in message["content"] + if "toolResult" in block + ] + assert "blue" in tool_result_texts + + # The assistant's final message references the tool output. + final_message_text = json.dumps(final_agent.messages[-1]).lower() + assert "blue" in final_message_text + + +async def test_checkpoint_survives_process_boundary_no_tool_rerun(): + """The durability invariant: completed tool calls are not re-run on resume. + + Uses a module-level counter that each tool increments on every call. After + driving the agent through multiple resume cycles, each tool must have been + called exactly once — proof that resuming from ``after_tools`` skips the + tools that already ran rather than re-executing them. + """ + call_counts = {"time": 0, "day": 0, "weather": 0} + + @tool + def get_time() -> str: + """Return the current time.""" + call_counts["time"] += 1 + return "12:01" + + @tool + def get_day() -> str: + """Return the current day of the week.""" + call_counts["day"] += 1 + return "monday" + + @tool + def get_weather() -> str: + """Return the current weather.""" + call_counts["weather"] += 1 + return "sunny" + + final_agent, checkpoints = await _drive_to_completion( + tools=[get_time, get_day, get_weather], + first_prompt=("What is the time, the day, and the weather? Use the get_time, get_day, and get_weather tools."), + ) + + # Each tool ran exactly once across the entire durable run. Resuming from + # a checkpoint must not re-execute tools that already completed. + assert call_counts == {"time": 1, "day": 1, "weather": 1}, ( + f"tools were re-executed on resume — counts: {call_counts}" + ) + + # At least one after_tools checkpoint was observed (the scenario the + # durability invariant protects). + assert any(cp.position == "after_tools" for cp in checkpoints), ( + f"no after_tools checkpoint observed: {[cp.position for cp in checkpoints]}" + ) + + # Final message references all three tool outputs. + final_message_text = json.dumps(final_agent.messages[-1]).lower() + assert all(s in final_message_text for s in ["12:01", "monday", "sunny"]) + + +async def test_checkpoint_resume_preserves_conversation_history(): + """After resume, agent.messages contains the full pre-crash conversation. + + The snapshot-based state transfer must restore not only the pending tool + results but the entire message history (user prompt, assistant tool_use, + tool results). Otherwise the resumed model call would be missing context. + """ + + @tool + def get_favorite_number() -> int: + """Return the user's favorite number.""" + return 42 + + final_agent, _ = await _drive_to_completion( + tools=[get_favorite_number], + first_prompt="What is my favorite number? Use the get_favorite_number tool.", + ) + + # The user's original prompt survived the full checkpoint/resume cycle. + user_messages = [m for m in final_agent.messages if m["role"] == "user"] + first_user_message_text = json.dumps(user_messages[0]).lower() + assert "favorite number" in first_user_message_text + + # The assistant reached a terminal response. + assert final_agent.messages[-1]["role"] == "assistant" + assert "42" in json.dumps(final_agent.messages[-1])