From 7101c0b77c367eb56667b65cd06eaa66d85d668e Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Wed, 18 Mar 2026 21:01:25 +0000 Subject: [PATCH] fix: Add HMAC integrity verification for Triton inference handler - Add HMAC integrity check before pickle deserialization in TritonPythonModel.initialize() - Replace hardcoded secret key with generate_secret_key() in _prepare_for_triton() ONNX path - Add _hmac_signing() after ONNX export for both PyTorch and TensorFlow frameworks - Add secret key validation in _start_triton_server() to reject None/empty keys Fixes RCE vulnerabilities in Triton handler by aligning with HMAC verification patterns used by TorchServe, MMS, TF Serving, and SMD handlers. --- .../sagemaker/serve/model_builder_utils.py | 4 ++-- .../serve/model_server/triton/model.py | 7 +++++-- .../serve/model_server/triton/server.py | 14 +++++++++++-- .../unit/test_model_builder_utils_triton.py | 20 +++++++++++++++---- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 8c1fd6db1b..c7189ec2ca 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -3075,8 +3075,8 @@ def _prepare_for_triton(self): export_path.mkdir(parents=True) if self.model: - self.secret_key = "dummy secret key for onnx backend" - + # ONNX path: export model to ONNX format for Triton's native ONNX backend. + # No pickle is created or loaded at runtime, so no HMAC signing is needed. if self.framework == Framework.PYTORCH: self._export_pytorch_to_onnx( export_path=export_path, model=self.model, schema_builder=self.schema_builder diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py b/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py index a1c731b0d6..0fd009677e 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py @@ -26,10 +26,13 @@ def auto_complete_config(auto_complete_model_config): def initialize(self, args: dict) -> None: """Placeholder docstring""" serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl") + metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json") with open(str(serve_path), mode="rb") as f: - inference_spec, schema_builder = cloudpickle.load(f) + buffer = f.read() + perform_integrity_check(buffer=buffer, metadata_path=str(metadata_path)) - # TODO: HMAC signing for integrity check + with open(str(serve_path), mode="rb") as f: + inference_spec, schema_builder = cloudpickle.load(f) self.inference_spec = inference_spec self.schema_builder = schema_builder diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py index 134f12dd42..5d19c9cd31 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py @@ -41,11 +41,16 @@ def _start_triton_server( env_vars.update( { "TRITON_MODEL_DIR": "/models/model", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } ) + # Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where + # pickle integrity verification is needed. The ONNX path does not + # use pickles, so no secret key is required. + if secret_key and isinstance(secret_key, str) and secret_key.strip(): + env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key + if "cpu" not in image_uri: self.container = docker_client.containers.run( image=image_uri, @@ -133,7 +138,12 @@ def _upload_triton_artifacts( env_vars = { "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model", "TRITON_MODEL_DIR": "/opt/ml/model/model", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } + + # Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where + # pickle integrity verification is needed. + if secret_key and isinstance(secret_key, str) and secret_key.strip(): + env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key + return s3_upload_path, env_vars diff --git a/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py b/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py index bb0d1d874c..25c72b1647 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py +++ b/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py @@ -81,9 +81,14 @@ class TestPrepareForTriton(unittest.TestCase): """Test _prepare_for_triton method.""" @patch('shutil.copy2') + @patch.object(_ModelBuilderUtils, '_hmac_signing') @patch.object(_ModelBuilderUtils, '_export_pytorch_to_onnx') - def test_prepare_for_triton_pytorch(self, mock_export, mock_copy): - """Test preparing PyTorch model for Triton.""" + def test_prepare_for_triton_pytorch(self, mock_export, mock_hmac, mock_copy): + """Test preparing PyTorch model for Triton. + + ONNX path: no pickle is created or loaded at runtime, + so no HMAC signing is needed. + """ utils = _ModelBuilderUtils() utils.framework = Framework.PYTORCH utils.model = Mock() @@ -94,11 +99,17 @@ def test_prepare_for_triton_pytorch(self, mock_export, mock_copy): utils._prepare_for_triton() mock_export.assert_called_once() + mock_hmac.assert_not_called() @patch('shutil.copy2') + @patch.object(_ModelBuilderUtils, '_hmac_signing') @patch.object(_ModelBuilderUtils, '_export_tf_to_onnx') - def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy): - """Test preparing TensorFlow model for Triton.""" + def test_prepare_for_triton_tensorflow(self, mock_export, mock_hmac, mock_copy): + """Test preparing TensorFlow model for Triton. + + ONNX path: no pickle is created or loaded at runtime, + so no HMAC signing is needed. + """ utils = _ModelBuilderUtils() utils.framework = Framework.TENSORFLOW utils.model = Mock() @@ -109,6 +120,7 @@ def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy): utils._prepare_for_triton() mock_export.assert_called_once() + mock_hmac.assert_not_called() @patch('shutil.copy2') @patch.object(_ModelBuilderUtils, '_generate_config_pbtxt')