Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/strands/agent/conversation_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
size while preserving conversation coherence
- SummarizingConversationManager: An implementation that summarizes older context instead
of simply trimming it
- estimate_tokens: Lightweight token estimation utility (chars/4 heuristic) for conversation managers

Conversation managers help control memory usage and context length while maintaining relevant conversation state, which
is critical for effective agent interactions.
"""

from ._token_utils import TokenCounter, estimate_tokens
from .conversation_manager import ConversationManager
from .null_conversation_manager import NullConversationManager
from .sliding_window_conversation_manager import SlidingWindowConversationManager
Expand All @@ -23,4 +25,6 @@
"NullConversationManager",
"SlidingWindowConversationManager",
"SummarizingConversationManager",
"TokenCounter",
"estimate_tokens",
]
84 changes: 84 additions & 0 deletions src/strands/agent/conversation_manager/_token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Lightweight token estimation utilities for conversation managers."""

import json
from collections.abc import Callable
from typing import Any

from ...types.content import Messages

IMAGE_CHAR_ESTIMATE = 4000

TokenCounter = Callable[[Messages], int]


def estimate_tokens(messages: Messages) -> int:
"""Approximate token count for a message list using a chars/4 heuristic.

This is deliberately conservative (overestimates for English text, underestimates for CJK).
For model-specific accuracy, pass a custom ``token_counter`` to the conversation manager.

Args:
messages: The conversation message history.

Returns:
Estimated token count.
"""
total_chars = 0
for msg in messages:
for block in msg.get("content", []):
total_chars += _estimate_block_chars(block)
return total_chars // 4


def _estimate_block_chars(block: Any) -> int:
"""Estimate character count for a single content block."""
if "text" in block:
return len(block["text"])

if "toolResult" in block:
result = block["toolResult"]
chars = 0
for item in result.get("content", []):
if "text" in item:
chars += len(item["text"])
elif "image" in item:
chars += IMAGE_CHAR_ESTIMATE
return chars

if "toolUse" in block:
tool_use = block["toolUse"]
chars = len(tool_use.get("name", ""))
tool_input = tool_use.get("input", {})
if isinstance(tool_input, str):
chars += len(tool_input)
else:
chars += len(json.dumps(tool_input, default=str))
return chars

if "image" in block:
return IMAGE_CHAR_ESTIMATE

# NOTE (M3): len(bytes) returns raw binary size, not extractable text length.
# A 100KB PDF may contain only 5KB of text — this overestimates for binary
# documents but stays in the same accuracy class as the chars/4 heuristic.
if "document" in block:
doc = block["document"]
source = doc.get("source", {})
data = source.get("bytes", b"")
return len(data) if data else 200

# NOTE (L1): Rough placeholder — actual video token cost varies enormously by
# duration/resolution. Treat as an order-of-magnitude estimate.
if "video" in block:
return IMAGE_CHAR_ESTIMATE * 10

if "reasoningContent" in block:
rc = block["reasoningContent"]
if isinstance(rc, dict) and "reasoningText" in rc:
return len(rc["reasoningText"].get("text", ""))
return len(str(rc))

if "cachePoint" in block or "guardContent" in block or "citationsContent" in block:
return 0

return len(str(block))
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ...types.content import ContentBlock, Messages
from ...types.exceptions import ContextWindowOverflowException
from ...types.tools import ToolResultContent
from ._token_utils import IMAGE_CHAR_ESTIMATE, TokenCounter, estimate_tokens
from .conversation_manager import ConversationManager

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -37,6 +38,9 @@ def __init__(
should_truncate_results: bool = True,
*,
per_turn: bool | int = False,
max_context_tokens: int | None = None,
token_counter: TokenCounter | None = None,
compactable_after_messages: int | None = None,
):
"""Initialize the sliding window conversation manager.

Expand All @@ -54,19 +58,45 @@ def __init__(
manage message history and prevent the agent loop from slowing down. Start with
per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed
for performance tuning.
max_context_tokens: Optional maximum token budget for the conversation context.
When set, the manager checks both message count and estimated token count,
trimming oldest messages when the budget is exceeded. Uses the configured
``token_counter`` heuristic (chars/4 by default). Note: when both
``max_context_tokens`` and ``window_size`` are set, either limit can
independently trigger context reduction.
token_counter: Optional custom token counting function. Takes a Messages list
and returns an integer token count. When not provided, the built-in
``estimate_tokens`` heuristic (chars/4) is used.
compactable_after_messages: Optional message age after which tool results are
replaced with a short stub (``[Tool result cleared — re-run if needed]``).
This reclaims token budget from stale, re-runnable tool output while
preserving the toolUse/toolResult pair structure required by model APIs.

Raises:
ValueError: If per_turn is 0 or a negative integer.
ValueError: If per_turn is 0 or a negative integer, or if compactable_after_messages
is not a positive integer.
"""
if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0:
raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}")

