Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -120,11 +119,10 @@ def prepare_for_mms(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
return ""
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def _start_serving(
env = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
if env_vars:
Expand Down Expand Up @@ -145,7 +144,6 @@ def _upload_server_artifacts(
env_vars = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"LOCAL_PYTHON": platform.python_version(),
Expand Down
6 changes: 2 additions & 4 deletions src/sagemaker/serve/model_server/smd/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -64,11 +63,10 @@ def prepare_for_smd(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
return ""
1 change: 0 additions & 1 deletion src/sagemaker/serve/model_server/smd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def _upload_smd_artifacts(
"SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_INFERENCE_CODE": "inference.handler",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
3 changes: 0 additions & 3 deletions src/sagemaker/serve/model_server/tei/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def _start_tei_serving(
secret_key: Secret key to use for authentication
env_vars: Environment variables to set
"""
if env_vars and secret_key:
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

self.container = client.containers.run(
image,
shm_size=_SHM_SIZE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -57,11 +56,10 @@ def prepare_for_tf_serving(
raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.")
_move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
return ""
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _start_tensorflow_serving(
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
Expand Down Expand Up @@ -142,7 +141,6 @@ def _upload_tensorflow_serving_artifacts(
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
6 changes: 2 additions & 4 deletions src/sagemaker/serve/model_server/torchserve/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
Expand Down Expand Up @@ -69,11 +68,10 @@ def prepare_for_torchserve(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
return ""
2 changes: 0 additions & 2 deletions src/sagemaker/serve/model_server/torchserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def _start_torch_serve(
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
Expand Down Expand Up @@ -116,7 +115,6 @@ def _upload_torchserve_artifacts(
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
6 changes: 4 additions & 2 deletions src/sagemaker/serve/model_server/triton/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def auto_complete_config(auto_complete_model_config):
def initialize(self, args: dict) -> None:
"""Placeholder docstring"""
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")
with open(str(serve_path), mode="rb") as f:
inference_spec, schema_builder = cloudpickle.load(f)
buffer = f.read()

# TODO: HMAC signing for integrity check
perform_integrity_check(buffer=buffer, metadata_path=metadata_path)
inference_spec, schema_builder = cloudpickle.loads(buffer)

self.inference_spec = inference_spec
self.schema_builder = schema_builder
Expand Down
2 changes: 0 additions & 2 deletions src/sagemaker/serve/model_server/triton/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def _start_triton_server(
env_vars.update(
{
"TRITON_MODEL_DIR": "/models/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
)
Expand Down Expand Up @@ -146,7 +145,6 @@ def _upload_triton_artifacts(
env_vars = {
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
"TRITON_MODEL_DIR": "/opt/ml/model/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
14 changes: 5 additions & 9 deletions src/sagemaker/serve/model_server/triton/triton_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)

Expand Down Expand Up @@ -213,7 +212,7 @@ def _prepare_for_triton(self):
export_path.mkdir(parents=True)

if self.model:
self.secret_key = "dummy secret key for onnx backend"
# ONNX path: no pickle serialization, no serve.pkl, no integrity check needed.

if self._framework == "pytorch":
self._export_pytorch_to_onnx(
Expand All @@ -237,26 +236,23 @@ def _prepare_for_triton(self):

self._pack_conda_env(pkl_path=pkl_path)

self._hmac_signing()
self._compute_integrity_hash()

return

raise ValueError("Either model or inference_spec should be provided to ModelBuilder.")

def _hmac_signing(self):
"""Perform HMAC signing on picke file for integrity check"""
secret_key = generate_secret_key()
def _compute_integrity_hash(self):
"""Compute SHA-256 integrity hash on pickle file for integrity check"""
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")

with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)

with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

self.secret_key = secret_key

def _generate_config_pbtxt(self, pkl_path: Path):
config_path = pkl_path.joinpath("config.pbtxt")

Expand Down
20 changes: 6 additions & 14 deletions src/sagemaker/serve/validations/check_integrity.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
"""Validates the integrity of pickled file with HMAC signing."""
"""Validates the integrity of pickled file with SHA-256 hash."""

from __future__ import absolute_import
import secrets
import hmac
import hashlib
import os
from pathlib import Path

from sagemaker.remote_function.core.serialization import _MetaData


def generate_secret_key(nbytes: int = 32) -> str:
"""Generates secret key"""
return secrets.token_hex(nbytes)


def compute_hash(buffer: bytes, secret_key: str) -> str:
"""Compute hash value using HMAC"""
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
def compute_hash(buffer: bytes) -> str:
"""Compute SHA-256 hash of the given buffer."""
return hashlib.sha256(buffer).hexdigest()


def perform_integrity_check(buffer: bytes, metadata_path: Path):
"""Validates the integrity of bytes by comparing the hash value"""
secret_key = os.environ.get("SAGEMAKER_SERVE_SECRET_KEY")
actual_hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
"""Validates the integrity of bytes by comparing the hash value."""
actual_hash_value = compute_hash(buffer=buffer)

if not Path.exists(metadata_path):
raise ValueError("Path to metadata.json does not exist")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_start_invoke_destroy_local_multi_model_server(self):
"KEY": "VALUE",
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": "secret_key",
"LOCAL_PYTHON": platform.python_version(),
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_start_invoke_destroy_local_tei_server(self, mock_requests):
"HF_HOME": "/opt/ml/model/",
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
"KEY": "VALUE",
"SAGEMAKER_SERVE_SECRET_KEY": "secret_key",
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def setUp(self):
)
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare._MetaData")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.shutil")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.Path")
Expand All @@ -52,7 +51,6 @@ def test_prepare_happy(
mock_path,
mock_shutil,
mock_capture_dependencies,
mock_generate_secret_key,
mock_compute_hash,
mock_metadata,
mock_get_saved_model_path,
Expand All @@ -65,16 +63,14 @@ def test_prepare_happy(
mock_path_instance.joinpath.return_value = Mock()
mock_get_saved_model_path.return_value = MODEL_PATH + "/1/"

mock_generate_secret_key.return_value = SECRET_KEY

secret_key = prepare_for_tf_serving(
model_path=MODEL_PATH,
shared_libs=SHARED_LIBS,
dependencies=DEPENDENCIES,
)

mock_path_instance.mkdir.assert_not_called()
self.assertEqual(secret_key, SECRET_KEY)
self.assertEqual(secret_key, "")

@patch("builtins.open", new_callable=mock_open, read_data=b"{}")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents")
Expand All @@ -84,7 +80,6 @@ def test_prepare_happy(
)
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare._MetaData")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.shutil")
@patch("sagemaker.serve.model_server.tensorflow_serving.prepare.Path")
Expand All @@ -93,7 +88,6 @@ def test_prepare_saved_model_not_found(
mock_path,
mock_shutil,
mock_capture_dependencies,
mock_generate_secret_key,
mock_compute_hash,
mock_metadata,
mock_get_saved_model_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def test_start_invoke_destroy_local_tensorflow_serving_server(self):
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": "secret_key",
"LOCAL_PYTHON": platform.python_version(),
"KEY": "VALUE",
},
Expand Down Expand Up @@ -97,5 +96,4 @@ def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platfo

mock_upload.assert_called_once_with(mock_session, MODEL_PATH, "mock_model_data_uri", ANY)
self.assertEqual(s3_upload_path, S3_URI)
self.assertEqual(env_vars.get("SAGEMAKER_SERVE_SECRET_KEY"), SECRET_KEY)
self.assertEqual(env_vars.get("LOCAL_PYTHON"), "3.8")
Loading
Loading