From ffe5001ff0ec20bb3f7d2bff03ec2693ac6dcc31 Mon Sep 17 00:00:00 2001 From: Ryan Tanaka Date: Fri, 20 Mar 2026 00:56:54 -0700 Subject: [PATCH 1/3] feature: add telemetry attribution module for SDK usage provenance --- sagemaker-core/src/sagemaker/core/__init__.py | 3 + .../sagemaker/core/telemetry/attribution.py | 41 +++++ .../core/telemetry/telemetry_logging.py | 8 + .../tests/unit/telemetry/test_attribution.py | 37 +++++ .../unit/telemetry/test_telemetry_logging.py | 150 +++++++++++++++++- 5 files changed, 233 insertions(+), 6 deletions(-) create mode 100644 sagemaker-core/src/sagemaker/core/telemetry/attribution.py create mode 100644 sagemaker-core/tests/unit/telemetry/test_attribution.py diff --git a/sagemaker-core/src/sagemaker/core/__init__.py b/sagemaker-core/src/sagemaker/core/__init__.py index 97192083a7..f25f18009d 100644 --- a/sagemaker-core/src/sagemaker/core/__init__.py +++ b/sagemaker-core/src/sagemaker/core/__init__.py @@ -15,5 +15,8 @@ # Partner App from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401 +# Attribution +from sagemaker.core.telemetry.attribution import Attribution, set_attribution # noqa: F401 + # Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner # They are not re-exported from core to avoid circular dependencies diff --git a/sagemaker-core/src/sagemaker/core/telemetry/attribution.py b/sagemaker-core/src/sagemaker/core/telemetry/attribution.py new file mode 100644 index 0000000000..1ba016f434 --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/telemetry/attribution.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Attribution module for tracking the provenance of SDK usage.""" +from __future__ import absolute_import +import os +from enum import Enum + +_CREATED_BY_ENV_VAR = "SAGEMAKER_PYSDK_CREATED_BY" + + +class Attribution(Enum): + """Enumeration of known SDK attribution sources.""" + + SAGEMAKER_AGENT_PLUGIN = "awslabs/agent-plugins/sagemaker-ai" + + +def set_attribution(attribution: Attribution): + """Sets the SDK usage attribution to the specified source. + + Call this at the top of scripts generated by an agent or integration + to enable accurate telemetry attribution. + + Args: + attribution (Attribution): The attribution source to set. + + Raises: + TypeError: If attribution is not an Attribution enum member. + """ + if not isinstance(attribution, Attribution): + raise TypeError(f"attribution must be an Attribution enum member, got {type(attribution)}") + os.environ[_CREATED_BY_ENV_VAR] = attribution.value diff --git a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py index 6388cdfe5e..45475b563b 100644 --- a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py +++ b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py @@ -13,15 +13,18 @@ """Telemetry module for SageMaker Python SDK to collect usage data and metrics.""" from __future__ import absolute_import import logging +import os import platform import sys from time import perf_counter from typing import List import functools import requests +from urllib.parse import quote import boto3 from sagemaker.core.helper.session_helper import Session +from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR from sagemaker.core.common_utils import resolve_value_from_config from sagemaker.core.config.config_schema import TELEMETRY_OPT_OUT_PATH from sagemaker.core.telemetry.constants import ( @@ -137,6 +140,11 @@ def wrapper(*args, **kwargs): if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn: extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}" + # Add created_by from environment variable if available + created_by = os.environ.get(_CREATED_BY_ENV_VAR, "") + if created_by: + extra += f"&x-createdBy={quote(created_by, safe='')}" + start_timer = perf_counter() try: # Call the original function diff --git a/sagemaker-core/tests/unit/telemetry/test_attribution.py b/sagemaker-core/tests/unit/telemetry/test_attribution.py new file mode 100644 index 0000000000..bd7cc3a907 --- /dev/null +++ b/sagemaker-core/tests/unit/telemetry/test_attribution.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import os +import pytest +from sagemaker.core.telemetry.attribution import ( + _CREATED_BY_ENV_VAR, + Attribution, + set_attribution, +) + + +@pytest.fixture(autouse=True) +def clean_env(): + yield + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + +def test_set_attribution_sagemaker_agent_plugin(): + set_attribution(Attribution.SAGEMAKER_AGENT_PLUGIN) + assert os.environ[_CREATED_BY_ENV_VAR] == Attribution.SAGEMAKER_AGENT_PLUGIN.value + + +def test_set_attribution_invalid_type_raises(): + with pytest.raises(TypeError): + set_attribution("awslabs/agent-plugins/sagemaker-ai") diff --git a/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py b/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py index 41af998d73..2973b12c44 100644 --- a/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py +++ b/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import os import unittest import pytest import requests @@ -18,6 +19,7 @@ import boto3 import sagemaker from sagemaker.core.telemetry.constants import Feature +from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR from sagemaker.core.telemetry.telemetry_logging import ( _send_telemetry_request, _telemetry_emitter, @@ -33,16 +35,23 @@ # Try to import sagemaker-serve exceptions, skip tests if not available try: - from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException + from sagemaker.serve.utils.exceptions import ( + ModelBuilderException, + LocalModelOutOfMemoryException, + ) + SAGEMAKER_SERVE_AVAILABLE = True except ImportError: SAGEMAKER_SERVE_AVAILABLE = False + # Create mock exceptions for type hints class ModelBuilderException(Exception): pass + class LocalModelOutOfMemoryException(Exception): pass + MOCK_SESSION = Mock() MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") MOCK_FEATURE = Feature.SDK_DEFAULTS @@ -158,10 +167,7 @@ def test_telemetry_emitter_decorator_success( 1, [11, 12], MOCK_SESSION, None, None, expected_extra_str ) - @pytest.mark.skipif( - not SAGEMAKER_SERVE_AVAILABLE, - reason="Requires sagemaker-serve package" - ) + @pytest.mark.skipif(not SAGEMAKER_SERVE_AVAILABLE, reason="Requires sagemaker-serve package") @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") def test_telemetry_emitter_decorator_handle_exception_success( @@ -194,7 +200,7 @@ def test_telemetry_emitter_decorator_handle_exception_success( mock_send_telemetry_request.assert_called_once_with( 0, - [1, 2], + [11, 12], MOCK_SESSION, str(mock_exception_obj), mock_exception_obj.__class__.__name__, @@ -357,3 +363,135 @@ def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_a _send_telemetry_request(1, [1, 2], mock_session) # Assert telemetry request was not sent mock_requests_helper.assert_not_called() + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_with_created_by_env_var( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-createdBy is included when SAGEMAKER_PYSDK_CREATED_BY env var is set""" + mock_resolve_config.return_value = False + + # Set environment variable + os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai" + + try: + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + + # Verify x-createdBy is in the extra string with URL encoding + self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", extra_str) + + # Verify forward slashes are encoded as %2F + self.assertNotIn("x-createdBy=awslabs/agent-plugins", extra_str) + finally: + # Clean up environment variable + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_without_created_by_env_var( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-createdBy is NOT included when env var is not set""" + mock_resolve_config.return_value = False + + # Ensure environment variable is not set + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + + # Verify x-createdBy is NOT in the extra string + self.assertNotIn("x-createdBy", extra_str) + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_created_by_with_special_chars( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-createdBy properly URL-encodes special characters""" + mock_resolve_config.return_value = False + + # Set environment variable with special characters + os.environ[_CREATED_BY_ENV_VAR] = "My App & Tools (v2.0)" + + try: + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + + # Verify special characters are URL-encoded + self.assertIn("x-createdBy=My%20App%20%26%20Tools%20%28v2.0%29", extra_str) + + # Verify raw special characters are NOT in the URL + self.assertNotIn("My App & Tools", extra_str) + self.assertNotIn("(v2.0)", extra_str) + finally: + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_created_by_empty_string( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-createdBy is NOT included when env var is empty string""" + mock_resolve_config.return_value = False + + # Set environment variable to empty string + os.environ[_CREATED_BY_ENV_VAR] = "" + + try: + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + + # Verify x-createdBy is NOT added for empty string + self.assertNotIn("x-createdBy", extra_str) + finally: + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + def test_construct_url_with_created_by(self): + """Test URL construction includes x-createdBy in extra_info""" + mock_accountId = "123456789012" + mock_region = "us-west-2" + mock_status = "1" + mock_feature = "15" + mock_extra_info = ( + "DataSet.create&x-sdkVersion=3.0&x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai" + ) + + url = _construct_url( + accountId=mock_accountId, + region=mock_region, + status=mock_status, + feature=mock_feature, + failure_reason=None, + failure_type=None, + extra_info=mock_extra_info, + ) + + expected_url = ( + f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?" + f"x-accountId={mock_accountId}" + f"&x-status={mock_status}" + f"&x-feature={mock_feature}" + f"&x-extra={mock_extra_info}" + ) + + self.assertEqual(url, expected_url) + self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", url) From 392881f40ce4735b255c6e9a863ec5bf97d7c761 Mon Sep 17 00:00:00 2001 From: Ryan Tanaka Date: Mon, 23 Mar 2026 01:16:17 -0700 Subject: [PATCH 2/3] feature: add TrainingJob ARN to telemetry for training jobs and fixed bug with telemetry not being sent for *Trainer.train() if sagemaker_session is not provided --- .../core/telemetry/resource_creation.py | 47 ++++++++++++ .../core/telemetry/telemetry_logging.py | 8 +- .../unit/telemetry/test_resource_creation.py | 74 +++++++++++++++++++ .../unit/telemetry/test_telemetry_logging.py | 47 ++++++++++++ 4 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 sagemaker-core/src/sagemaker/core/telemetry/resource_creation.py create mode 100644 sagemaker-core/tests/unit/telemetry/test_resource_creation.py diff --git a/sagemaker-core/src/sagemaker/core/telemetry/resource_creation.py b/sagemaker-core/src/sagemaker/core/telemetry/resource_creation.py new file mode 100644 index 0000000000..c73d138004 --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/telemetry/resource_creation.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Resource creation module for tracking ARNs of resources created via SDK calls.""" +from __future__ import absolute_import + +# Maps class name (string) to the attribute name holding the resource ARN. +# String-based keys avoid cross-package imports and circular dependencies. +_RESOURCE_ARN_ATTRIBUTES = { + "TrainingJob": "training_job_arn", +} + + +def get_resource_arn(response): + """Extract the ARN from a SDK response object if available. + + Uses string-based type name lookup to avoid cross-package imports. + + Args: + response: The return value of a _telemetry_emitter-decorated function. + + Returns: + str: The ARN string if available, otherwise None. + """ + if response is None: + return None + + arn_attr = _RESOURCE_ARN_ATTRIBUTES.get(type(response).__name__) + if not arn_attr: + return None + + arn = getattr(response, arn_attr, None) + + # Guard against Unassigned sentinel used in resources.py + if not arn or type(arn).__name__ == "Unassigned": + return None + + return str(arn) diff --git a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py index 45475b563b..ad54e95014 100644 --- a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py +++ b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py @@ -25,6 +25,7 @@ import boto3 from sagemaker.core.helper.session_helper import Session from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR +from sagemaker.core.telemetry.resource_creation import get_resource_arn from sagemaker.core.common_utils import resolve_value_from_config from sagemaker.core.config.config_schema import TELEMETRY_OPT_OUT_PATH from sagemaker.core.telemetry.constants import ( @@ -84,7 +85,7 @@ def wrapper(*args, **kwargs): sagemaker_session = None if len(args) > 0 and hasattr(args[0], "sagemaker_session"): # Get the sagemaker_session from the instance method args - sagemaker_session = args[0].sagemaker_session + sagemaker_session = args[0].sagemaker_session or _get_default_sagemaker_session() elif len(args) > 0 and hasattr(args[0], "_sagemaker_session"): # Get the sagemaker_session from the instance method args (private attribute) sagemaker_session = args[0]._sagemaker_session @@ -152,6 +153,11 @@ def wrapper(*args, **kwargs): stop_timer = perf_counter() elapsed = stop_timer - start_timer extra += f"&x-latency={round(elapsed, 2)}" + # For specified response types (e.g., TrainingJob), obtain the ARN of the + # resource created if present so that it can be included. + resource_arn = get_resource_arn(response) + if resource_arn: + extra += f"&x-resourceArn={resource_arn}" if not telemetry_opt_out_flag: _send_telemetry_request( STATUS_TO_CODE[str(Status.SUCCESS)], diff --git a/sagemaker-core/tests/unit/telemetry/test_resource_creation.py b/sagemaker-core/tests/unit/telemetry/test_resource_creation.py new file mode 100644 index 0000000000..d3ff58a0f8 --- /dev/null +++ b/sagemaker-core/tests/unit/telemetry/test_resource_creation.py @@ -0,0 +1,74 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from unittest.mock import MagicMock +from sagemaker.core.utils.utils import Unassigned +from sagemaker.core.telemetry.resource_creation import _RESOURCE_ARN_ATTRIBUTES, get_resource_arn + + +# Each entry: (class_name, arn_attr, arn_value) +_RESOURCE_TEST_CASES = [ + ( + "TrainingJob", + "training_job_arn", + "arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job", + ), +] + + +def test_get_resource_arn_none_response(): + assert get_resource_arn(None) is None + + +def test_get_resource_arn_unknown_type(): + assert get_resource_arn("some string") is None + assert get_resource_arn(42) is None + + +@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES) +def test_get_resource_arn_with_valid_arn(class_name, arn_attr, arn_value): + mock_resource = MagicMock() + mock_resource.__class__.__name__ = class_name + setattr(mock_resource, arn_attr, arn_value) + assert get_resource_arn(mock_resource) == arn_value + + +@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES) +def test_get_resource_arn_with_unassigned(class_name, arn_attr, arn_value): + mock_resource = MagicMock() + mock_resource.__class__.__name__ = class_name + setattr(mock_resource, arn_attr, Unassigned()) + assert get_resource_arn(mock_resource) is None + + +@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES) +def test_get_resource_arn_with_none_arn(class_name, arn_attr, arn_value): + mock_resource = MagicMock() + mock_resource.__class__.__name__ = class_name + setattr(mock_resource, arn_attr, None) + assert get_resource_arn(mock_resource) is None + + +# Verify string keys in _RESOURCE_ARN_ATTRIBUTES match actual class names +@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES) +def test_resource_class_name_matches_dict_key(class_name, arn_attr, arn_value): + from sagemaker.core.resources import TrainingJob + + _CLASS_MAP = { + "TrainingJob": TrainingJob, + } + cls = _CLASS_MAP.get(class_name) + assert cls is not None, f"No class found for key '{class_name}'" + assert cls.__name__ == class_name + assert class_name in _RESOURCE_ARN_ATTRIBUTES diff --git a/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py b/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py index 2973b12c44..6ce1cb3269 100644 --- a/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py +++ b/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py @@ -495,3 +495,50 @@ def test_construct_url_with_created_by(self): self.assertEqual(url, expected_url) self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", url) + + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_with_resource_arn( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-resourceArn is included when decorated function returns a TrainingJob.""" + mock_resolve_config.return_value = False + + mock_training_job = Mock() + mock_training_job.__class__.__name__ = "TrainingJob" + mock_training_job.training_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job" + ) + + class TrainingJobReturningMock: + def __init__(self): + self.sagemaker_session = MOCK_SESSION + + @_telemetry_emitter(MOCK_FEATURE, MOCK_FUNC_NAME) + def mock_train(self): + return mock_training_job + + TrainingJobReturningMock().mock_train() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + self.assertIn( + "x-resourceArn=arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job", + extra_str, + ) + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_without_resource_arn( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-resourceArn is NOT included when response has no registered ARN.""" + mock_resolve_config.return_value = False + + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + self.assertNotIn("x-resourceArn", extra_str) From ae2da621c90ef68652e88d10d1ad0a2934da0dd7 Mon Sep 17 00:00:00 2001 From: Ryan Tanaka Date: Mon, 23 Mar 2026 01:22:45 -0700 Subject: [PATCH 3/3] adding createdBy metadata to user agent string if attribution env var has been set to aid in resource attribution --- .../src/sagemaker/core/utils/user_agent.py | 29 +++++++++ .../tests/unit/generated/test_user_agent.py | 62 +++++++++++++++---- 2 files changed, 78 insertions(+), 13 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/utils/user_agent.py b/sagemaker-core/src/sagemaker/core/utils/user_agent.py index 4b5ee2cf24..2c24d0fb92 100644 --- a/sagemaker-core/src/sagemaker/core/utils/user_agent.py +++ b/sagemaker-core/src/sagemaker/core/utils/user_agent.py @@ -17,7 +17,31 @@ import importlib_metadata +from string import ascii_letters, digits + +from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR + SagemakerCore_PREFIX = "AWS-SageMakerCore" + +_USERAGENT_ALLOWED_CHARACTERS = ascii_letters + digits + "!$%&'*+-.^_`|~," + + +def sanitize_user_agent_string_component(raw_str, allow_hash=False): + """Sanitize a User-Agent string component by replacing disallowed characters with '-'. + + Args: + raw_str (str): The input string to sanitize. + allow_hash (bool): Whether '#' is considered an allowed character. + + Returns: + str: The sanitized string. + """ + return "".join( + c if c in _USERAGENT_ALLOWED_CHARACTERS or (allow_hash and c == "#") else "-" + for c in raw_str + ) + + STUDIO_PREFIX = "AWS-SageMaker-Studio" NOTEBOOK_PREFIX = "AWS-SageMaker-Notebook-Instance" @@ -74,4 +98,9 @@ def get_user_agent_extra_suffix() -> str: if studio_app_type: suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type) + # Add created_by metadata if attribution has been set + created_by = os.environ.get(_CREATED_BY_ENV_VAR) + if created_by: + suffix = "{} md/{}#{}".format(suffix, "createdBy", sanitize_user_agent_string_component(created_by)) + return suffix diff --git a/sagemaker-core/tests/unit/generated/test_user_agent.py b/sagemaker-core/tests/unit/generated/test_user_agent.py index 53b9e6cf99..8ebb5721a9 100644 --- a/sagemaker-core/tests/unit/generated/test_user_agent.py +++ b/sagemaker-core/tests/unit/generated/test_user_agent.py @@ -1,21 +1,12 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. from __future__ import absolute_import import json +import os from mock import patch, mock_open +import pytest +from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR from sagemaker.core.utils.user_agent import ( SagemakerCore_PREFIX, SagemakerCore_VERSION, @@ -24,8 +15,15 @@ process_notebook_metadata_file, process_studio_metadata_file, get_user_agent_extra_suffix, + sanitize_user_agent_string_component, ) -from sagemaker.core.utils.user_agent import SagemakerCore_PREFIX + + +@pytest.fixture(autouse=True) +def clean_env(): + yield + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] # Test process_notebook_metadata_file function @@ -58,6 +56,27 @@ def test_process_studio_metadata_file_not_exists(tmp_path): assert process_studio_metadata_file() is None +# Test sanitize_user_agent_string_component function +def test_sanitize_replaces_slash_with_dash(): + assert sanitize_user_agent_string_component("awslabs/agent-plugins/sagemaker-ai") == "awslabs-agent-plugins-sagemaker-ai" + + +def test_sanitize_allows_alphanumeric(): + assert sanitize_user_agent_string_component("abc123") == "abc123" + + +def test_sanitize_replaces_hash_when_not_allowed(): + assert sanitize_user_agent_string_component("foo#bar") == "foo-bar" + + +def test_sanitize_allows_hash_when_permitted(): + assert sanitize_user_agent_string_component("foo#bar", allow_hash=True) == "foo#bar" + + +def test_sanitize_replaces_space_with_dash(): + assert sanitize_user_agent_string_component("foo bar") == "foo-bar" + + # Test get_user_agent_extra_suffix function def test_get_user_agent_extra_suffix(): assert get_user_agent_extra_suffix() == f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION}" @@ -78,3 +97,20 @@ def test_get_user_agent_extra_suffix(): get_user_agent_extra_suffix() == f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION} md/{STUDIO_PREFIX}#studio_type" ) + + +def test_get_user_agent_extra_suffix_without_created_by(): + suffix = get_user_agent_extra_suffix() + assert "createdBy" not in suffix + + +def test_get_user_agent_extra_suffix_with_created_by(): + os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai" + suffix = get_user_agent_extra_suffix() + assert "md/createdBy#awslabs-agent-plugins-sagemaker-ai" in suffix + + +def test_get_user_agent_extra_suffix_created_by_sanitized(): + os.environ[_CREATED_BY_ENV_VAR] = "my agent/v1.0 (test)" + suffix = get_user_agent_extra_suffix() + assert "md/createdBy#my-agent-v1.0--test-" in suffix