if max_context_tokens is not None and max_context_tokens <= 0:
raise ValueError(f"max_context_tokens must be a positive integer, got {max_context_tokens}")

if compactable_after_messages is not None and compactable_after_messages <= 0:
raise ValueError(
f"compactable_after_messages must be a positive integer, got {compactable_after_messages}"
)

super().__init__()

self.window_size = window_size
self.should_truncate_results = should_truncate_results
self.per_turn = per_turn
self.max_context_tokens = max_context_tokens
self.token_counter: TokenCounter = token_counter or estimate_tokens
self.compactable_after_messages = compactable_after_messages
self._model_call_count = 0
self._last_compacted_index = 0

def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
"""Register hook callbacks for per-turn conversation management.
Expand All @@ -77,34 +107,44 @@ def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
"""
super().register_hooks(registry, **kwargs)

# Always register the callback - per_turn check happens in the callback
# Always register — per_turn and max_context_tokens checks happen in the callback
registry.add_callback(BeforeModelCallEvent, self._on_before_model_call)

def _on_before_model_call(self, event: BeforeModelCallEvent) -> None:
"""Handle before model call event for per-turn management.
"""Handle before model call event for per-turn management and token budget enforcement.

This callback is invoked before each model call. It tracks the model call count and applies message management
based on the per_turn configuration.
This callback is invoked before each model call. It applies management when either
the token budget is exceeded or per-turn management is due. A single
``apply_management`` call handles both token budget and message count limits, so
at most one call is made per hook invocation.

Args:
event: The before model call event containing the agent and model execution details.
"""
# Check if per_turn is enabled
if self.per_turn is False:
return
needs_apply = False

if self.max_context_tokens is not None:
current_tokens = self._get_current_token_count(event.agent)
if current_tokens > self.max_context_tokens:
logger.debug(
"current_tokens=<%d>, max_context_tokens=<%d> | token budget exceeded",
current_tokens,
self.max_context_tokens,
)
needs_apply = True

self._model_call_count += 1
if self.per_turn is not False:
self._model_call_count += 1

# Determine if we should apply management
should_apply = False
if self.per_turn is True:
should_apply = True
elif isinstance(self.per_turn, int) and self.per_turn > 0:
should_apply = self._model_call_count % self.per_turn == 0
if self.per_turn is True:
needs_apply = True
elif isinstance(self.per_turn, int) and self.per_turn > 0:
if self._model_call_count % self.per_turn == 0:
needs_apply = True

if should_apply:
if needs_apply:
logger.debug(
"model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management",
"model_call_count=<%d>, per_turn=<%s> | applying conversation management",
self._model_call_count,
self.per_turn,
)
Expand All @@ -118,6 +158,7 @@ def get_state(self) -> dict[str, Any]:
"""
state = super().get_state()
state["model_call_count"] = self._model_call_count
state["last_compacted_index"] = self._last_compacted_index
return state

def restore_from_session(self, state: dict[str, Any]) -> list | None:
Expand All @@ -131,13 +172,15 @@ def restore_from_session(self, state: dict[str, Any]) -> list | None:
"""
result = super().restore_from_session(state)
self._model_call_count = state.get("model_call_count", 0)
self._last_compacted_index = state.get("last_compacted_index", 0)
return result

def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
"""Apply the sliding window to the agent's messages array to maintain a manageable history size.

