diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index c59623215..fa74f0f2e 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -8,11 +8,13 @@ size while preserving conversation coherence - SummarizingConversationManager: An implementation that summarizes older context instead of simply trimming it +- estimate_tokens: Lightweight token estimation utility (chars/4 heuristic) for conversation managers Conversation managers help control memory usage and context length while maintaining relevant conversation state, which is critical for effective agent interactions. """ +from ._token_utils import TokenCounter, estimate_tokens from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager @@ -23,4 +25,6 @@ "NullConversationManager", "SlidingWindowConversationManager", "SummarizingConversationManager", + "TokenCounter", + "estimate_tokens", ] diff --git a/src/strands/agent/conversation_manager/_token_utils.py b/src/strands/agent/conversation_manager/_token_utils.py new file mode 100644 index 000000000..7e51529c3 --- /dev/null +++ b/src/strands/agent/conversation_manager/_token_utils.py @@ -0,0 +1,84 @@ +"""Lightweight token estimation utilities for conversation managers.""" + +import json +from collections.abc import Callable +from typing import Any + +from ...types.content import Messages + +IMAGE_CHAR_ESTIMATE = 4000 + +TokenCounter = Callable[[Messages], int] + + +def estimate_tokens(messages: Messages) -> int: + """Approximate token count for a message list using a chars/4 heuristic. + + This is deliberately conservative (overestimates for English text, underestimates for CJK). + For model-specific accuracy, pass a custom ``token_counter`` to the conversation manager. + + Args: + messages: The conversation message history. + + Returns: + Estimated token count. + """ + total_chars = 0 + for msg in messages: + for block in msg.get("content", []): + total_chars += _estimate_block_chars(block) + return total_chars // 4 + + +def _estimate_block_chars(block: Any) -> int: + """Estimate character count for a single content block.""" + if "text" in block: + return len(block["text"]) + + if "toolResult" in block: + result = block["toolResult"] + chars = 0 + for item in result.get("content", []): + if "text" in item: + chars += len(item["text"]) + elif "image" in item: + chars += IMAGE_CHAR_ESTIMATE + return chars + + if "toolUse" in block: + tool_use = block["toolUse"] + chars = len(tool_use.get("name", "")) + tool_input = tool_use.get("input", {}) + if isinstance(tool_input, str): + chars += len(tool_input) + else: + chars += len(json.dumps(tool_input, default=str)) + return chars + + if "image" in block: + return IMAGE_CHAR_ESTIMATE + + # NOTE (M3): len(bytes) returns raw binary size, not extractable text length. + # A 100KB PDF may contain only 5KB of text — this overestimates for binary + # documents but stays in the same accuracy class as the chars/4 heuristic. + if "document" in block: + doc = block["document"] + source = doc.get("source", {}) + data = source.get("bytes", b"") + return len(data) if data else 200 + + # NOTE (L1): Rough placeholder — actual video token cost varies enormously by + # duration/resolution. Treat as an order-of-magnitude estimate. + if "video" in block: + return IMAGE_CHAR_ESTIMATE * 10 + + if "reasoningContent" in block: + rc = block["reasoningContent"] + if isinstance(rc, dict) and "reasoningText" in rc: + return len(rc["reasoningText"].get("text", "")) + return len(str(rc)) + + if "cachePoint" in block or "guardContent" in block or "citationsContent" in block: + return 0 + + return len(str(block)) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 94446380b..d9d235058 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -10,6 +10,7 @@ from ...types.content import ContentBlock, Messages from ...types.exceptions import ContextWindowOverflowException from ...types.tools import ToolResultContent +from ._token_utils import IMAGE_CHAR_ESTIMATE, TokenCounter, estimate_tokens from .conversation_manager import ConversationManager logger = logging.getLogger(__name__) @@ -37,6 +38,9 @@ def __init__( should_truncate_results: bool = True, *, per_turn: bool | int = False, + max_context_tokens: int | None = None, + token_counter: TokenCounter | None = None, + compactable_after_messages: int | None = None, ): """Initialize the sliding window conversation manager. @@ -54,19 +58,45 @@ def __init__( manage message history and prevent the agent loop from slowing down. Start with per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed for performance tuning. + max_context_tokens: Optional maximum token budget for the conversation context. + When set, the manager checks both message count and estimated token count, + trimming oldest messages when the budget is exceeded. Uses the configured + ``token_counter`` heuristic (chars/4 by default). Note: when both + ``max_context_tokens`` and ``window_size`` are set, either limit can + independently trigger context reduction. + token_counter: Optional custom token counting function. Takes a Messages list + and returns an integer token count. When not provided, the built-in + ``estimate_tokens`` heuristic (chars/4) is used. + compactable_after_messages: Optional message age after which tool results are + replaced with a short stub (``[Tool result cleared — re-run if needed]``). + This reclaims token budget from stale, re-runnable tool output while + preserving the toolUse/toolResult pair structure required by model APIs. Raises: - ValueError: If per_turn is 0 or a negative integer. + ValueError: If per_turn is 0 or a negative integer, or if compactable_after_messages + is not a positive integer. """ if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0: raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}") + if max_context_tokens is not None and max_context_tokens <= 0: + raise ValueError(f"max_context_tokens must be a positive integer, got {max_context_tokens}") + + if compactable_after_messages is not None and compactable_after_messages <= 0: + raise ValueError( + f"compactable_after_messages must be a positive integer, got {compactable_after_messages}" + ) + super().__init__() self.window_size = window_size self.should_truncate_results = should_truncate_results self.per_turn = per_turn + self.max_context_tokens = max_context_tokens + self.token_counter: TokenCounter = token_counter or estimate_tokens + self.compactable_after_messages = compactable_after_messages self._model_call_count = 0 + self._last_compacted_index = 0 def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: """Register hook callbacks for per-turn conversation management. @@ -77,34 +107,44 @@ def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: """ super().register_hooks(registry, **kwargs) - # Always register the callback - per_turn check happens in the callback + # Always register — per_turn and max_context_tokens checks happen in the callback registry.add_callback(BeforeModelCallEvent, self._on_before_model_call) def _on_before_model_call(self, event: BeforeModelCallEvent) -> None: - """Handle before model call event for per-turn management. + """Handle before model call event for per-turn management and token budget enforcement. - This callback is invoked before each model call. It tracks the model call count and applies message management - based on the per_turn configuration. + This callback is invoked before each model call. It applies management when either + the token budget is exceeded or per-turn management is due. A single + ``apply_management`` call handles both token budget and message count limits, so + at most one call is made per hook invocation. Args: event: The before model call event containing the agent and model execution details. """ - # Check if per_turn is enabled - if self.per_turn is False: - return + needs_apply = False + + if self.max_context_tokens is not None: + current_tokens = self._get_current_token_count(event.agent) + if current_tokens > self.max_context_tokens: + logger.debug( + "current_tokens=<%d>, max_context_tokens=<%d> | token budget exceeded", + current_tokens, + self.max_context_tokens, + ) + needs_apply = True - self._model_call_count += 1 + if self.per_turn is not False: + self._model_call_count += 1 - # Determine if we should apply management - should_apply = False - if self.per_turn is True: - should_apply = True - elif isinstance(self.per_turn, int) and self.per_turn > 0: - should_apply = self._model_call_count % self.per_turn == 0 + if self.per_turn is True: + needs_apply = True + elif isinstance(self.per_turn, int) and self.per_turn > 0: + if self._model_call_count % self.per_turn == 0: + needs_apply = True - if should_apply: + if needs_apply: logger.debug( - "model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management", + "model_call_count=<%d>, per_turn=<%s> | applying conversation management", self._model_call_count, self.per_turn, ) @@ -118,6 +158,7 @@ def get_state(self) -> dict[str, Any]: """ state = super().get_state() state["model_call_count"] = self._model_call_count + state["last_compacted_index"] = self._last_compacted_index return state def restore_from_session(self, state: dict[str, Any]) -> list | None: @@ -131,13 +172,15 @@ def restore_from_session(self, state: dict[str, Any]) -> list | None: """ result = super().restore_from_session(state) self._model_call_count = state.get("model_call_count", 0) + self._last_compacted_index = state.get("last_compacted_index", 0) return result def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. - This method is called after every event loop cycle to apply a sliding window if the message count - exceeds the window size. + This method is called after every event loop cycle. It applies micro-compaction for stale tool + results (if configured), then loops ``reduce_context`` until both message count and token budget + limits are satisfied (or no further reduction is possible). Args: agent: The agent whose messages will be managed. @@ -146,12 +189,36 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """ messages = agent.messages - if len(messages) <= self.window_size: - logger.debug( - "message_count=<%s>, window_size=<%s> | skipping context reduction", len(messages), self.window_size + # Micro-compact stale tool results before checking limits + if self.compactable_after_messages is not None: + self._micro_compact(messages) + + # Bound by len(messages) — each iteration must remove at least one message or + # tool-result truncation, and the no-progress guard below catches stalls. + max_iterations = len(messages) + for _ in range(max_iterations): + over_message_limit = len(messages) > self.window_size + over_token_limit = ( + self.max_context_tokens is not None and self._get_current_token_count(agent) > self.max_context_tokens ) - return - self.reduce_context(agent) + + if not over_message_limit and not over_token_limit: + logger.debug( + "message_count=<%s>, window_size=<%s> | context within limits", + len(messages), + self.window_size, + ) + return + + prev_len = len(messages) + self.reduce_context(agent) + if len(messages) >= prev_len: + logger.warning( + "message_count=<%s>, window_size=<%s> | reduce_context made no progress, stopping", + len(messages), + self.window_size, + ) + return def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Trim the oldest messages to reduce the conversation context size. @@ -229,9 +296,77 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A # trim_index represents the number of messages being removed from the agents messages array self.removed_message_count += trim_index + # Adjust compaction tracking index + self._last_compacted_index = max(0, self._last_compacted_index - trim_index) + # Overwrite message history messages[:] = messages[trim_index:] + def _get_current_token_count(self, agent: "Agent") -> int: + """Estimate the current token count for the conversation context. + + Always uses the configured ``token_counter`` heuristic rather than model-reported + ``latest_context_size``, because the model-reported value reflects the *previous* + cycle and becomes stale after any reduction — leading to over-reduction spirals. + + Args: + agent: The agent whose context size is being measured. + + Returns: + The estimated token count. + """ + return self.token_counter(agent.messages) + + _COMPACT_STUB = "[Tool result cleared — re-run if needed]" + + def _micro_compact(self, messages: Messages) -> int: + """Replace old tool results with compact stubs to reclaim token budget. + + Tool results older than ``compactable_after_messages`` messages from the end of the + conversation are replaced with a short stub. The toolUse/toolResult pair structure + is preserved — only the content within toolResult blocks is replaced. + + Tracks ``_last_compacted_index`` to skip already-processed messages on subsequent calls. + + Args: + messages: The conversation message history (modified in-place). + + Returns: + Estimated number of tokens reclaimed. + """ + if self.compactable_after_messages is None: + return 0 + + # NOTE (M1): Clamp index in case messages were externally replaced or shortened + # between calls (e.g., manual agent.messages reset, session restore mismatch). + self._last_compacted_index = min(self._last_compacted_index, len(messages)) + + reclaimed_chars = 0 + cutoff = len(messages) - self.compactable_after_messages + + # NOTE (M2): reclaimed_chars may overcount if text was already truncated by + # _truncate_tool_results — the return value is an estimate, not used for decisions. + stub_len = len(self._COMPACT_STUB) + for i in range(self._last_compacted_index, max(0, cutoff)): + msg = messages[i] + for block in msg.get("content", []): + if "toolResult" not in block: + continue + result = block["toolResult"] + items = result.get("content", []) + for j, item in enumerate(items): + if "text" in item and item["text"] != self._COMPACT_STUB: + reclaimed_chars += max(0, len(item["text"]) - stub_len) + items[j] = {"text": self._COMPACT_STUB} + elif "image" in item: + reclaimed_chars += IMAGE_CHAR_ESTIMATE + items[j] = {"text": self._COMPACT_STUB} + + if cutoff > 0: + self._last_compacted_index = max(self._last_compacted_index, cutoff) + + return reclaimed_chars // 4 + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: """Truncate tool results and replace image blocks in a message to reduce context size. diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index abd4d08b5..b924be43e 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -7,11 +7,13 @@ from ..._async import run_async from ...event_loop.streaming import process_stream +from ...hooks import BeforeModelCallEvent, HookRegistry from ...tools._tool_helpers import noop_tool from ...tools.registry import ToolRegistry from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException from ...types.tools import AgentTool +from ._token_utils import TokenCounter, estimate_tokens from .conversation_manager import ConversationManager if TYPE_CHECKING: @@ -65,6 +67,10 @@ def __init__( preserve_recent_messages: int = 10, summarization_agent: Optional["Agent"] = None, summarization_system_prompt: str | None = None, + *, + max_context_tokens: int | None = None, + proactive_threshold: float = 0.8, + token_counter: TokenCounter | None = None, ): """Initialize the summarizing conversation manager. @@ -77,6 +83,19 @@ def __init__( If provided, this agent can use tools as part of the summarization process. summarization_system_prompt: Optional system prompt override for summarization. If None, uses the default summarization prompt. + max_context_tokens: Optional maximum token budget for the conversation context. + When set, summarization is triggered proactively when estimated token usage + exceeds ``max_context_tokens * proactive_threshold``, instead of waiting + for a ``ContextWindowOverflowException``. + proactive_threshold: Fraction of ``max_context_tokens`` at which proactive + summarization is triggered. Defaults to 0.8 (80%). Tune this based on + your model's context window — models with larger windows may benefit from + a higher threshold (e.g. 0.9) to maximize context utilization, while + models with tight windows may need a lower threshold (e.g. 0.6) to leave + headroom for the summarization request itself. + token_counter: Optional custom token counting function. Takes a Messages list + and returns an integer token count. When not provided, the built-in + ``estimate_tokens`` heuristic (chars/4) is used. """ super().__init__() if summarization_agent is not None and summarization_system_prompt is not None: @@ -85,11 +104,56 @@ def __init__( "Agents come with their own system prompt." ) + if max_context_tokens is not None and max_context_tokens <= 0: + raise ValueError(f"max_context_tokens must be a positive integer, got {max_context_tokens}") + self.summary_ratio = max(0.1, min(0.8, summary_ratio)) self.preserve_recent_messages = preserve_recent_messages self.summarization_agent = summarization_agent self.summarization_system_prompt = summarization_system_prompt + self.max_context_tokens = max_context_tokens + self.proactive_threshold = max(0.1, min(1.0, proactive_threshold)) + self.token_counter: TokenCounter = token_counter or estimate_tokens self._summary_message: Message | None = None + self._last_summarized_msg_count: int | None = None + + def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: + """Register hook callbacks for proactive token-budget management. + + Only registers the before-model-call hook when ``max_context_tokens`` is set. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + super().register_hooks(registry, **kwargs) + if self.max_context_tokens is not None: + registry.add_callback(BeforeModelCallEvent, self._on_before_model_call) + + def _on_before_model_call(self, event: BeforeModelCallEvent) -> None: + """Check token budget before each model call and trigger proactive summarization if needed.""" + if self.max_context_tokens is None: + return + + current_tokens = self._get_current_token_count(event.agent) + threshold = int(self.max_context_tokens * self.proactive_threshold) + + if current_tokens > threshold: + logger.debug( + "current_tokens=<%d>, threshold=<%d> | proactive summarization triggered", + current_tokens, + threshold, + ) + self._do_proactive_summarization(event.agent) + + def _get_current_token_count(self, agent: "Agent") -> int: + """Estimate the current token count using the configured heuristic. + + Always uses the ``token_counter`` heuristic rather than model-reported + ``latest_context_size``, because the model-reported value reflects the previous + cycle and becomes stale after any reduction. + """ + return self.token_counter(agent.messages) @override def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: @@ -110,18 +174,49 @@ def get_state(self) -> dict[str, Any]: return {"summary_message": self._summary_message, **super().get_state()} def apply_management(self, agent: "Agent", **kwargs: Any) -> None: - """Apply management strategy to conversation history. + """Apply token-budget management to the conversation history. - For the summarizing conversation manager, no proactive management is performed. - Summarization only occurs when there's a context overflow that triggers reduce_context. + When ``max_context_tokens`` is configured, checks the current token usage against the + proactive threshold and triggers summarization if exceeded. Skips if summarization + already ran this cycle (e.g., from the ``_on_before_model_call`` hook). Without + ``max_context_tokens``, this is a no-op — context reduction only happens reactively + via ``ContextWindowOverflowException``. Args: agent: The agent whose conversation history will be managed. - The agent's messages list is modified in-place. **kwargs: Additional keyword arguments for future extensibility. """ - # No proactive management - summarization only happens on context overflow - pass + if self.max_context_tokens is None: + return + + current_tokens = self._get_current_token_count(agent) + threshold = int(self.max_context_tokens * self.proactive_threshold) + + if current_tokens > threshold: + logger.debug( + "current_tokens=<%d>, threshold=<%d> | apply_management triggering summarization", + current_tokens, + threshold, + ) + self._do_proactive_summarization(agent) + + def _do_proactive_summarization(self, agent: "Agent") -> None: + """Run reduce_context with a guard against double-summarization in the same cycle. + + The before-model-call hook and apply_management (called in the agent's finally block) + can both trigger summarization in the same agent cycle. This method skips the second + call if the message count hasn't changed since the last summarization. + """ + msg_count = len(agent.messages) + if self._last_summarized_msg_count == msg_count: + logger.debug("skipping summarization — already summarized at message_count=<%d>", msg_count) + return + + try: + self.reduce_context(agent) + self._last_summarized_msg_count = len(agent.messages) + except Exception: + logger.warning("proactive summarization failed", exc_info=True) def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Reduce context using summarization. @@ -173,7 +268,9 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A except Exception as summarization_error: logger.error("Summarization failed: %s", summarization_error) - raise summarization_error from e + if e is not None: + raise summarization_error from e + raise def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 9186e0e70..238ed3af4 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -147,6 +147,12 @@ class BeforeToolCallEvent(HookEvent, _Interruptible): cancel_tool: A user defined message that when set, will cancel the tool call. The message will be placed into a tool result with an error status. If set to `True`, Strands will cancel the tool call and use a default cancel message. + tool_is_read_only: Convenience property. True if the selected tool is read-only, False otherwise + (including when selected_tool is None). + tool_is_destructive: Convenience property. True if the selected tool is destructive, False otherwise + (including when selected_tool is None). + tool_requires_confirmation: Convenience property. True if the selected tool requires confirmation, + False otherwise (including when selected_tool is None). """ selected_tool: AgentTool | None @@ -157,6 +163,21 @@ class BeforeToolCallEvent(HookEvent, _Interruptible): def _can_write(self, name: str) -> bool: return name in ["cancel_tool", "selected_tool", "tool_use"] + @property + def tool_is_read_only(self) -> bool: + """Whether the selected tool only reads state. False when selected_tool is None.""" + return self.selected_tool is not None and self.selected_tool.is_read_only + + @property + def tool_is_destructive(self) -> bool: + """Whether the selected tool performs irreversible actions. False when selected_tool is None.""" + return self.selected_tool is not None and self.selected_tool.is_destructive + + @property + def tool_requires_confirmation(self) -> bool: + """Whether the selected tool requires user confirmation. False when selected_tool is None.""" + return self.selected_tool is not None and self.selected_tool.requires_confirmation + @override def _interrupt_id(self, name: str) -> str: """Unique id for the interrupt. diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 9207df9b8..9df0fcc20 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -459,6 +459,10 @@ def __init__( tool_spec: ToolSpec, tool_func: Callable[P, R], metadata: FunctionToolMetadata, + *, + read_only: bool = False, + destructive: bool = False, + requires_confirmation: bool = False, ): """Initialize the decorated function tool. @@ -467,6 +471,9 @@ def __init__( tool_spec: The tool specification containing metadata for Agent integration. tool_func: The original function being decorated. metadata: The FunctionToolMetadata object with extracted function information. + read_only: Whether this tool only reads state without modification. + destructive: Whether this tool performs irreversible actions. + requires_confirmation: Whether this tool should require user confirmation before execution. """ super().__init__() @@ -474,6 +481,9 @@ def __init__( self._tool_spec = tool_spec self._tool_func = tool_func self._metadata = metadata + self._read_only = read_only + self._destructive = destructive + self._requires_confirmation = requires_confirmation functools.update_wrapper(wrapper=self, wrapped=self._tool_func) @@ -506,7 +516,15 @@ def my_tool(): if instance is not None and not inspect.ismethod(self._tool_func): # Create a bound method tool_func = self._tool_func.__get__(instance, instance.__class__) - return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata) + return DecoratedFunctionTool( + self._tool_name, + self._tool_spec, + tool_func, + self._metadata, + read_only=self._read_only, + destructive=self._destructive, + requires_confirmation=self._requires_confirmation, + ) return self @@ -577,6 +595,24 @@ def tool_type(self) -> str: """ return "function" + @property + @override + def is_read_only(self) -> bool: + """Whether this tool only reads state without modification.""" + return self._read_only + + @property + @override + def is_destructive(self) -> bool: + """Whether this tool performs irreversible actions.""" + return self._destructive + + @property + @override + def requires_confirmation(self) -> bool: + """Whether this tool should require user confirmation before execution.""" + return self._requires_confirmation + @override async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: """Stream the tool with a tool use specification. @@ -725,6 +761,9 @@ def tool( inputSchema: JSONSchema | None = None, name: str | None = None, context: bool | str = False, + read_only: bool = False, + destructive: bool = False, + requires_confirmation: bool = False, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system @@ -734,6 +773,9 @@ def tool( # type: ignore inputSchema: JSONSchema | None = None, name: str | None = None, context: bool | str = False, + read_only: bool = False, + destructive: bool = False, + requires_confirmation: bool = False, ) -> DecoratedFunctionTool[P, R] | Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: """Decorator that transforms a Python function into a Strands tool. @@ -762,6 +804,10 @@ def tool( # type: ignore context: When provided, places an object in the designated parameter. If True, the param name defaults to 'tool_context', or if an override is needed, set context equal to a string to designate the param name. + read_only: Whether this tool only reads state without modification. Defaults to False. + destructive: Whether this tool performs irreversible actions. Defaults to False. + requires_confirmation: Whether this tool should require user confirmation before execution. + Defaults to False. Returns: An AgentTool that also mimics the original function when invoked @@ -816,13 +862,27 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]": tool_spec["description"] = description if inputSchema is not None: tool_spec["inputSchema"] = inputSchema + if read_only: + tool_spec["readOnly"] = True + if destructive: + tool_spec["destructive"] = True + if requires_confirmation: + tool_spec["requiresConfirmation"] = True tool_name = tool_spec.get("name", f.__name__) if not isinstance(tool_name, str): raise ValueError(f"Tool name must be a string, got {type(tool_name)}") - return DecoratedFunctionTool(tool_name, tool_spec, f, tool_meta) + return DecoratedFunctionTool( + tool_name, + tool_spec, + f, + tool_meta, + read_only=read_only, + destructive=destructive, + requires_confirmation=requires_confirmation, + ) # Handle both @tool and @tool() syntax if func is None: diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index bedd93f24..a535502ac 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -35,6 +35,10 @@ def __init__( mcp_client: "MCPClient", name_override: str | None = None, timeout: timedelta | None = None, + *, + read_only: bool | None = None, + destructive: bool | None = None, + requires_confirmation: bool | None = None, ) -> None: """Initialize a new MCPAgentTool instance. @@ -44,6 +48,12 @@ def __init__( name_override: Optional name to use for the agent tool (for disambiguation) If None, uses the original MCP tool name timeout: Optional timeout duration for tool execution + read_only: Override for read-only classification. When None, falls back to the + tool spec's ``readOnly`` field if present, otherwise False. + destructive: Override for destructive classification. When None, falls back to the + tool spec's ``destructive`` field if present, otherwise False. + requires_confirmation: Override for confirmation requirement. When None, falls back + to the tool spec's ``requiresConfirmation`` field if present, otherwise False. """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) @@ -51,6 +61,9 @@ def __init__( self.mcp_client = mcp_client self._agent_tool_name = name_override or mcp_tool.name self.timeout = timeout + self._read_only_override = read_only + self._destructive_override = destructive + self._requires_confirmation_override = requires_confirmation @property def tool_name(self) -> str: @@ -93,6 +106,24 @@ def tool_type(self) -> str: """ return "python" + @property + @override + def is_read_only(self) -> bool: + """Whether this tool only reads state. Set via constructor override.""" + return self._read_only_override if self._read_only_override is not None else False + + @property + @override + def is_destructive(self) -> bool: + """Whether this tool performs irreversible actions. Set via constructor override.""" + return self._destructive_override if self._destructive_override is not None else False + + @property + @override + def requires_confirmation(self) -> bool: + """Whether this tool requires user confirmation. Set via constructor override.""" + return self._requires_confirmation_override if self._requires_confirmation_override is not None else False + @override async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: """Stream the MCP tool. diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 9a0f0f722..85d35b603 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -270,6 +270,8 @@ def register_tool(self, tool: AgentTool) -> None: " Cannot add a duplicate tool which differs by a '-' or '_'" ) + self._validate_security_metadata(tool) + # Register in main registry self.registry[tool.tool_name] = tool @@ -288,6 +290,24 @@ def register_tool(self, tool: AgentTool) -> None: list(self.dynamic_tools.keys()), ) + def _validate_security_metadata(self, tool: AgentTool) -> None: + """Validate that a tool's security metadata is internally consistent. + + Args: + tool: The tool to validate. + + Raises: + ValueError: If the tool has contradictory security metadata. + """ + if tool.is_read_only and tool.is_destructive: + raise ValueError(f"Tool '{tool.tool_name}' cannot be both read_only and destructive") + + if tool.is_destructive and not tool.requires_confirmation: + logger.warning( + "tool_name=<%s> | tool is marked destructive but does not require confirmation", + tool.tool_name, + ) + def replace(self, new_tool: AgentTool) -> None: """Replace an existing tool with a new implementation. @@ -305,6 +325,8 @@ def replace(self, new_tool: AgentTool) -> None: if tool_name not in self.registry: raise ValueError(f"Cannot replace tool '{tool_name}' - tool does not exist") + self._validate_security_metadata(new_tool) + # Update main registry self.registry[tool_name] = new_tool diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index ccfeac323..f5fc04e20 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -240,6 +240,24 @@ def tool_type(self) -> str: """ return "python" + @property + @override + def is_read_only(self) -> bool: + """Whether this tool only reads state, derived from its ToolSpec.""" + return self._tool_spec.get("readOnly") is True + + @property + @override + def is_destructive(self) -> bool: + """Whether this tool performs irreversible actions, derived from its ToolSpec.""" + return self._tool_spec.get("destructive") is True + + @property + @override + def requires_confirmation(self) -> bool: + """Whether this tool requires user confirmation, derived from its ToolSpec.""" + return self._tool_spec.get("requiresConfirmation") is True + @override async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: """Stream the Python function with the given tool use request. diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 088c83bdb..d977aab4c 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -30,12 +30,19 @@ class ToolSpec(TypedDict): outputSchema: Optional JSON Schema defining the expected output format. Note: Not all model providers support this field. Providers that don't support it should filter it out before sending to their API. + readOnly: Optional flag indicating the tool only reads state without modification. + destructive: Optional flag indicating the tool performs irreversible actions. + requiresConfirmation: Optional flag indicating the tool should require user + confirmation before execution. """ description: str inputSchema: JSONSchema name: str outputSchema: NotRequired[JSONSchema] + readOnly: NotRequired[bool] + destructive: NotRequired[bool] + requiresConfirmation: NotRequired[bool] class Tool(TypedDict): @@ -255,6 +262,33 @@ def supports_hot_reload(self) -> bool: """ return False + @property + def is_read_only(self) -> bool: + """Whether this tool only reads state without modification. + + Returns: + False by default. Override in subclasses or set via @tool(read_only=True). + """ + return False + + @property + def is_destructive(self) -> bool: + """Whether this tool performs irreversible actions. + + Returns: + False by default. Override in subclasses or set via @tool(destructive=True). + """ + return False + + @property + def requires_confirmation(self) -> bool: + """Whether this tool should require user confirmation before execution. + + Returns: + False by default. Override in subclasses or set via @tool(requires_confirmation=True). + """ + return False + @abstractmethod # pragma: no cover def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: diff --git a/tests/strands/agent/test_token_aware_context_management.py b/tests/strands/agent/test_token_aware_context_management.py new file mode 100644 index 000000000..fe3177365 --- /dev/null +++ b/tests/strands/agent/test_token_aware_context_management.py @@ -0,0 +1,1232 @@ +"""Tests for token-aware context management features. + +Covers: +- _estimate_tokens utility +- SlidingWindowConversationManager: max_context_tokens, token_counter, micro-compaction +- SummarizingConversationManager: proactive token-budget summarization +""" + +from typing import cast +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from strands.agent.agent import Agent +from strands.agent.conversation_manager._token_utils import IMAGE_CHAR_ESTIMATE, estimate_tokens +from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.hooks.events import BeforeModelCallEvent +from strands.hooks.registry import HookRegistry +from strands.types.content import Messages + +# ============================================================================== +# estimate_tokens utility tests +# ============================================================================== + + +class TestEstimateTokens: + def test_empty_messages(self): + assert estimate_tokens([]) == 0 + + def test_text_messages(self): + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello world"}]}, + {"role": "assistant", "content": [{"text": "Hi there, how can I help?"}]}, + ] + result = estimate_tokens(messages) + total_chars = len("Hello world") + len("Hi there, how can I help?") + assert result == total_chars // 4 + + def test_tool_result_messages(self): + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "A" * 1000}], + "status": "success", + } + } + ], + } + ] + result = estimate_tokens(messages) + assert result == 1000 // 4 + + def test_tool_use_messages(self): + messages: Messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "read_file", + "input": {"path": "/foo/bar.py"}, + } + } + ], + } + ] + result = estimate_tokens(messages) + assert result > 0 + + def test_image_in_tool_result(self): + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"image": {"format": "png", "source": {"bytes": b"data"}}}], + "status": "success", + } + } + ], + } + ] + result = estimate_tokens(messages) + assert result == IMAGE_CHAR_ESTIMATE // 4 + + def test_standalone_image_block(self): + messages: Messages = [ + { + "role": "user", + "content": [{"image": {"format": "png", "source": {"bytes": b"data"}}}], + } + ] + result = estimate_tokens(messages) + assert result == IMAGE_CHAR_ESTIMATE // 4 + + def test_document_block(self): + messages: Messages = [ + { + "role": "user", + "content": [ + { + "document": { + "format": "pdf", + "name": "test.pdf", + "source": {"bytes": b"x" * 8000}, + } + } + ], + } + ] + result = estimate_tokens(messages) + assert result == 8000 // 4 + + def test_video_block(self): + messages: Messages = [ + { + "role": "user", + "content": [{"video": {"format": "mp4", "source": {"bytes": b"v"}}}], + } + ] + result = estimate_tokens(messages) + assert result == (IMAGE_CHAR_ESTIMATE * 10) // 4 + + def test_cache_point_block_zero_tokens(self): + messages: Messages = [ + {"role": "user", "content": [{"cachePoint": {"type": "default"}}]}, + ] + assert estimate_tokens(messages) == 0 + + def test_guard_content_block_zero_tokens(self): + messages: Messages = [ + {"role": "user", "content": [{"guardContent": {"text": {"text": "check"}}}]}, + ] + assert estimate_tokens(messages) == 0 + + def test_tool_use_input_uses_json_serialization(self): + messages: Messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "1", + "name": "tool", + "input": {"key": "value"}, + } + } + ], + } + ] + result = estimate_tokens(messages) + # json.dumps produces '{"key": "value"}' (18 chars) + "tool" (4 chars) = 22 chars + # str() would produce "{'key': 'value'}" (16 chars) — different + expected_chars = len("tool") + len('{"key": "value"}') + assert result == expected_chars // 4 + + def test_mixed_content(self): + messages: Messages = [ + { + "role": "user", + "content": [ + {"text": "A" * 400}, + ], + }, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "tool", "input": {}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "B" * 800}], "status": "success"}}, + ], + }, + ] + result = estimate_tokens(messages) + assert result > 0 + + def test_empty_content_blocks(self): + messages: Messages = [{"role": "user", "content": []}] + assert estimate_tokens(messages) == 0 + + +# ============================================================================== +# SlidingWindowConversationManager — max_context_tokens tests +# ============================================================================== + + +class TestSlidingWindowTokenBudget: + def test_default_no_token_budget(self): + manager = SlidingWindowConversationManager() + assert manager.max_context_tokens is None + + def test_apply_management_skips_when_under_both_limits(self): + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=10000, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "user", "content": [{"text": "short"}]}, + {"role": "assistant", "content": [{"text": "also short"}]}, + ] + agent = Agent(messages=messages) + original = messages.copy() + manager.apply_management(agent) + assert messages == original + + def test_apply_management_triggers_on_token_budget_exceeded(self): + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=10, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "user", "content": [{"text": "A" * 200}]}, + {"role": "assistant", "content": [{"text": "B" * 200}]}, + {"role": "user", "content": [{"text": "C" * 200}]}, + ] + agent = Agent(messages=messages) + manager.apply_management(agent) + assert len(messages) < 3 + + def test_apply_management_triggers_on_message_limit_even_without_token_budget(self): + manager = SlidingWindowConversationManager( + window_size=2, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second"}]}, + ] + agent = Agent(messages=messages) + manager.apply_management(agent) + assert len(messages) <= 2 + + def test_custom_token_counter(self): + call_count = 0 + + def my_counter(msgs): + nonlocal call_count + call_count += 1 + return 99999 + + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=100, + token_counter=my_counter, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "More"}]}, + ] + agent = Agent(messages=messages) + manager.apply_management(agent) + assert call_count > 0 + assert len(messages) < 3 + + def test_token_budget_always_uses_heuristic(self): + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=5000, + ) + mock_agent = MagicMock() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + ] + # Even when model reports tokens, heuristic is used to avoid staleness + mock_agent.event_loop_metrics.latest_context_size = 6000 + + current = manager._get_current_token_count(mock_agent) + assert current == 100 # 400 chars / 4, NOT 6000 + + def test_before_model_call_enforces_token_budget(self): + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=50, + should_truncate_results=False, + ) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + {"role": "assistant", "content": [{"text": "B" * 400}]}, + {"role": "user", "content": [{"text": "C" * 400}]}, + ] + mock_agent.event_loop_metrics.latest_context_size = None + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "apply_management") as mock_apply: + registry.invoke_callbacks(event) + mock_apply.assert_called_once_with(mock_agent) + + def test_before_model_call_skips_when_under_budget(self): + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=100000, + ) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [{"role": "user", "content": [{"text": "short"}]}] + mock_agent.event_loop_metrics.latest_context_size = None + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "apply_management") as mock_apply: + registry.invoke_callbacks(event) + mock_apply.assert_not_called() + + def test_backward_compatibility_no_token_params(self): + manager = SlidingWindowConversationManager(window_size=40) + assert manager.max_context_tokens is None + assert manager.compactable_after_messages is None + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + agent = Agent(messages=messages) + manager.apply_management(agent) + assert len(messages) == 2 + + +# ============================================================================== +# SlidingWindowConversationManager — micro-compaction tests +# ============================================================================== + + +class TestMicroCompaction: + def test_compactable_after_messages_validation(self): + with pytest.raises(ValueError, match="compactable_after_messages"): + SlidingWindowConversationManager(compactable_after_messages=0) + with pytest.raises(ValueError, match="compactable_after_messages"): + SlidingWindowConversationManager(compactable_after_messages=-3) + + def test_micro_compact_replaces_old_tool_results(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=2, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "read", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "A" * 5000}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Recent message 1"}]}, + {"role": "assistant", "content": [{"text": "Recent response"}]}, + ] + reclaimed = manager._micro_compact(messages) + assert reclaimed > 0 + assert messages[1]["content"][0]["toolResult"]["content"][0]["text"] == manager._COMPACT_STUB + + def test_micro_compact_preserves_recent_results(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=3, + should_truncate_results=False, + ) + original_text = "B" * 5000 + messages: Messages = [ + {"role": "user", "content": [{"text": "Old message"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "2", "name": "read", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "2", "content": [{"text": original_text}], "status": "success"}} + ], + }, + ] + manager._micro_compact(messages) + # All 3 messages are within the compactable_after_messages=3 window + assert messages[2]["content"][0]["toolResult"]["content"][0]["text"] == original_text + + def test_micro_compact_does_not_double_compact(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=1, + should_truncate_results=False, + ) + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": manager._COMPACT_STUB}], + "status": "success", + } + } + ], + }, + {"role": "user", "content": [{"text": "Recent"}]}, + ] + reclaimed = manager._micro_compact(messages) + assert reclaimed == 0 + + def test_micro_compact_preserves_tool_pair_structure(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=1, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "big result"}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Latest"}]}, + ] + manager._micro_compact(messages) + assert "toolResult" in messages[1]["content"][0] + assert messages[1]["content"][0]["toolResult"]["toolUseId"] == "1" + assert messages[1]["content"][0]["toolResult"]["status"] == "success" + + def test_micro_compact_handles_empty_messages(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=5, + should_truncate_results=False, + ) + reclaimed = manager._micro_compact([]) + assert reclaimed == 0 + + def test_micro_compact_runs_in_apply_management(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=2, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "X" * 10000}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Recent 1"}]}, + {"role": "assistant", "content": [{"text": "Recent 2"}]}, + ] + agent = Agent(messages=messages) + manager.apply_management(agent) + assert messages[1]["content"][0]["toolResult"]["content"][0]["text"] == manager._COMPACT_STUB + + def test_micro_compact_skips_non_tool_result_blocks(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=1, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "user", "content": [{"text": "plain text should not be compacted"}]}, + {"role": "user", "content": [{"text": "Recent"}]}, + ] + original_text = messages[0]["content"][0]["text"] + manager._micro_compact(messages) + assert messages[0]["content"][0]["text"] == original_text + + def test_micro_compact_replaces_image_blocks(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=1, + should_truncate_results=False, + ) + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"image": {"format": "png", "source": {"bytes": b"bigdata"}}}], + "status": "success", + } + } + ], + }, + {"role": "user", "content": [{"text": "Recent"}]}, + ] + reclaimed = manager._micro_compact(messages) + assert reclaimed == IMAGE_CHAR_ESTIMATE // 4 + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == manager._COMPACT_STUB + + def test_micro_compact_reclaimed_subtracts_stub_length(self): + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=1, + should_truncate_results=False, + ) + original_text = "A" * 200 + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": original_text}], + "status": "success", + } + } + ], + }, + {"role": "user", "content": [{"text": "Recent"}]}, + ] + reclaimed = manager._micro_compact(messages) + stub_len = len(manager._COMPACT_STUB) + assert reclaimed == (len(original_text) - stub_len) // 4 + + def test_micro_compact_skips_already_processed_messages(self): + """Issue #9: _last_compacted_index prevents re-scanning already compacted messages.""" + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=1, + should_truncate_results=False, + ) + messages: Messages = [ + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "X" * 5000}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "msg2"}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "2", "content": [{"text": "Y" * 5000}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "msg4"}]}, + ] + # First call compacts messages 0-2 (cutoff = 4-1=3) + reclaimed1 = manager._micro_compact(messages) + assert reclaimed1 > 0 + assert manager._last_compacted_index == 3 + + # Add more messages and compact again — should only process new range + messages.append({"role": "user", "content": [{"text": "msg5"}]}) + messages.append( + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "3", "content": [{"text": "Z" * 5000}], "status": "success"}} + ], + } + ) + messages.append({"role": "user", "content": [{"text": "msg7"}]}) + reclaimed2 = manager._micro_compact(messages) + # Should compact messages at indices 3-5 (cutoff=7-1=6), but _last_compacted_index=3 + # so it starts from 3 + assert reclaimed2 > 0 + + def test_micro_compact_then_truncation_interaction(self): + """Issue #11: micro-compaction + truncation work together without conflict.""" + manager = SlidingWindowConversationManager( + window_size=100, + compactable_after_messages=2, + should_truncate_results=True, + max_context_tokens=10, + ) + messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "X" * 10000}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "Recent 1"}]}, + {"role": "assistant", "content": [{"text": "Recent 2"}]}, + {"role": "user", "content": [{"text": "Recent 3"}]}, + ] + agent = Agent(messages=messages) + + # First micro-compaction should replace old tool result + manager._micro_compact(messages) + assert messages[1]["content"][0]["toolResult"]["content"][0]["text"] == manager._COMPACT_STUB + + # Then apply_management which checks token budget — should still work + manager.apply_management(agent) + # Messages should still be valid (no crash) + assert len(messages) > 0 + + +# ============================================================================== +# Parameter validation tests +# ============================================================================== + + +class TestParameterValidation: + def test_max_context_tokens_zero_raises_sliding_window(self): + with pytest.raises(ValueError, match="max_context_tokens"): + SlidingWindowConversationManager(max_context_tokens=0) + + def test_max_context_tokens_negative_raises_sliding_window(self): + with pytest.raises(ValueError, match="max_context_tokens"): + SlidingWindowConversationManager(max_context_tokens=-100) + + def test_max_context_tokens_zero_raises_summarizing(self): + with pytest.raises(ValueError, match="max_context_tokens"): + SummarizingConversationManager(max_context_tokens=0) + + def test_max_context_tokens_negative_raises_summarizing(self): + with pytest.raises(ValueError, match="max_context_tokens"): + SummarizingConversationManager(max_context_tokens=-50) + + def test_max_context_tokens_positive_accepted(self): + sw = SlidingWindowConversationManager(max_context_tokens=1000) + assert sw.max_context_tokens == 1000 + sm = SummarizingConversationManager(max_context_tokens=5000) + assert sm.max_context_tokens == 5000 + + +# ============================================================================== +# SummarizingConversationManager — proactive token-budget tests +# ============================================================================== + + +async def _mock_model_stream(response_text): + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + +class MockSummarizationAgent: + def __init__(self, summary_response="Summary of conversation."): + self.summary_response = summary_response + self.system_prompt = None + self.messages = [] + self.model = Mock() + self.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream(self.summary_response)) + self.call_tracker = Mock() + self.tool_registry = Mock() + self.tool_names = [] + self._default_structured_output_model = None + + +def create_mock_agent(summary_response="Summary of conversation.") -> "Agent": + return cast("Agent", MockSummarizationAgent(summary_response)) + + +class TestSummarizingTokenBudget: + def test_default_no_token_budget(self): + manager = SummarizingConversationManager() + assert manager.max_context_tokens is None + assert manager.proactive_threshold == 0.8 + + def test_proactive_threshold_clamped(self): + manager = SummarizingConversationManager(max_context_tokens=1000, proactive_threshold=0.05) + assert manager.proactive_threshold == 0.1 + + manager = SummarizingConversationManager(max_context_tokens=1000, proactive_threshold=1.5) + assert manager.proactive_threshold == 1.0 + + def test_apply_management_triggers_when_over_budget(self): + """apply_management checks token budget and triggers summarization when exceeded.""" + manager = SummarizingConversationManager( + max_context_tokens=100, + proactive_threshold=0.5, + preserve_recent_messages=1, + ) + mock_agent = create_mock_agent() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 10000}]}, + {"role": "assistant", "content": [{"text": "B" * 10000}]}, + ] + mock_agent.event_loop_metrics = MagicMock() + mock_agent.event_loop_metrics.latest_context_size = None + + with patch.object(manager, "reduce_context") as mock_reduce: + manager.apply_management(mock_agent) + mock_reduce.assert_called_once() + + def test_apply_management_noop_without_token_budget(self): + manager = SummarizingConversationManager() + mock_agent = create_mock_agent() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 10000}]}, + {"role": "assistant", "content": [{"text": "B" * 10000}]}, + ] + original = mock_agent.messages.copy() + manager.apply_management(mock_agent) + assert mock_agent.messages == original + + def test_before_model_call_proactive_summarization(self): + manager = SummarizingConversationManager( + max_context_tokens=100, + proactive_threshold=0.5, + preserve_recent_messages=1, + ) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 1000}]}, + ] + mock_agent.event_loop_metrics.latest_context_size = None + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "reduce_context") as mock_reduce: + registry.invoke_callbacks(event) + mock_reduce.assert_called_once() + + def test_before_model_call_skips_without_token_budget(self): + """Issue #16: hook is not registered when max_context_tokens is None.""" + manager = SummarizingConversationManager() + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [{"role": "user", "content": [{"text": "A" * 10000}]}] + mock_agent.event_loop_metrics.latest_context_size = None + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "reduce_context") as mock_reduce: + registry.invoke_callbacks(event) + mock_reduce.assert_not_called() + + def test_token_count_always_uses_heuristic(self): + """Issue #3: _get_current_token_count always uses heuristic, never stale model-reported value.""" + manager = SummarizingConversationManager( + max_context_tokens=5000, + proactive_threshold=0.8, + ) + mock_agent = MagicMock() + mock_agent.messages = [{"role": "user", "content": [{"text": "A" * 400}]}] + mock_agent.event_loop_metrics.latest_context_size = 4500 + + current = manager._get_current_token_count(mock_agent) + assert current == 100 # 400 chars / 4, NOT 4500 + + def test_custom_token_counter(self): + def always_big(msgs): + return 999999 + + manager = SummarizingConversationManager( + max_context_tokens=100, + token_counter=always_big, + preserve_recent_messages=1, + ) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A"}]}, + {"role": "assistant", "content": [{"text": "B"}]}, + ] + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "reduce_context") as mock_reduce: + registry.invoke_callbacks(event) + mock_reduce.assert_called_once() + + def test_backward_compatibility(self): + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + ) + assert manager.max_context_tokens is None + + mock_agent = create_mock_agent() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 10000}]}, + {"role": "assistant", "content": [{"text": "B" * 10000}]}, + ] + original = mock_agent.messages.copy() + manager.apply_management(mock_agent) + assert mock_agent.messages == original + + def test_proactive_summarization_catches_all_exceptions(self): + """Issue #4: hook catches Exception, not just ContextWindowOverflowException.""" + manager = SummarizingConversationManager( + max_context_tokens=10, + proactive_threshold=0.5, + preserve_recent_messages=1, + ) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + ] + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "reduce_context", side_effect=RuntimeError("model timeout")): + # Should not raise — gracefully logs warning + registry.invoke_callbacks(event) + + def test_reduce_context_preserves_cause_chain_with_exception(self): + """Issue #13: raise from e only when e is not None.""" + manager = SummarizingConversationManager( + preserve_recent_messages=100, + ) + mock_agent = create_mock_agent() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "short"}]}, + ] + mock_agent.event_loop_metrics = MagicMock() + mock_agent.event_loop_metrics.latest_context_size = None + + # When e=None, should raise without "from None" (preserves natural __cause__) + with pytest.raises(Exception) as exc_info: + manager.reduce_context(mock_agent, e=None) + assert exc_info.value.__cause__ is None + + +# ============================================================================== +# _model_call_count semantics regression tests +# ============================================================================== + + +class TestModelCallCountSemantics: + """Issue #18: _model_call_count should only increment when per_turn is enabled.""" + + def test_model_call_count_not_incremented_when_per_turn_false(self): + manager = SlidingWindowConversationManager(per_turn=False) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [{"role": "user", "content": [{"text": "hi"}]}] + mock_agent.event_loop_metrics.latest_context_size = None + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + registry.invoke_callbacks(event) + registry.invoke_callbacks(event) + registry.invoke_callbacks(event) + assert manager._model_call_count == 0 + + def test_model_call_count_incremented_when_per_turn_true(self): + manager = SlidingWindowConversationManager(per_turn=True) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [{"role": "user", "content": [{"text": "hi"}]}] + mock_agent.event_loop_metrics.latest_context_size = None + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "apply_management"): + registry.invoke_callbacks(event) + registry.invoke_callbacks(event) + assert manager._model_call_count == 2 + + def test_model_call_count_incremented_when_per_turn_int(self): + manager = SlidingWindowConversationManager(per_turn=3) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [{"role": "user", "content": [{"text": "hi"}]}] + mock_agent.event_loop_metrics.latest_context_size = None + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "apply_management"): + for _ in range(5): + registry.invoke_callbacks(event) + assert manager._model_call_count == 5 + + def test_per_turn_int_modulo_applies_correctly(self): + applied_count = 0 + + manager = SlidingWindowConversationManager(per_turn=3) + + def counting_apply(agent, **kwargs): + nonlocal applied_count + applied_count += 1 + + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [{"role": "user", "content": [{"text": "hi"}]}] + mock_agent.event_loop_metrics.latest_context_size = None + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "apply_management", side_effect=counting_apply): + for _ in range(9): + registry.invoke_callbacks(event) + # Should apply at calls 3, 6, 9 + assert applied_count == 3 + + +# ============================================================================== +# Integration: hook -> apply_management -> reduce_context flow +# ============================================================================== + + +class TestIntegrationHookFlow: + """Issue #10: Integration test for the full hook -> apply_management -> reduce_context flow.""" + + def test_sliding_window_hook_triggers_full_management_pipeline(self): + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=50, + compactable_after_messages=2, + should_truncate_results=False, + ) + registry = HookRegistry() + manager.register_hooks(registry) + + messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "X" * 5000}], "status": "success"}} + ], + }, + {"role": "user", "content": [{"text": "msg3"}]}, + {"role": "assistant", "content": [{"text": "msg4"}]}, + {"role": "user", "content": [{"text": "msg5"}]}, + ] + agent = Agent(messages=messages) + event = BeforeModelCallEvent(agent=agent, invocation_state={}) + + registry.invoke_callbacks(event) + + # Micro-compaction should have run on old tool result + # The exact state depends on whether reduce_context also trimmed + # but the key invariant is: no crash and messages are valid + assert len(messages) > 0 + for msg in messages: + assert msg["role"] in ("user", "assistant") + + def test_last_compacted_index_adjusted_after_trim(self): + """_last_compacted_index is adjusted when messages are trimmed by reduce_context.""" + manager = SlidingWindowConversationManager( + window_size=2, + compactable_after_messages=1, + should_truncate_results=False, + ) + initial_index = 5 + manager._last_compacted_index = initial_index + + messages: Messages = [ + {"role": "user", "content": [{"text": "a"}]}, + {"role": "assistant", "content": [{"text": "b"}]}, + {"role": "user", "content": [{"text": "c"}]}, + {"role": "assistant", "content": [{"text": "d"}]}, + {"role": "user", "content": [{"text": "e"}]}, + ] + original_len = len(messages) + agent = Agent(messages=messages) + + manager.reduce_context(agent) + trimmed_count = original_len - len(messages) + assert trimmed_count > 0 + assert manager._last_compacted_index == max(0, initial_index - trimmed_count) + + +# ============================================================================== +# Token budget convergence tests +# ============================================================================== + + +class TestTokenBudgetConvergence: + """Bug #1: apply_management must loop reduce_context until under token budget.""" + + def test_converges_when_under_window_size_but_over_token_budget(self): + """Messages under window_size but over max_context_tokens — must reduce repeatedly. + + 5 messages x 400 chars = 2000 chars / 4 = 500 tokens, budget = 100. + Each reduce_context trims 2 messages (default when under window_size). + After 1st: 3 msgs, 300 tokens. After 2nd: 1 msg, 100 tokens. Converges. + """ + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=100, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + {"role": "assistant", "content": [{"text": "B" * 400}]}, + {"role": "user", "content": [{"text": "C" * 400}]}, + {"role": "assistant", "content": [{"text": "D" * 400}]}, + {"role": "user", "content": [{"text": "E" * 400}]}, + ] + agent = Agent(messages=messages) + manager.apply_management(agent) + current_tokens = manager.token_counter(agent.messages) + assert current_tokens <= 100 + assert len(messages) < 5 + + def test_stops_when_no_progress(self): + """If reduce_context can't shrink further, apply_management should not loop forever.""" + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=1, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + ] + agent = Agent(messages=messages) + manager.apply_management(agent) + assert len(messages) >= 1 + + def test_loop_terminates_when_reduce_makes_no_progress(self): + """apply_management stops looping when reduce_context can't shrink further. + + Patches reduce_context to never actually remove messages, simulating a stuck state. + The loop must detect no-progress and break rather than spinning. + """ + reduce_call_count = 0 + + def noop_reduce(agent, **kwargs): + nonlocal reduce_call_count + reduce_call_count += 1 + + manager = SlidingWindowConversationManager( + window_size=100, + max_context_tokens=1, + should_truncate_results=False, + ) + messages: Messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + {"role": "assistant", "content": [{"text": "B" * 400}]}, + {"role": "user", "content": [{"text": "C" * 400}]}, + ] + agent = Agent(messages=messages) + + with patch.object(manager, "reduce_context", side_effect=noop_reduce): + manager.apply_management(agent) + + assert reduce_call_count == 1 + + +# ============================================================================== +# State round-trip tests +# ============================================================================== + + +class TestStateRoundTrip: + """Test gap #13: get_state / restore_from_session round-trip for new fields.""" + + def test_sliding_window_state_round_trip(self): + manager = SlidingWindowConversationManager( + window_size=40, + compactable_after_messages=5, + per_turn=3, + ) + manager._model_call_count = 7 + manager._last_compacted_index = 12 + manager.removed_message_count = 3 + + state = manager.get_state() + assert state["model_call_count"] == 7 + assert state["last_compacted_index"] == 12 + assert state["removed_message_count"] == 3 + + new_manager = SlidingWindowConversationManager( + window_size=40, + compactable_after_messages=5, + per_turn=3, + ) + new_manager.restore_from_session(state) + + assert new_manager._model_call_count == 7 + assert new_manager._last_compacted_index == 12 + assert new_manager.removed_message_count == 3 + + def test_sliding_window_state_defaults_for_missing_keys(self): + """Backward compat: old session state without new keys should use defaults.""" + manager = SlidingWindowConversationManager(window_size=40) + state = { + "__name__": "SlidingWindowConversationManager", + "removed_message_count": 5, + } + manager.restore_from_session(state) + assert manager._model_call_count == 0 + assert manager._last_compacted_index == 0 + assert manager.removed_message_count == 5 + + +# ============================================================================== +# SummarizingConversationManager — apply_management contract tests +# ============================================================================== + + +class TestSummarizingApplyManagement: + """Design #5: apply_management should honor the token budget contract.""" + + def test_apply_management_triggers_summarization_when_over_budget(self): + manager = SummarizingConversationManager( + max_context_tokens=100, + proactive_threshold=0.5, + preserve_recent_messages=1, + ) + mock_agent = create_mock_agent() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 10000}]}, + {"role": "assistant", "content": [{"text": "B" * 10000}]}, + ] + + with patch.object(manager, "reduce_context") as mock_reduce: + manager.apply_management(mock_agent) + mock_reduce.assert_called_once() + + def test_apply_management_noop_when_under_budget(self): + manager = SummarizingConversationManager( + max_context_tokens=100000, + proactive_threshold=0.8, + ) + mock_agent = create_mock_agent() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "short"}]}, + ] + + with patch.object(manager, "reduce_context") as mock_reduce: + manager.apply_management(mock_agent) + mock_reduce.assert_not_called() + + def test_apply_management_noop_without_max_context_tokens(self): + manager = SummarizingConversationManager() + mock_agent = create_mock_agent() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 10000}]}, + ] + + with patch.object(manager, "reduce_context") as mock_reduce: + manager.apply_management(mock_agent) + mock_reduce.assert_not_called() + + def test_apply_management_catches_exceptions(self): + manager = SummarizingConversationManager( + max_context_tokens=10, + proactive_threshold=0.5, + preserve_recent_messages=1, + ) + mock_agent = create_mock_agent() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + ] + + with patch.object(manager, "reduce_context", side_effect=RuntimeError("model timeout")): + manager.apply_management(mock_agent) + + def test_no_double_summarization_in_same_cycle(self): + """Hook and apply_management in same cycle should not both call reduce_context.""" + reduce_count = 0 + + def counting_reduce(agent, **kwargs): + nonlocal reduce_count + reduce_count += 1 + # Simulate successful summarization by shrinking messages + agent.messages[:] = agent.messages[-1:] + + manager = SummarizingConversationManager( + max_context_tokens=10, + proactive_threshold=0.5, + preserve_recent_messages=1, + ) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + {"role": "assistant", "content": [{"text": "B" * 400}]}, + ] + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + with patch.object(manager, "reduce_context", side_effect=counting_reduce): + # Hook fires — triggers first summarization + registry.invoke_callbacks(event) + assert reduce_count == 1 + + # apply_management fires (as agent's finally block would) — should skip + manager.apply_management(mock_agent) + assert reduce_count == 1 + + def test_summarization_runs_again_after_new_messages(self): + """After new messages arrive, summarization should fire again.""" + reduce_count = 0 + + def counting_reduce(agent, **kwargs): + nonlocal reduce_count + reduce_count += 1 + agent.messages[:] = agent.messages[-1:] + + manager = SummarizingConversationManager( + max_context_tokens=10, + proactive_threshold=0.5, + preserve_recent_messages=1, + ) + + mock_agent = MagicMock() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "A" * 400}]}, + {"role": "assistant", "content": [{"text": "B" * 400}]}, + ] + + with patch.object(manager, "reduce_context", side_effect=counting_reduce): + manager.apply_management(mock_agent) + assert reduce_count == 1 + + # Simulate new messages arriving + mock_agent.messages.append({"role": "user", "content": [{"text": "C" * 400}]}) + mock_agent.messages.append({"role": "assistant", "content": [{"text": "D" * 400}]}) + + manager.apply_management(mock_agent) + assert reduce_count == 2 diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py index 9cb90167d..dea484c5c 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -54,6 +54,9 @@ def create_mock_tool(tool_name: str, mcp_tool_name: str | None = None) -> MagicM tool.mcp_tool = MagicMock(spec=MCPTool) tool.mcp_tool.name = mcp_tool_name or tool_name tool.mcp_tool.description = f"Description for {tool_name}" + tool.is_read_only = False + tool.is_destructive = False + tool.requires_confirmation = False return tool diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 3723f381b..ee9f294c6 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -130,15 +130,21 @@ def function() -> str: def test_register_tool_duplicate_name_without_hot_reload(): """Test that registering a tool with duplicate name raises ValueError when hot reload is not supported.""" # Create mock tools that don't support hot reload - tool_1 = MagicMock() + tool_1 = MagicMock(spec=PythonAgentTool) tool_1.tool_name = "duplicate_tool" tool_1.supports_hot_reload = False tool_1.is_dynamic = False + tool_1.is_read_only = False + tool_1.is_destructive = False + tool_1.requires_confirmation = False - tool_2 = MagicMock() + tool_2 = MagicMock(spec=PythonAgentTool) tool_2.tool_name = "duplicate_tool" tool_2.supports_hot_reload = False tool_2.is_dynamic = False + tool_2.is_read_only = False + tool_2.is_destructive = False + tool_2.requires_confirmation = False tool_registry = ToolRegistry() tool_registry.register_tool(tool_1) @@ -156,11 +162,17 @@ def test_register_tool_duplicate_name_with_hot_reload(): tool_1.tool_name = "hot_reload_tool" tool_1.supports_hot_reload = True tool_1.is_dynamic = False + tool_1.is_read_only = False + tool_1.is_destructive = False + tool_1.requires_confirmation = False tool_2 = MagicMock(spec=PythonAgentTool) tool_2.tool_name = "hot_reload_tool" tool_2.supports_hot_reload = True tool_2.is_dynamic = False + tool_2.is_read_only = False + tool_2.is_destructive = False + tool_2.requires_confirmation = False tool_registry = ToolRegistry() tool_registry.register_tool(tool_1) @@ -519,10 +531,16 @@ def test_tool_registry_replace_existing_tool(): old_tool.tool_name = "my_tool" old_tool.is_dynamic = False old_tool.supports_hot_reload = False + old_tool.is_read_only = False + old_tool.is_destructive = False + old_tool.requires_confirmation = False new_tool = MagicMock() new_tool.tool_name = "my_tool" new_tool.is_dynamic = False + new_tool.is_read_only = False + new_tool.is_destructive = False + new_tool.requires_confirmation = False registry = ToolRegistry() registry.register_tool(old_tool) @@ -535,6 +553,9 @@ def test_tool_registry_replace_nonexistent_tool(): """Test replacing a tool that doesn't exist raises ValueError.""" new_tool = MagicMock() new_tool.tool_name = "my_tool" + new_tool.is_read_only = False + new_tool.is_destructive = False + new_tool.requires_confirmation = False registry = ToolRegistry() @@ -548,10 +569,16 @@ def test_tool_registry_replace_dynamic_tool(): old_tool.tool_name = "dynamic_tool" old_tool.is_dynamic = True old_tool.supports_hot_reload = True + old_tool.is_read_only = False + old_tool.is_destructive = False + old_tool.requires_confirmation = False new_tool = MagicMock() new_tool.tool_name = "dynamic_tool" new_tool.is_dynamic = True + new_tool.is_read_only = False + new_tool.is_destructive = False + new_tool.requires_confirmation = False registry = ToolRegistry() registry.register_tool(old_tool) @@ -567,10 +594,16 @@ def test_tool_registry_replace_dynamic_with_non_dynamic(): old_tool.tool_name = "my_tool" old_tool.is_dynamic = True old_tool.supports_hot_reload = True + old_tool.is_read_only = False + old_tool.is_destructive = False + old_tool.requires_confirmation = False new_tool = MagicMock() new_tool.tool_name = "my_tool" new_tool.is_dynamic = False + new_tool.is_read_only = False + new_tool.is_destructive = False + new_tool.requires_confirmation = False registry = ToolRegistry() registry.register_tool(old_tool) @@ -589,10 +622,16 @@ def test_tool_registry_replace_non_dynamic_with_dynamic(): old_tool.tool_name = "my_tool" old_tool.is_dynamic = False old_tool.supports_hot_reload = False + old_tool.is_read_only = False + old_tool.is_destructive = False + old_tool.requires_confirmation = False new_tool = MagicMock() new_tool.tool_name = "my_tool" new_tool.is_dynamic = True + new_tool.is_read_only = False + new_tool.is_destructive = False + new_tool.requires_confirmation = False registry = ToolRegistry() registry.register_tool(old_tool) diff --git a/tests/strands/tools/test_tool_security_metadata.py b/tests/strands/tools/test_tool_security_metadata.py new file mode 100644 index 000000000..cbdaff427 --- /dev/null +++ b/tests/strands/tools/test_tool_security_metadata.py @@ -0,0 +1,426 @@ +"""Tests for tool security metadata (is_read_only, is_destructive, requires_confirmation). + +Covers: +- AgentTool base class defaults +- @tool decorator parameters +- ToolSpec round-trip (PythonAgentTool reads from spec) +- MCPAgentTool with overrides and spec fallback +- ToolRegistry contradiction rejection and destructive-without-confirmation warning +- BeforeToolCallEvent convenience properties +- Hook-based permission gate integration test +""" + +import logging +from unittest.mock import MagicMock + +import pytest + +from strands.hooks.events import BeforeToolCallEvent +from strands.tools.decorator import DecoratedFunctionTool, tool +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.tools.registry import ToolRegistry +from strands.tools.tools import PythonAgentTool + +# --------------------------------------------------------------------------- +# 1. AgentTool base class defaults +# --------------------------------------------------------------------------- + + +class _MinimalTool(PythonAgentTool): + pass + + +def _make_spec(name="test_tool", **extra): + spec = {"name": name, "description": "A test tool", "inputSchema": {"json": {"type": "object", "properties": {}}}} + spec.update(extra) + return spec + + +def test_agent_tool_defaults_all_false(): + t = PythonAgentTool(tool_name="t", tool_spec=_make_spec("t"), tool_func=lambda tu, **kw: None) + assert t.is_read_only is False + assert t.is_destructive is False + assert t.requires_confirmation is False + + +# --------------------------------------------------------------------------- +# 2. @tool decorator parameters +# --------------------------------------------------------------------------- + + +def test_tool_decorator_read_only(): + @tool(read_only=True) + def list_files(directory: str) -> str: + """List files in a directory.""" + return "" + + assert list_files.is_read_only is True + assert list_files.is_destructive is False + assert list_files.requires_confirmation is False + + +def test_tool_decorator_destructive_with_confirmation(): + @tool(destructive=True, requires_confirmation=True) + def delete_file(path: str) -> str: + """Delete a file permanently.""" + return "" + + assert delete_file.is_read_only is False + assert delete_file.is_destructive is True + assert delete_file.requires_confirmation is True + + +def test_tool_decorator_bare_has_defaults(): + @tool + def noop() -> str: + """Does nothing.""" + return "" + + assert noop.is_read_only is False + assert noop.is_destructive is False + assert noop.requires_confirmation is False + + +def test_tool_decorator_only_requires_confirmation(): + @tool(requires_confirmation=True) + def sensitive_op(data: str) -> str: + """A sensitive operation.""" + return data + + assert sensitive_op.is_read_only is False + assert sensitive_op.is_destructive is False + assert sensitive_op.requires_confirmation is True + + +# --------------------------------------------------------------------------- +# 3. ToolSpec round-trip (PythonAgentTool reads from spec) +# --------------------------------------------------------------------------- + + +def test_python_agent_tool_reads_read_only_from_spec(): + spec = _make_spec("reader", readOnly=True) + t = PythonAgentTool(tool_name="reader", tool_spec=spec, tool_func=lambda tu, **kw: None) + assert t.is_read_only is True + assert t.is_destructive is False + + +def test_python_agent_tool_reads_destructive_from_spec(): + spec = _make_spec("destroyer", destructive=True, requiresConfirmation=True) + t = PythonAgentTool(tool_name="destroyer", tool_spec=spec, tool_func=lambda tu, **kw: None) + assert t.is_destructive is True + assert t.requires_confirmation is True + assert t.is_read_only is False + + +def test_python_agent_tool_no_security_fields_in_spec(): + spec = _make_spec("plain") + t = PythonAgentTool(tool_name="plain", tool_spec=spec, tool_func=lambda tu, **kw: None) + assert t.is_read_only is False + assert t.is_destructive is False + assert t.requires_confirmation is False + + +# --------------------------------------------------------------------------- +# 4. MCPAgentTool with overrides and spec fallback +# --------------------------------------------------------------------------- + + +def _make_mcp_tool(name="mcp_test", input_schema=None): + mcp_tool = MagicMock() + mcp_tool.name = name + mcp_tool.description = f"MCP tool {name}" + mcp_tool.inputSchema = input_schema or {"type": "object", "properties": {}} + mcp_tool.outputSchema = None + return mcp_tool + + +def test_mcp_tool_defaults(): + mcp_tool = _make_mcp_tool() + t = MCPAgentTool(mcp_tool=mcp_tool, mcp_client=MagicMock()) + assert t.is_read_only is False + assert t.is_destructive is False + assert t.requires_confirmation is False + + +def test_mcp_tool_constructor_overrides(): + mcp_tool = _make_mcp_tool() + t = MCPAgentTool(mcp_tool=mcp_tool, mcp_client=MagicMock(), read_only=True) + assert t.is_read_only is True + assert t.is_destructive is False + + +def test_mcp_tool_destructive_override(): + mcp_tool = _make_mcp_tool() + t = MCPAgentTool( + mcp_tool=mcp_tool, + mcp_client=MagicMock(), + destructive=True, + requires_confirmation=True, + ) + assert t.is_destructive is True + assert t.requires_confirmation is True + assert t.is_read_only is False + + +def test_mcp_tool_override_takes_precedence_over_spec(): + """Constructor override should win even if the spec says differently.""" + mcp_tool = _make_mcp_tool() + t = MCPAgentTool( + mcp_tool=mcp_tool, + mcp_client=MagicMock(), + read_only=False, + ) + assert t.is_read_only is False + + +def test_mcp_tool_no_override_defaults_to_false(): + """Without constructor overrides, MCP tools default to False for all security properties.""" + mcp_tool = _make_mcp_tool() + t = MCPAgentTool(mcp_tool=mcp_tool, mcp_client=MagicMock()) + assert t.is_read_only is False + assert t.is_destructive is False + assert t.requires_confirmation is False + + +# --------------------------------------------------------------------------- +# 5. ToolRegistry validation +# --------------------------------------------------------------------------- + + +def test_registry_rejects_read_only_and_destructive(): + @tool(read_only=True, destructive=True) + def bad_tool() -> str: + """Contradictory tool.""" + return "" + + registry = ToolRegistry() + with pytest.raises(ValueError, match="cannot be both read_only and destructive"): + registry.register_tool(bad_tool) + + +def test_registry_warns_destructive_without_confirmation(caplog): + @tool(destructive=True) + def risky_tool() -> str: + """A risky tool without confirmation.""" + return "" + + registry = ToolRegistry() + with caplog.at_level(logging.WARNING): + registry.register_tool(risky_tool) + + assert "destructive but does not require confirmation" in caplog.text + + +def test_registry_accepts_destructive_with_confirmation(): + @tool(destructive=True, requires_confirmation=True) + def safe_destructive() -> str: + """A destructive tool that requires confirmation.""" + return "" + + registry = ToolRegistry() + registry.register_tool(safe_destructive) + assert "safe_destructive" in registry.registry + + +def test_registry_accepts_read_only(): + @tool(read_only=True) + def reader() -> str: + """A read-only tool.""" + return "" + + registry = ToolRegistry() + registry.register_tool(reader) + assert "reader" in registry.registry + + +def test_registry_accepts_default_metadata(): + @tool + def plain() -> str: + """A plain tool.""" + return "" + + registry = ToolRegistry() + registry.register_tool(plain) + assert "plain" in registry.registry + + +def test_registry_replace_rejects_contradictory_metadata(): + """ToolRegistry.replace() must also validate security metadata.""" + + @tool + def my_tool() -> str: + """A normal tool.""" + return "" + + @tool(read_only=True, destructive=True, name="my_tool") + def bad_replacement() -> str: + """Contradictory replacement.""" + return "" + + registry = ToolRegistry() + registry.register_tool(my_tool) + + with pytest.raises(ValueError, match="cannot be both read_only and destructive"): + registry.replace(bad_replacement) + + +# --------------------------------------------------------------------------- +# 6. BeforeToolCallEvent convenience properties +# --------------------------------------------------------------------------- + + +def _make_before_event(selected_tool=None): + return BeforeToolCallEvent( + agent=MagicMock(), + selected_tool=selected_tool, + tool_use={"name": "test", "toolUseId": "id-1", "input": {}}, + invocation_state={}, + ) + + +def test_event_convenience_props_with_none_tool(): + event = _make_before_event(selected_tool=None) + assert event.tool_is_read_only is False + assert event.tool_is_destructive is False + assert event.tool_requires_confirmation is False + + +def test_event_convenience_props_with_read_only_tool(): + @tool(read_only=True) + def reader() -> str: + """Read only.""" + return "" + + event = _make_before_event(selected_tool=reader) + assert event.tool_is_read_only is True + assert event.tool_is_destructive is False + assert event.tool_requires_confirmation is False + + +def test_event_convenience_props_with_destructive_tool(): + @tool(destructive=True, requires_confirmation=True) + def destroyer() -> str: + """Destructive.""" + return "" + + event = _make_before_event(selected_tool=destroyer) + assert event.tool_is_read_only is False + assert event.tool_is_destructive is True + assert event.tool_requires_confirmation is True + + +# --------------------------------------------------------------------------- +# 7. Integration: hook-based permission gate +# --------------------------------------------------------------------------- + + +def test_hook_cancels_destructive_tool(): + """Simulate a BeforeToolCallEvent hook that cancels destructive tools.""" + + @tool(destructive=True, requires_confirmation=True) + def delete_db() -> str: + """Delete the database.""" + return "" + + event = _make_before_event(selected_tool=delete_db) + + # Hook logic: cancel destructive tools + if event.tool_is_destructive: + event.cancel_tool = "Destructive tool requires approval" + + assert event.cancel_tool == "Destructive tool requires approval" + + +def test_hook_allows_read_only_tool(): + """Simulate a BeforeToolCallEvent hook that allows read-only tools.""" + + @tool(read_only=True) + def list_items() -> str: + """List items.""" + return "" + + event = _make_before_event(selected_tool=list_items) + + # Hook logic: only cancel non-read-only tools + if not event.tool_is_read_only: + event.cancel_tool = "Non-read-only tool blocked" + + assert event.cancel_tool is False + + +# --------------------------------------------------------------------------- +# 8. Backward compatibility +# --------------------------------------------------------------------------- + + +def test_decorated_tool_get_preserves_security_metadata(): + """Verify __get__ (descriptor protocol) propagates security metadata.""" + + class MyClass: + @tool(destructive=True, requires_confirmation=True) + def my_method(self, x: str) -> str: + """A method tool.""" + return x + + instance = MyClass() + bound_tool = instance.my_method + + assert isinstance(bound_tool, DecoratedFunctionTool) + assert bound_tool.is_destructive is True + assert bound_tool.requires_confirmation is True + assert bound_tool.is_read_only is False + + +def test_tool_spec_typed_dict_accepts_security_fields(): + """Verify ToolSpec TypedDict accepts the new NotRequired fields.""" + from strands.types.tools import ToolSpec + + spec: ToolSpec = { + "name": "test", + "description": "test", + "inputSchema": {}, + "readOnly": True, + "destructive": False, + "requiresConfirmation": False, + } + assert spec["readOnly"] is True + assert spec["destructive"] is False + + +def test_decorated_tool_writes_security_fields_to_spec(): + """@tool should write security fields into its ToolSpec for serialization consistency.""" + + @tool(read_only=True) + def reader(x: str) -> str: + """Read something.""" + return x + + assert reader.tool_spec.get("readOnly") is True + assert "destructive" not in reader.tool_spec + assert "requiresConfirmation" not in reader.tool_spec + + +def test_decorated_tool_destructive_fields_in_spec(): + """@tool(destructive=True, requires_confirmation=True) should write both fields to spec.""" + + @tool(destructive=True, requires_confirmation=True) + def deleter(x: str) -> str: + """Delete something.""" + return x + + assert deleter.tool_spec.get("destructive") is True + assert deleter.tool_spec.get("requiresConfirmation") is True + assert "readOnly" not in deleter.tool_spec + + +def test_bare_decorator_omits_security_fields_from_spec(): + """Plain @tool should not pollute ToolSpec with False security fields.""" + + @tool + def plain(x: str) -> str: + """Do something.""" + return x + + assert "readOnly" not in plain.tool_spec + assert "destructive" not in plain.tool_spec + assert "requiresConfirmation" not in plain.tool_spec