diff --git a/sagemaker-core/src/sagemaker/core/__init__.py b/sagemaker-core/src/sagemaker/core/__init__.py index 97192083a7..f25f18009d 100644 --- a/sagemaker-core/src/sagemaker/core/__init__.py +++ b/sagemaker-core/src/sagemaker/core/__init__.py @@ -15,5 +15,8 @@ # Partner App from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401 +# Attribution +from sagemaker.core.telemetry.attribution import Attribution, set_attribution # noqa: F401 + # Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner # They are not re-exported from core to avoid circular dependencies diff --git a/sagemaker-core/src/sagemaker/core/telemetry/attribution.py b/sagemaker-core/src/sagemaker/core/telemetry/attribution.py new file mode 100644 index 0000000000..1ba016f434 --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/telemetry/attribution.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Attribution module for tracking the provenance of SDK usage.""" +from __future__ import absolute_import +import os +from enum import Enum + +_CREATED_BY_ENV_VAR = "SAGEMAKER_PYSDK_CREATED_BY" + + +class Attribution(Enum): + """Enumeration of known SDK attribution sources.""" + + SAGEMAKER_AGENT_PLUGIN = "awslabs/agent-plugins/sagemaker-ai" + + +def set_attribution(attribution: Attribution): + """Sets the SDK usage attribution to the specified source. + + Call this at the top of scripts generated by an agent or integration + to enable accurate telemetry attribution. + + Args: + attribution (Attribution): The attribution source to set. + + Raises: + TypeError: If attribution is not an Attribution enum member. + """ + if not isinstance(attribution, Attribution): + raise TypeError(f"attribution must be an Attribution enum member, got {type(attribution)}") + os.environ[_CREATED_BY_ENV_VAR] = attribution.value diff --git a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py index 6388cdfe5e..45475b563b 100644 --- a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py +++ b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py @@ -13,15 +13,18 @@ """Telemetry module for SageMaker Python SDK to collect usage data and metrics.""" from __future__ import absolute_import import logging +import os import platform import sys from time import perf_counter from typing import List import functools import requests +from urllib.parse import quote import boto3 from sagemaker.core.helper.session_helper import Session +from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR from sagemaker.core.common_utils import resolve_value_from_config from sagemaker.core.config.config_schema import TELEMETRY_OPT_OUT_PATH from sagemaker.core.telemetry.constants import ( @@ -137,6 +140,11 @@ def wrapper(*args, **kwargs): if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn: extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}" + # Add created_by from environment variable if available + created_by = os.environ.get(_CREATED_BY_ENV_VAR, "") + if created_by: + extra += f"&x-createdBy={quote(created_by, safe='')}" + start_timer = perf_counter() try: # Call the original function diff --git a/sagemaker-core/tests/unit/telemetry/test_attribution.py b/sagemaker-core/tests/unit/telemetry/test_attribution.py new file mode 100644 index 0000000000..bd7cc3a907 --- /dev/null +++ b/sagemaker-core/tests/unit/telemetry/test_attribution.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import os +import pytest +from sagemaker.core.telemetry.attribution import ( + _CREATED_BY_ENV_VAR, + Attribution, + set_attribution, +) + + +@pytest.fixture(autouse=True) +def clean_env(): + yield + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + +def test_set_attribution_sagemaker_agent_plugin(): + set_attribution(Attribution.SAGEMAKER_AGENT_PLUGIN) + assert os.environ[_CREATED_BY_ENV_VAR] == Attribution.SAGEMAKER_AGENT_PLUGIN.value + + +def test_set_attribution_invalid_type_raises(): + with pytest.raises(TypeError): + set_attribution("awslabs/agent-plugins/sagemaker-ai") diff --git a/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py b/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py index 41af998d73..2973b12c44 100644 --- a/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py +++ b/sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import os import unittest import pytest import requests @@ -18,6 +19,7 @@ import boto3 import sagemaker from sagemaker.core.telemetry.constants import Feature +from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR from sagemaker.core.telemetry.telemetry_logging import ( _send_telemetry_request, _telemetry_emitter, @@ -33,16 +35,23 @@ # Try to import sagemaker-serve exceptions, skip tests if not available try: - from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException + from sagemaker.serve.utils.exceptions import ( + ModelBuilderException, + LocalModelOutOfMemoryException, + ) + SAGEMAKER_SERVE_AVAILABLE = True except ImportError: SAGEMAKER_SERVE_AVAILABLE = False + # Create mock exceptions for type hints class ModelBuilderException(Exception): pass + class LocalModelOutOfMemoryException(Exception): pass + MOCK_SESSION = Mock() MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") MOCK_FEATURE = Feature.SDK_DEFAULTS @@ -158,10 +167,7 @@ def test_telemetry_emitter_decorator_success( 1, [11, 12], MOCK_SESSION, None, None, expected_extra_str ) - @pytest.mark.skipif( - not SAGEMAKER_SERVE_AVAILABLE, - reason="Requires sagemaker-serve package" - ) + @pytest.mark.skipif(not SAGEMAKER_SERVE_AVAILABLE, reason="Requires sagemaker-serve package") @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") def test_telemetry_emitter_decorator_handle_exception_success( @@ -194,7 +200,7 @@ def test_telemetry_emitter_decorator_handle_exception_success( mock_send_telemetry_request.assert_called_once_with( 0, - [1, 2], + [11, 12], MOCK_SESSION, str(mock_exception_obj), mock_exception_obj.__class__.__name__, @@ -357,3 +363,135 @@ def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_a _send_telemetry_request(1, [1, 2], mock_session) # Assert telemetry request was not sent mock_requests_helper.assert_not_called() + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_with_created_by_env_var( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-createdBy is included when SAGEMAKER_PYSDK_CREATED_BY env var is set""" + mock_resolve_config.return_value = False + + # Set environment variable + os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai" + + try: + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + + # Verify x-createdBy is in the extra string with URL encoding + self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", extra_str) + + # Verify forward slashes are encoded as %2F + self.assertNotIn("x-createdBy=awslabs/agent-plugins", extra_str) + finally: + # Clean up environment variable + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_without_created_by_env_var( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-createdBy is NOT included when env var is not set""" + mock_resolve_config.return_value = False + + # Ensure environment variable is not set + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + + # Verify x-createdBy is NOT in the extra string + self.assertNotIn("x-createdBy", extra_str) + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_created_by_with_special_chars( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-createdBy properly URL-encodes special characters""" + mock_resolve_config.return_value = False + + # Set environment variable with special characters + os.environ[_CREATED_BY_ENV_VAR] = "My App & Tools (v2.0)" + + try: + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + + # Verify special characters are URL-encoded + self.assertIn("x-createdBy=My%20App%20%26%20Tools%20%28v2.0%29", extra_str) + + # Verify raw special characters are NOT in the URL + self.assertNotIn("My App & Tools", extra_str) + self.assertNotIn("(v2.0)", extra_str) + finally: + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_created_by_empty_string( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test that x-createdBy is NOT included when env var is empty string""" + mock_resolve_config.return_value = False + + # Set environment variable to empty string + os.environ[_CREATED_BY_ENV_VAR] = "" + + try: + mock_local_client = LocalSagemakerClientMock() + mock_local_client.mock_create_model() + + args = mock_send_telemetry_request.call_args.args + extra_str = str(args[5]) + + # Verify x-createdBy is NOT added for empty string + self.assertNotIn("x-createdBy", extra_str) + finally: + if _CREATED_BY_ENV_VAR in os.environ: + del os.environ[_CREATED_BY_ENV_VAR] + + def test_construct_url_with_created_by(self): + """Test URL construction includes x-createdBy in extra_info""" + mock_accountId = "123456789012" + mock_region = "us-west-2" + mock_status = "1" + mock_feature = "15" + mock_extra_info = ( + "DataSet.create&x-sdkVersion=3.0&x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai" + ) + + url = _construct_url( + accountId=mock_accountId, + region=mock_region, + status=mock_status, + feature=mock_feature, + failure_reason=None, + failure_type=None, + extra_info=mock_extra_info, + ) + + expected_url = ( + f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?" + f"x-accountId={mock_accountId}" + f"&x-status={mock_status}" + f"&x-feature={mock_feature}" + f"&x-extra={mock_extra_info}" + ) + + self.assertEqual(url, expected_url) + self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", url)