This method is called after every event loop cycle to apply a sliding window if the message count
exceeds the window size.
This method is called after every event loop cycle. It applies micro-compaction for stale tool
results (if configured), then loops ``reduce_context`` until both message count and token budget
limits are satisfied (or no further reduction is possible).

Args:
agent: The agent whose messages will be managed.
Expand All @@ -146,12 +189,36 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
"""
messages = agent.messages

if len(messages) <= self.window_size:
logger.debug(
"message_count=<%s>, window_size=<%s> | skipping context reduction", len(messages), self.window_size
# Micro-compact stale tool results before checking limits
if self.compactable_after_messages is not None:
self._micro_compact(messages)

# Bound by len(messages) — each iteration must remove at least one message or
# tool-result truncation, and the no-progress guard below catches stalls.
max_iterations = len(messages)
for _ in range(max_iterations):
over_message_limit = len(messages) > self.window_size
over_token_limit = (
self.max_context_tokens is not None and self._get_current_token_count(agent) > self.max_context_tokens
)
return
self.reduce_context(agent)

if not over_message_limit and not over_token_limit:
logger.debug(
"message_count=<%s>, window_size=<%s> | context within limits",
len(messages),
self.window_size,
)
return

prev_len = len(messages)
self.reduce_context(agent)
if len(messages) >= prev_len:
logger.warning(
"message_count=<%s>, window_size=<%s> | reduce_context made no progress, stopping",
len(messages),
self.window_size,
)
return

def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None:
"""Trim the oldest messages to reduce the conversation context size.
Expand Down Expand Up @@ -229,9 +296,77 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
# trim_index represents the number of messages being removed from the agents messages array
self.removed_message_count += trim_index

# Adjust compaction tracking index
self._last_compacted_index = max(0, self._last_compacted_index - trim_index)

# Overwrite message history
messages[:] = messages[trim_index:]

def _get_current_token_count(self, agent: "Agent") -> int:
"""Estimate the current token count for the conversation context.

Always uses the configured ``token_counter`` heuristic rather than model-reported
``latest_context_size``, because the model-reported value reflects the *previous*
cycle and becomes stale after any reduction — leading to over-reduction spirals.

Args:
agent: The agent whose context size is being measured.

Returns:
The estimated token count.
"""
return self.token_counter(agent.messages)

_COMPACT_STUB = "[Tool result cleared — re-run if needed]"

def _micro_compact(self, messages: Messages) -> int:
"""Replace old tool results with compact stubs to reclaim token budget.

Tool results older than ``compactable_after_messages`` messages from the end of the
conversation are replaced with a short stub. The toolUse/toolResult pair structure
is preserved — only the content within toolResult blocks is replaced.

Tracks ``_last_compacted_index`` to skip already-processed messages on subsequent calls.

Args:
messages: The conversation message history (modified in-place).

Returns:
Estimated number of tokens reclaimed.
"""
if self.compactable_after_messages is None:
return 0

# NOTE (M1): Clamp index in case messages were externally replaced or shortened
# between calls (e.g., manual agent.messages reset, session restore mismatch).
self._last_compacted_index = min(self._last_compacted_index, len(messages))

reclaimed_chars = 0
cutoff = len(messages) - self.compactable_after_messages

# NOTE (M2): reclaimed_chars may overcount if text was already truncated by
# _truncate_tool_results — the return value is an estimate, not used for decisions.
stub_len = len(self._COMPACT_STUB)
for i in range(self._last_compacted_index, max(0, cutoff)):
msg = messages[i]
for block in msg.get("content", []):
if "toolResult" not in block:
continue
result = block["toolResult"]
items = result.get("content", [])
for j, item in enumerate(items):
if "text" in item and item["text"] != self._COMPACT_STUB:
reclaimed_chars += max(0, len(item["text"]) - stub_len)
items[j] = {"text": self._COMPACT_STUB}
elif "image" in item:
reclaimed_chars += IMAGE_CHAR_ESTIMATE
items[j] = {"text": self._COMPACT_STUB}

if cutoff > 0:
self._last_compacted_index = max(self._last_compacted_index, cutoff)

return reclaimed_chars // 4

def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
"""Truncate tool results and replace image blocks in a message to reduce context size.

Expand Down
Loading
Loading