diff --git a/README.md b/README.md index 173adc006..ad81f0248 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ agent = Agent(tools=[calculator]) agent("What is the square root of 1764") ``` -> **Note**: For the default Amazon Bedrock model provider, you'll need AWS credentials configured and model access enabled for Claude 4 Sonnet in the us-west-2 region. See the [Quickstart Guide](https://strandsagents.com/) for details on configuring other model providers. +> **Note**: For the default Amazon Bedrock model provider, you'll need AWS credentials configured and model access enabled for Claude Sonnet 4.6 via the `global.anthropic.claude-sonnet-4-6` cross-region inference profile. See the [Quickstart Guide](https://strandsagents.com/) for details on configuring other model providers. ## Installation diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 7ff3024a8..e04cb4b38 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -35,9 +35,7 @@ logger = logging.getLogger(__name__) -# See: `BedrockModel._get_default_model_with_warning` for why we need both -DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" -_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0" +DEFAULT_BEDROCK_MODEL_ID = "global.anthropic.claude-sonnet-4-6" DEFAULT_BEDROCK_REGION = "us-west-2" BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ @@ -90,7 +88,7 @@ class BedrockConfig(BaseModelConfig, total=False): guardrail_latest_message: Flag to send only the lastest user message to guardrails. Defaults to False. max_tokens: Maximum number of tokens to generate in the response - model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0") + model_id: The Bedrock model ID (e.g., "global.anthropic.claude-sonnet-4-6") include_tool_result_status: Flag to include status field in tool results. True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto". service_tier: Service tier for the request, controlling the trade-off between latency and cost. @@ -151,12 +149,16 @@ def __init__( session = boto_session or boto3.Session() resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION - self.config = BedrockModel.BedrockConfig( - model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config), - include_tool_result_status="auto", - ) + self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto") self.update_config(**model_config) + if self.config.get("model_id") == DEFAULT_BEDROCK_MODEL_ID: + warnings.warn( + f"You're using default model '{DEFAULT_BEDROCK_MODEL_ID}', which is subject to change. " + "Specify a model explicitly to pin the model target.", + stacklevel=2, + ) + logger.debug("config=<%s> | initializing", self.config) # Add strands-agents to the request user agent @@ -1081,52 +1083,3 @@ async def structured_output( raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") yield {"output": output_model(**output_response)} - - @staticmethod - def _get_default_model_with_warning(region_name: str, model_config: BedrockConfig | None = None) -> str: - """Get the default Bedrock modelId based on region. - - If the region is not **known** to support inference then we show a helpful warning - that compliments the exception that Bedrock will throw. - If the customer provided a model_id in their config or they overrode the `DEFAULT_BEDROCK_MODEL_ID` - then we should not process further. - - Args: - region_name (str): region for bedrock model - model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init - """ - if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"): - return DEFAULT_BEDROCK_MODEL_ID - - model_config = model_config or {} - if model_config.get("model_id"): - return model_config["model_id"] - - prefix_inference_map = {"ap": "apac"} # some inference endpoints can be a bit different than the region prefix - - prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1` - if prefix not in {"us", "eu", "ap", "us-gov"}: - warnings.warn( - f""" - ================== WARNING ================== - - This region {region_name} does not support - our default inference endpoint: {_DEFAULT_BEDROCK_MODEL_ID.format(prefix)}. - Update the agent to pass in a 'model_id' like so: - ``` - Agent(..., model='valid_model_id', ...) - ```` - Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html - - ================================================== - """, - stacklevel=2, - ) - - default_model_id = _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix)) - warnings.warn( - f"You're using default model '{default_model_id}', which is subject to change. " - "Specify a model explicitly to pin the model target.", - stacklevel=2, - ) - return default_model_id diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 1e27274a1..c05385b16 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -34,9 +34,6 @@ from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider -# For unit testing we will use the the us inference -FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") - @pytest.fixture def mock_model(request): @@ -244,7 +241,7 @@ def test_agent__init__with_default_model(): agent = Agent() assert isinstance(agent.model, BedrockModel) - assert agent.model.config["model_id"] == FORMATTED_DEFAULT_MODEL_ID + assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID def test_agent__init__with_explicit_model(mock_model): @@ -854,7 +851,7 @@ def test_agent_tool_names(tools, agent): def test_agent_init_with_no_model_or_model_id(): agent = Agent() assert agent.model is not None - assert agent.model.get_config().get("model_id") == FORMATTED_DEFAULT_MODEL_ID + assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID def test_agent_with_none_callback_handler_prints_nothing(): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 384ee05e1..483ce4991 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -16,7 +16,6 @@ from strands import _exception_notes from strands.models import BedrockModel, CacheConfig from strands.models.bedrock import ( - _DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, @@ -24,8 +23,6 @@ from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec -FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") - @pytest.fixture def session_cls(): @@ -132,7 +129,7 @@ def test__init__default_model_id(bedrock_client): model = BedrockModel() tru_model_id = model.get_config().get("model_id") - exp_model_id = FORMATTED_DEFAULT_MODEL_ID + exp_model_id = DEFAULT_BEDROCK_MODEL_ID assert tru_model_id == exp_model_id @@ -2165,84 +2162,24 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): assert len(captured_warnings) == 0 -def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings): - """Test get_model_prefix_with_warning doesn't warn for supported region prefixes.""" - BedrockModel._get_default_model_with_warning("us-west-2") - BedrockModel._get_default_model_with_warning("eu-west-2") - assert all("does not support" not in str(w.message) for w in captured_warnings) - - -def test_get_default_model_for_supported_eu_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("eu-west-1") - assert model_id == "eu.anthropic.claude-sonnet-4-20250514-v1:0" - assert all("does not support" not in str(w.message) for w in captured_warnings) - - -def test_get_default_model_for_supported_us_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("us-east-1") - assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" - assert all("does not support" not in str(w.message) for w in captured_warnings) - - -def test_get_default_model_for_supported_gov_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("us-gov-west-1") - assert model_id == "us-gov.anthropic.claude-sonnet-4-20250514-v1:0" - assert all("does not support" not in str(w.message) for w in captured_warnings) - - -def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings): - """Test _get_default_model_with_warning warns for APAC regions since 'ap' is not in supported prefixes.""" - model_id = BedrockModel._get_default_model_with_warning("ap-southeast-1") - assert model_id == "apac.anthropic.claude-sonnet-4-20250514-v1:0" - - -def test_get_default_model_with_warning_unsupported_region_warns(captured_warnings): - """Test _get_default_model_with_warning warns for unsupported regions.""" - BedrockModel._get_default_model_with_warning("ca-central-1") - region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] - assert len(region_warnings) == 1 - assert "This region ca-central-1 does not support" in str(region_warnings[0].message) - assert "our default inference endpoint" in str(region_warnings[0].message) - - -def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured_warnings): - """Test _get_default_model_with_warning doesn't warn when custom model_id provided.""" - model_config = {"model_id": "custom-model"} - model_id = BedrockModel._get_default_model_with_warning("ca-central-1", model_config) - - assert model_id == "custom-model" - assert len(captured_warnings) == 0 - - -def test_init_with_unsupported_region_warns(session_cls, captured_warnings): - """Test BedrockModel initialization warns for unsupported regions.""" - BedrockModel(region_name="ca-central-1") - - region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] - assert len(region_warnings) == 1 - assert "This region ca-central-1 does not support" in str(region_warnings[0].message) - - -def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings): - """Test BedrockModel initialization doesn't warn when custom model_id provided.""" - BedrockModel(region_name="ca-central-1", model_id="custom-model") - assert len(captured_warnings) == 0 +def test_init_with_default_model_warns_subject_to_change(session_cls, captured_warnings): + """Test BedrockModel initialization warns when relying on the default model id.""" + BedrockModel() + default_warnings = [w for w in captured_warnings if "which is subject to change" in str(w.message)] + assert len(default_warnings) == 1 + assert f"You're using default model '{DEFAULT_BEDROCK_MODEL_ID}'" in str(default_warnings[0].message) -def test_override_default_model_id_uses_the_overriden_value(captured_warnings): - with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", "custom-overridden-model"): - model_id = BedrockModel._get_default_model_with_warning("us-east-1") - assert model_id == "custom-overridden-model" +def test_init_with_custom_model_id_no_default_warning(session_cls, captured_warnings): + """Test BedrockModel initialization doesn't emit the default-model warning when model_id is provided.""" + BedrockModel(model_id="custom-model") -def test_no_override_uses_formatted_default_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("us-east-1") - assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" - assert model_id != _DEFAULT_BEDROCK_MODEL_ID - assert all("does not support" not in str(w.message) for w in captured_warnings) + default_warnings = [w for w in captured_warnings if "which is subject to change" in str(w.message)] + assert len(default_warnings) == 0 -def test_custom_model_id_not_overridden_by_region_formatting(session_cls): +def test_custom_model_id_not_overridden(session_cls): """Test that custom model_id is not overridden by region formatting.""" custom_model_id = "custom.model.id"