Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 60 additions & 8 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from ..event_loop.streaming import process_stream
from ..tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Messages
from ..types.content import ContentBlock, Messages, SystemContentBlock
from ..types.event_loop import Usage
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
Expand Down Expand Up @@ -201,12 +202,38 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:

return formatted_messages

@staticmethod
def _format_system_prompt_content(
system_prompt_content: list[SystemContentBlock],
) -> list[dict[str, Any]]:
"""Convert system prompt content blocks to Anthropic list-form system array.

A ``cachePoint`` block attaches ``cache_control: {"type": "ephemeral"}`` to
the immediately preceding text block, mirroring the convention already used
by ``_format_request_messages``. This lets callers mark the static prefix of
the system prompt as cacheable while leaving dynamic suffixes uncached.

Args:
system_prompt_content: System prompt content blocks.

Returns:
Anthropic list-form system array.
"""
formatted: list[dict[str, Any]] = []
for block in system_prompt_content:
if "text" in block:
formatted.append({"type": "text", "text": block["text"]})
elif "cachePoint" in block and formatted and formatted[-1].get("type") == "text":
formatted[-1]["cache_control"] = {"type": "ephemeral"}
return formatted

def format_request(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
tool_choice: ToolChoice | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
) -> dict[str, Any]:
"""Format an Anthropic streaming request.

Expand All @@ -215,6 +242,9 @@ def format_request(
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
system_prompt_content: System prompt content blocks. When provided, takes
precedence over ``system_prompt`` and enables prompt caching via
``cachePoint`` blocks translated to ``cache_control: ephemeral``.

Returns:
An Anthropic streaming request.
Expand All @@ -223,6 +253,12 @@ def format_request(
TypeError: If a message contains a content block type that cannot be converted to an Anthropic-compatible
format.
"""
system_field: str | list[dict[str, Any]] | None = None
if system_prompt_content:
system_field = self._format_system_prompt_content(system_prompt_content) or None
elif system_prompt:
system_field = system_prompt

return {
"max_tokens": self.config["max_tokens"],
"messages": self._format_request_messages(messages),
Expand All @@ -236,7 +272,7 @@ def format_request(
for tool_spec in tool_specs or []
],
**(self._format_tool_choice(tool_choice)),
**({"system": system_prompt} if system_prompt else {}),
**({"system": system_field} if system_field else {}),
**(self.config.get("params") or {}),
}

Expand Down Expand Up @@ -354,14 +390,20 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:

case "metadata":
usage = event["usage"]
usage_out: Usage = {
"inputTokens": usage["input_tokens"],
"outputTokens": usage["output_tokens"],
"totalTokens": usage["input_tokens"] + usage["output_tokens"],
}
cache_read = usage.get("cache_read_input_tokens") or 0
cache_write = usage.get("cache_creation_input_tokens") or 0
if cache_read or cache_write:
usage_out["cacheReadInputTokens"] = cache_read
usage_out["cacheWriteInputTokens"] = cache_write

return {
"metadata": {
"usage": {
"inputTokens": usage["input_tokens"],
"outputTokens": usage["output_tokens"],
"totalTokens": usage["input_tokens"] + usage["output_tokens"],
},
"usage": usage_out,
"metrics": {
"latencyMs": 0, # TODO
},
Expand All @@ -379,6 +421,7 @@ async def stream(
system_prompt: str | None = None,
*,
tool_choice: ToolChoice | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Anthropic model.
Expand All @@ -388,6 +431,9 @@ async def stream(
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
system_prompt_content: System prompt content blocks. When provided, takes
precedence over ``system_prompt`` and enables prompt caching via
``cachePoint`` blocks translated to ``cache_control: ephemeral``.
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand All @@ -398,7 +444,13 @@ async def stream(
ModelThrottledException: If the request is throttled by Anthropic.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
request = self.format_request(
messages,
tool_specs,
system_prompt,
tool_choice,
system_prompt_content=system_prompt_content,
)
logger.debug("request=<%s>", request)

logger.debug("invoking model")
Expand Down
98 changes: 98 additions & 0 deletions tests/strands/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,56 @@ def test_format_request_with_cache_point(model, model_id, max_tokens):
assert tru_request == exp_request


def test_format_request_with_system_prompt_content_cache_point(model, messages, model_id, max_tokens):
"""cachePoint in system_prompt_content emits Anthropic list-form system with cache_control."""
system_prompt_content = [
{"text": "static prefix"},
{"cachePoint": {"type": "default"}},
{"text": "dynamic suffix"},
]

tru_request = model.format_request(messages, system_prompt_content=system_prompt_content)
exp_request = {
"max_tokens": max_tokens,
"messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}],
"model": model_id,
"system": [
{"type": "text", "text": "static prefix", "cache_control": {"type": "ephemeral"}},
{"type": "text", "text": "dynamic suffix"},
],
"tools": [],
}

assert tru_request == exp_request


def test_format_request_with_system_prompt_content_no_cache_point(model, messages, model_id, max_tokens):
"""system_prompt_content with only text blocks emits list-form system without cache_control."""
system_prompt_content = [{"text": "plain system"}]

tru_request = model.format_request(messages, system_prompt_content=system_prompt_content)
exp_request = {
"max_tokens": max_tokens,
"messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}],
"model": model_id,
"system": [{"type": "text", "text": "plain system"}],
"tools": [],
}

assert tru_request == exp_request


def test_format_request_system_prompt_content_precedes_system_prompt(model, messages, model_id, max_tokens):
"""system_prompt_content takes precedence over system_prompt when both are supplied."""
tru_request = model.format_request(
messages,
system_prompt="ignored",
system_prompt_content=[{"text": "used"}],
)

assert tru_request["system"] == [{"type": "text", "text": "used"}]


def test_format_request_with_empty_content(model, model_id, max_tokens):
messages = [
{
Expand Down Expand Up @@ -703,6 +753,54 @@ def test_format_chunk_metadata(model):
assert tru_chunk == exp_chunk


def test_format_chunk_metadata_with_cache_tokens(model):
event = {
"type": "metadata",
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"cache_read_input_tokens": 100,
"cache_creation_input_tokens": 200,
},
}

tru_chunk = model.format_chunk(event)
exp_chunk = {
"metadata": {
"usage": {
"inputTokens": 10,
"outputTokens": 5,
"totalTokens": 15,
"cacheReadInputTokens": 100,
"cacheWriteInputTokens": 200,
},
"metrics": {
"latencyMs": 0,
},
},
}

assert tru_chunk == exp_chunk


def test_format_chunk_metadata_without_cache_tokens_unchanged(model):
"""When cache fields are absent or zero the usage shape is unchanged."""
event = {
"type": "metadata",
"usage": {
"input_tokens": 1,
"output_tokens": 2,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
},
}

tru_chunk = model.format_chunk(event)

assert "cacheReadInputTokens" not in tru_chunk["metadata"]["usage"]
assert "cacheWriteInputTokens" not in tru_chunk["metadata"]["usage"]


def test_format_chunk_unknown(model):
event = {"type": "unknown"}

Expand Down