Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ agent = Agent(tools=[calculator])
agent("What is the square root of 1764")
```

> **Note**: For the default Amazon Bedrock model provider, you'll need AWS credentials configured and model access enabled for Claude 4 Sonnet in the us-west-2 region. See the [Quickstart Guide](https://strandsagents.com/) for details on configuring other model providers.
> **Note**: For the default Amazon Bedrock model provider, you'll need AWS credentials configured and model access enabled for Claude Sonnet 4.6 via the `global.anthropic.claude-sonnet-4-6` cross-region inference profile. See the [Quickstart Guide](https://strandsagents.com/) for details on configuring other model providers.

## Installation

Expand Down
67 changes: 10 additions & 57 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@

logger = logging.getLogger(__name__)

# See: `BedrockModel._get_default_model_with_warning` for why we need both
DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"
_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0"
DEFAULT_BEDROCK_MODEL_ID = "global.anthropic.claude-sonnet-4-6"
DEFAULT_BEDROCK_REGION = "us-west-2"

BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
Expand Down Expand Up @@ -90,7 +88,7 @@ class BedrockConfig(BaseModelConfig, total=False):
guardrail_latest_message: Flag to send only the lastest user message to guardrails.
Defaults to False.
max_tokens: Maximum number of tokens to generate in the response
model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0")
model_id: The Bedrock model ID (e.g., "global.anthropic.claude-sonnet-4-6")
include_tool_result_status: Flag to include status field in tool results.
True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto".
service_tier: Service tier for the request, controlling the trade-off between latency and cost.
Expand Down Expand Up @@ -151,12 +149,16 @@ def __init__(

session = boto_session or boto3.Session()
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
self.config = BedrockModel.BedrockConfig(
model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config),
include_tool_result_status="auto",
)
self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto")
self.update_config(**model_config)

if self.config.get("model_id") == DEFAULT_BEDROCK_MODEL_ID:
warnings.warn(
f"You're using default model '{DEFAULT_BEDROCK_MODEL_ID}', which is subject to change. "
"Specify a model explicitly to pin the model target.",
stacklevel=2,
)

logger.debug("config=<%s> | initializing", self.config)

# Add strands-agents to the request user agent
Expand Down Expand Up @@ -1081,52 +1083,3 @@ async def structured_output(
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")

yield {"output": output_model(**output_response)}

@staticmethod
def _get_default_model_with_warning(region_name: str, model_config: BedrockConfig | None = None) -> str:
"""Get the default Bedrock modelId based on region.

If the region is not **known** to support inference then we show a helpful warning
that compliments the exception that Bedrock will throw.
If the customer provided a model_id in their config or they overrode the `DEFAULT_BEDROCK_MODEL_ID`
then we should not process further.

Args:
region_name (str): region for bedrock model
model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init
"""
if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"):
return DEFAULT_BEDROCK_MODEL_ID

model_config = model_config or {}
if model_config.get("model_id"):
return model_config["model_id"]

prefix_inference_map = {"ap": "apac"} # some inference endpoints can be a bit different than the region prefix

prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1`
if prefix not in {"us", "eu", "ap", "us-gov"}:
warnings.warn(
f"""
================== WARNING ==================

This region {region_name} does not support
our default inference endpoint: {_DEFAULT_BEDROCK_MODEL_ID.format(prefix)}.
Update the agent to pass in a 'model_id' like so:
```
Agent(..., model='valid_model_id', ...)
````
Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html

==================================================
""",
stacklevel=2,
)

default_model_id = _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix))
warnings.warn(
f"You're using default model '{default_model_id}', which is subject to change. "
"Specify a model explicitly to pin the model target.",
stacklevel=2,
)
return default_model_id
7 changes: 2 additions & 5 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@
from tests.fixtures.mock_session_repository import MockedSessionRepository
from tests.fixtures.mocked_model_provider import MockedModelProvider

# For unit testing we will use the the us inference
FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us")


@pytest.fixture
def mock_model(request):
Expand Down Expand Up @@ -244,7 +241,7 @@ def test_agent__init__with_default_model():
agent = Agent()

assert isinstance(agent.model, BedrockModel)
assert agent.model.config["model_id"] == FORMATTED_DEFAULT_MODEL_ID
assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID


def test_agent__init__with_explicit_model(mock_model):
Expand Down Expand Up @@ -854,7 +851,7 @@ def test_agent_tool_names(tools, agent):
def test_agent_init_with_no_model_or_model_id():
agent = Agent()
assert agent.model is not None
assert agent.model.get_config().get("model_id") == FORMATTED_DEFAULT_MODEL_ID
assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID


