Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
# Partner App
from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401

# Attribution
from sagemaker.core.telemetry.attribution import Attribution, set_attribution # noqa: F401

# Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner
# They are not re-exported from core to avoid circular dependencies
41 changes: 41 additions & 0 deletions sagemaker-core/src/sagemaker/core/telemetry/attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Attribution module for tracking the provenance of SDK usage."""
from __future__ import absolute_import
import os
from enum import Enum

_CREATED_BY_ENV_VAR = "SAGEMAKER_PYSDK_CREATED_BY"


class Attribution(Enum):
"""Enumeration of known SDK attribution sources."""

SAGEMAKER_AGENT_PLUGIN = "awslabs/agent-plugins/sagemaker-ai"


def set_attribution(attribution: Attribution):
"""Sets the SDK usage attribution to the specified source.

Call this at the top of scripts generated by an agent or integration
to enable accurate telemetry attribution.

Args:
attribution (Attribution): The attribution source to set.

Raises:
TypeError: If attribution is not an Attribution enum member.
"""
if not isinstance(attribution, Attribution):
raise TypeError(f"attribution must be an Attribution enum member, got {type(attribution)}")
os.environ[_CREATED_BY_ENV_VAR] = attribution.value
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
"""Telemetry module for SageMaker Python SDK to collect usage data and metrics."""
from __future__ import absolute_import
import logging
import os
import platform
import sys
from time import perf_counter
from typing import List
import functools
import requests
from urllib.parse import quote

import boto3
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
from sagemaker.core.common_utils import resolve_value_from_config
from sagemaker.core.config.config_schema import TELEMETRY_OPT_OUT_PATH
from sagemaker.core.telemetry.constants import (
Expand Down Expand Up @@ -137,6 +140,11 @@ def wrapper(*args, **kwargs):
if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"

# Add created_by from environment variable if available
created_by = os.environ.get(_CREATED_BY_ENV_VAR, "")
if created_by:
extra += f"&x-createdBy={quote(created_by, safe='')}"

start_timer = perf_counter()
try:
# Call the original function
Expand Down
37 changes: 37 additions & 0 deletions sagemaker-core/tests/unit/telemetry/test_attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os
import pytest
from sagemaker.core.telemetry.attribution import (
_CREATED_BY_ENV_VAR,
Attribution,
set_attribution,
)


@pytest.fixture(autouse=True)
def clean_env():
yield
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]


def test_set_attribution_sagemaker_agent_plugin():
set_attribution(Attribution.SAGEMAKER_AGENT_PLUGIN)
assert os.environ[_CREATED_BY_ENV_VAR] == Attribution.SAGEMAKER_AGENT_PLUGIN.value


def test_set_attribution_invalid_type_raises():
with pytest.raises(TypeError):
set_attribution("awslabs/agent-plugins/sagemaker-ai")
150 changes: 144 additions & 6 deletions sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os
import unittest
import pytest
import requests
from unittest.mock import Mock, patch, MagicMock
import boto3
import sagemaker
from sagemaker.core.telemetry.constants import Feature
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
from sagemaker.core.telemetry.telemetry_logging import (
_send_telemetry_request,
_telemetry_emitter,
Expand All @@ -33,16 +35,23 @@

# Try to import sagemaker-serve exceptions, skip tests if not available
try:
from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException
from sagemaker.serve.utils.exceptions import (
ModelBuilderException,
LocalModelOutOfMemoryException,
)

SAGEMAKER_SERVE_AVAILABLE = True
except ImportError:
SAGEMAKER_SERVE_AVAILABLE = False

# Create mock exceptions for type hints
class ModelBuilderException(Exception):
pass

class LocalModelOutOfMemoryException(Exception):
pass


MOCK_SESSION = Mock()
MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex")
MOCK_FEATURE = Feature.SDK_DEFAULTS
Expand Down Expand Up @@ -158,10 +167,7 @@ def test_telemetry_emitter_decorator_success(
1, [11, 12], MOCK_SESSION, None, None, expected_extra_str
)

@pytest.mark.skipif(
not SAGEMAKER_SERVE_AVAILABLE,
reason="Requires sagemaker-serve package"
)
@pytest.mark.skipif(not SAGEMAKER_SERVE_AVAILABLE, reason="Requires sagemaker-serve package")
@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
def test_telemetry_emitter_decorator_handle_exception_success(
Expand Down Expand Up @@ -194,7 +200,7 @@ def test_telemetry_emitter_decorator_handle_exception_success(

mock_send_telemetry_request.assert_called_once_with(
0,
[1, 2],
[11, 12],
MOCK_SESSION,
str(mock_exception_obj),
mock_exception_obj.__class__.__name__,
Expand Down Expand Up @@ -357,3 +363,135 @@ def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_a
_send_telemetry_request(1, [1, 2], mock_session)
# Assert telemetry request was not sent
mock_requests_helper.assert_not_called()

@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
def test_telemetry_emitter_with_created_by_env_var(
self, mock_resolve_config, mock_send_telemetry_request
):
"""Test that x-createdBy is included when SAGEMAKER_PYSDK_CREATED_BY env var is set"""
mock_resolve_config.return_value = False

# Set environment variable
os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai"

try:
mock_local_client = LocalSagemakerClientMock()
mock_local_client.mock_create_model()

args = mock_send_telemetry_request.call_args.args
extra_str = str(args[5])

# Verify x-createdBy is in the extra string with URL encoding
self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", extra_str)

# Verify forward slashes are encoded as %2F
self.assertNotIn("x-createdBy=awslabs/agent-plugins", extra_str)
finally:
# Clean up environment variable
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]

@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
def test_telemetry_emitter_without_created_by_env_var(
self, mock_resolve_config, mock_send_telemetry_request
):
"""Test that x-createdBy is NOT included when env var is not set"""
mock_resolve_config.return_value = False

# Ensure environment variable is not set
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]

mock_local_client = LocalSagemakerClientMock()
mock_local_client.mock_create_model()

args = mock_send_telemetry_request.call_args.args
extra_str = str(args[5])

# Verify x-createdBy is NOT in the extra string
self.assertNotIn("x-createdBy", extra_str)

@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
def test_telemetry_emitter_created_by_with_special_chars(
self, mock_resolve_config, mock_send_telemetry_request
):
"""Test that x-createdBy properly URL-encodes special characters"""
mock_resolve_config.return_value = False

# Set environment variable with special characters
os.environ[_CREATED_BY_ENV_VAR] = "My App & Tools (v2.0)"

try:
mock_local_client = LocalSagemakerClientMock()
mock_local_client.mock_create_model()

args = mock_send_telemetry_request.call_args.args
extra_str = str(args[5])

# Verify special characters are URL-encoded
self.assertIn("x-createdBy=My%20App%20%26%20Tools%20%28v2.0%29", extra_str)

# Verify raw special characters are NOT in the URL
self.assertNotIn("My App & Tools", extra_str)
self.assertNotIn("(v2.0)", extra_str)
finally:
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]

@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
def test_telemetry_emitter_created_by_empty_string(
self, mock_resolve_config, mock_send_telemetry_request
):
"""Test that x-createdBy is NOT included when env var is empty string"""
mock_resolve_config.return_value = False

# Set environment variable to empty string
os.environ[_CREATED_BY_ENV_VAR] = ""

try:
mock_local_client = LocalSagemakerClientMock()
mock_local_client.mock_create_model()

args = mock_send_telemetry_request.call_args.args
extra_str = str(args[5])

# Verify x-createdBy is NOT added for empty string
self.assertNotIn("x-createdBy", extra_str)
finally:
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]

def test_construct_url_with_created_by(self):
"""Test URL construction includes x-createdBy in extra_info"""
mock_accountId = "123456789012"
mock_region = "us-west-2"
mock_status = "1"
mock_feature = "15"
mock_extra_info = (
"DataSet.create&x-sdkVersion=3.0&x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai"
)

url = _construct_url(
accountId=mock_accountId,
region=mock_region,
status=mock_status,
feature=mock_feature,
failure_reason=None,
failure_type=None,
extra_info=mock_extra_info,
)

expected_url = (
f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?"
f"x-accountId={mock_accountId}"
f"&x-status={mock_status}"
f"&x-feature={mock_feature}"
f"&x-extra={mock_extra_info}"
)

self.assertEqual(url, expected_url)
self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", url)
Loading