Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def build(self):
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.builder.requirements_manager import RequirementsManager
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -2884,20 +2883,17 @@ def _save_inference_spec(self) -> None:
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
save_pkl(pkl_path, (self.inference_spec, self.schema_builder))

def _hmac_signing(self):
"""Perform HMAC signing on picke file for integrity check"""
secret_key = generate_secret_key()
def _compute_integrity_hash(self):
"""Compute SHA-256 hash of serve.pkl and store in metadata.json for integrity check."""
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")

with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)

with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

self.secret_key = secret_key

def _generate_config_pbtxt(self, pkl_path: Path):
"""Generate Triton config.pbtxt file."""
config_path = pkl_path.joinpath("config.pbtxt")
Expand Down Expand Up @@ -3075,7 +3071,8 @@ def _prepare_for_triton(self):
export_path.mkdir(parents=True)

if self.model:
self.secret_key = "dummy secret key for onnx backend"
# ONNX path: no pickle serialization, no serve.pkl, no integrity check needed.
# Do not set secret_key — there is nothing to sign.

if self.framework == Framework.PYTORCH:
self._export_pytorch_to_onnx(
Expand All @@ -3099,7 +3096,7 @@ def _prepare_for_triton(self):

self._pack_conda_env(pkl_path=pkl_path)

self._hmac_signing()
self._compute_integrity_hash()

return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -119,11 +118,10 @@ def prepare_for_mms(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
return ""
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _start_serving(
env = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
if env_vars:
Expand Down Expand Up @@ -131,7 +130,6 @@ def _upload_server_artifacts(
env_vars = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"LOCAL_PYTHON": platform.python_version(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -64,11 +63,10 @@ def prepare_for_smd(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
return ""
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def _upload_smd_artifacts(
"SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_INFERENCE_CODE": "inference.handler",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def _start_tei_serving(
secret_key: Secret key to use for authentication
env_vars: Environment variables to set
"""
if env_vars and secret_key:
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

self.container = client.containers.run(
image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -57,11 +56,10 @@ def prepare_for_tf_serving(
raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.")
_move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
return ""
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _start_tensorflow_serving(
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
Expand Down Expand Up @@ -124,7 +123,6 @@ def _upload_tensorflow_serving_artifacts(
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
Expand Down Expand Up @@ -67,11 +66,10 @@ def prepare_for_torchserve(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
return ""
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def _start_torch_serve(
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
Expand Down Expand Up @@ -103,7 +102,6 @@ def _upload_torchserve_artifacts(
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
10 changes: 7 additions & 3 deletions sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ def auto_complete_config(auto_complete_model_config):
def initialize(self, args: dict) -> None:
"""Placeholder docstring"""
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
with open(str(serve_path), mode="rb") as f:
inference_spec, schema_builder = cloudpickle.load(f)
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")

# TODO: HMAC signing for integrity check
# Integrity check BEFORE deserialization to prevent RCE via malicious pickle
with open(str(serve_path), "rb") as f:
buffer = f.read()
perform_integrity_check(buffer=buffer, metadata_path=metadata_path)

inference_spec, schema_builder = cloudpickle.loads(buffer)

self.inference_spec = inference_spec
self.schema_builder = schema_builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _start_triton_server(
env_vars.update(
{
"TRITON_MODEL_DIR": "/models/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
)
Expand Down Expand Up @@ -133,7 +132,6 @@ def _upload_triton_artifacts(
env_vars = {
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
"TRITON_MODEL_DIR": "/opt/ml/model/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
20 changes: 6 additions & 14 deletions sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
"""Validates the integrity of pickled file with HMAC signing."""
"""Validates the integrity of pickled file with SHA-256 hash."""

from __future__ import absolute_import
import secrets
import hmac
import hashlib
import os
from pathlib import Path

from sagemaker.core.remote_function.core.serialization import _MetaData


def generate_secret_key(nbytes: int = 32) -> str:
"""Generates secret key"""
return secrets.token_hex(nbytes)


def compute_hash(buffer: bytes, secret_key: str) -> str:
"""Compute hash value using HMAC"""
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
def compute_hash(buffer: bytes) -> str:
"""Compute SHA-256 hash of the given buffer."""
return hashlib.sha256(buffer).hexdigest()


def perform_integrity_check(buffer: bytes, metadata_path: Path):
"""Validates the integrity of bytes by comparing the hash value"""
secret_key = os.environ.get("SAGEMAKER_SERVE_SECRET_KEY")
actual_hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
"""Validates the integrity of bytes by comparing the hash value."""
actual_hash_value = compute_hash(buffer=buffer)

if not Path.exists(metadata_path):
raise ValueError("Path to metadata.json does not exist")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ def test_prepare_mms_js_resources(self, mock_create_dir, mock_copy_js):

@patch('builtins.input', return_value='')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies')
@patch('shutil.copy2')
def test_prepare_for_mms_creates_structure(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input):
def test_prepare_for_mms_creates_structure(self, mock_copy, mock_capture, mock_hash, mock_input):
"""Test prepare_for_mms creates directory structure and files."""
from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms

Expand All @@ -83,7 +82,6 @@ def test_prepare_for_mms_creates_structure(self, mock_copy, mock_capture, mock_g
serve_pkl = code_dir / "serve.pkl"
serve_pkl.write_bytes(b"test data")

mock_gen_key.return_value = "test-secret-key"
mock_hash.return_value = "test-hash"
mock_session = Mock()
mock_inference_spec = Mock()
Expand All @@ -98,16 +96,14 @@ def test_prepare_for_mms_creates_structure(self, mock_copy, mock_capture, mock_g
inference_spec=mock_inference_spec
)

self.assertEqual(secret_key, "test-secret-key")
mock_inference_spec.prepare.assert_called_once_with(str(model_path))
mock_capture.assert_called_once()

@patch('builtins.input', return_value='')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies')
@patch('shutil.copy2')
def test_prepare_for_mms_raises_on_invalid_dir(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input):
def test_prepare_for_mms_raises_on_invalid_dir(self, mock_copy, mock_capture, mock_hash, mock_input):
"""Test prepare_for_mms raises exception for invalid directory."""
from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms

Expand All @@ -128,10 +124,9 @@ def test_prepare_for_mms_raises_on_invalid_dir(self, mock_copy, mock_capture, mo

@patch('builtins.input', return_value='')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key')
@patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies')
@patch('shutil.copy2')
def test_prepare_for_mms_copies_shared_libs(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input):
def test_prepare_for_mms_copies_shared_libs(self, mock_copy, mock_capture, mock_hash, mock_input):
"""Test prepare_for_mms copies shared libraries."""
from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms

Expand All @@ -145,7 +140,6 @@ def test_prepare_for_mms_copies_shared_libs(self, mock_copy, mock_capture, mock_
shared_lib = Path(self.temp_dir) / "lib.so"
shared_lib.touch()

mock_gen_key.return_value = "test-key"
mock_hash.return_value = "test-hash"
mock_session = Mock()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def test_start_serving_creates_container(self, mock_path):
self.assertEqual(server.container, mock_container)
mock_client.containers.run.assert_called_once()
call_kwargs = mock_client.containers.run.call_args[1]
self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", call_kwargs["environment"])
self.assertEqual(call_kwargs["environment"]["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret")
self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", call_kwargs["environment"])

@patch('sagemaker.serve.model_server.multi_model_server.server.Path')
def test_start_serving_with_no_env_vars(self, mock_path):
Expand Down Expand Up @@ -166,8 +165,7 @@ def test_upload_server_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock
)

self.assertIsNotNone(model_data)
self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars)
self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key")
self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars)

@patch('sagemaker.serve.model_server.multi_model_server.server._is_s3_uri')
def test_upload_server_artifacts_no_upload(self, mock_is_s3):
Expand All @@ -187,7 +185,7 @@ def test_upload_server_artifacts_no_upload(self, mock_is_s3):
)

self.assertIsNone(model_data)
self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars)
self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars)


class TestUpdateEnvVars(unittest.TestCase):
Expand Down
Loading
Loading