def test_agent_with_none_callback_handler_prints_nothing():
Expand Down
89 changes: 13 additions & 76 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
from strands import _exception_notes
from strands.models import BedrockModel, CacheConfig
from strands.models.bedrock import (
_DEFAULT_BEDROCK_MODEL_ID,
DEFAULT_BEDROCK_MODEL_ID,
DEFAULT_BEDROCK_REGION,
DEFAULT_READ_TIMEOUT,
)
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
from strands.types.tools import ToolSpec

FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us")


@pytest.fixture
def session_cls():
Expand Down Expand Up @@ -132,7 +129,7 @@ def test__init__default_model_id(bedrock_client):
model = BedrockModel()

tru_model_id = model.get_config().get("model_id")
exp_model_id = FORMATTED_DEFAULT_MODEL_ID
exp_model_id = DEFAULT_BEDROCK_MODEL_ID

assert tru_model_id == exp_model_id

Expand Down Expand Up @@ -2165,84 +2162,24 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings):
assert len(captured_warnings) == 0


def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings):
"""Test get_model_prefix_with_warning doesn't warn for supported region prefixes."""
BedrockModel._get_default_model_with_warning("us-west-2")
BedrockModel._get_default_model_with_warning("eu-west-2")
assert all("does not support" not in str(w.message) for w in captured_warnings)


def test_get_default_model_for_supported_eu_region_returns_correct_model_id(captured_warnings):
model_id = BedrockModel._get_default_model_with_warning("eu-west-1")
assert model_id == "eu.anthropic.claude-sonnet-4-20250514-v1:0"
assert all("does not support" not in str(w.message) for w in captured_warnings)


def test_get_default_model_for_supported_us_region_returns_correct_model_id(captured_warnings):
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0"
assert all("does not support" not in str(w.message) for w in captured_warnings)


def test_get_default_model_for_supported_gov_region_returns_correct_model_id(captured_warnings):
model_id = BedrockModel._get_default_model_with_warning("us-gov-west-1")
assert model_id == "us-gov.anthropic.claude-sonnet-4-20250514-v1:0"
assert all("does not support" not in str(w.message) for w in captured_warnings)


def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings):
"""Test _get_default_model_with_warning warns for APAC regions since 'ap' is not in supported prefixes."""
model_id = BedrockModel._get_default_model_with_warning("ap-southeast-1")
assert model_id == "apac.anthropic.claude-sonnet-4-20250514-v1:0"


def test_get_default_model_with_warning_unsupported_region_warns(captured_warnings):
"""Test _get_default_model_with_warning warns for unsupported regions."""
BedrockModel._get_default_model_with_warning("ca-central-1")
region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)]
assert len(region_warnings) == 1
assert "This region ca-central-1 does not support" in str(region_warnings[0].message)
assert "our default inference endpoint" in str(region_warnings[0].message)


def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured_warnings):
"""Test _get_default_model_with_warning doesn't warn when custom model_id provided."""
model_config = {"model_id": "custom-model"}
model_id = BedrockModel._get_default_model_with_warning("ca-central-1", model_config)

assert model_id == "custom-model"
assert len(captured_warnings) == 0


def test_init_with_unsupported_region_warns(session_cls, captured_warnings):
"""Test BedrockModel initialization warns for unsupported regions."""
BedrockModel(region_name="ca-central-1")

region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)]
assert len(region_warnings) == 1
assert "This region ca-central-1 does not support" in str(region_warnings[0].message)


def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings):
"""Test BedrockModel initialization doesn't warn when custom model_id provided."""
BedrockModel(region_name="ca-central-1", model_id="custom-model")
assert len(captured_warnings) == 0
def test_init_with_default_model_warns_subject_to_change(session_cls, captured_warnings):
"""Test BedrockModel initialization warns when relying on the default model id."""
BedrockModel()

default_warnings = [w for w in captured_warnings if "which is subject to change" in str(w.message)]
assert len(default_warnings) == 1
assert f"You're using default model '{DEFAULT_BEDROCK_MODEL_ID}'" in str(default_warnings[0].message)

def test_override_default_model_id_uses_the_overriden_value(captured_warnings):
with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", "custom-overridden-model"):
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
assert model_id == "custom-overridden-model"

def test_init_with_custom_model_id_no_default_warning(session_cls, captured_warnings):
"""Test BedrockModel initialization doesn't emit the default-model warning when model_id is provided."""
BedrockModel(model_id="custom-model")

def test_no_override_uses_formatted_default_model_id(captured_warnings):
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0"
assert model_id != _DEFAULT_BEDROCK_MODEL_ID
assert all("does not support" not in str(w.message) for w in captured_warnings)
default_warnings = [w for w in captured_warnings if "which is subject to change" in str(w.message)]
assert len(default_warnings) == 0


def test_custom_model_id_not_overridden_by_region_formatting(session_cls):
def test_custom_model_id_not_overridden(session_cls):
"""Test that custom model_id is not overridden by region formatting."""
custom_model_id = "custom.model.id"

Expand Down