From 228240a09fe53e5ce9630a1bab301f6fcfab1348 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:09:46 -0400 Subject: [PATCH 1/2] fix: bug: ModelBuilder overwrites user-provided HF_MODEL_ID for DJL Serving, preventi (5529) --- .../sagemaker/serve/model_builder_servers.py | 14 +- .../test_model_builder_servers_hf_model_id.py | 275 ++++++++++++++++++ 2 files changed, 282 insertions(+), 7 deletions(-) create mode 100644 sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index 43af8b4f7a..48b8e0b307 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -136,7 +136,7 @@ def _build_for_torchserve(self) -> Model: if isinstance(self.model, str): # Configure HuggingFace model support if not self._is_jumpstart_model_id(): - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) # Add HuggingFace token if available if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): @@ -212,7 +212,7 @@ def _build_for_tgi(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for TGI - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") @@ -320,7 +320,7 @@ def _build_for_djl(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for DJL - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) # Get model configuration for DJL optimization self.hf_model_config = _get_model_config_properties_from_hf( @@ -426,7 +426,7 @@ def _build_for_triton(self) -> Model: self.env_vars.update({"HF_TASK": model_task}) # Configure HuggingFace authentication - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") @@ -532,7 +532,7 @@ def _build_for_tei(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for TEI - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") @@ -676,7 +676,7 @@ def _build_for_transformers(self) -> Model: if self.inference_spec is not None: hf_model_id = self.inference_spec.get_model() if isinstance(hf_model_id, str): # Only if it's a valid HF model ID - self.env_vars.update({"HF_MODEL_ID": hf_model_id}) + self.env_vars.setdefault("HF_MODEL_ID", hf_model_id) # Get HF config only for string model IDs if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( @@ -687,7 +687,7 @@ def _build_for_transformers(self) -> Model: hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) # Get HF config for string model IDs if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py new file mode 100644 index 0000000000..0b06dd9877 --- /dev/null +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py @@ -0,0 +1,275 @@ +"""Unit tests to verify HF_MODEL_ID is not overwritten when user provides it.""" +import unittest +from unittest.mock import Mock, patch, MagicMock, PropertyMock + +from sagemaker.serve.model_builder_servers import _ModelBuilderServers +from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.mode.function_pointers import Mode + + +def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"): + """Create a mock builder with common attributes set.""" + builder = MagicMock(spec=_ModelBuilderServers) + builder.model = model + builder.env_vars = env_vars if env_vars is not None else {} + builder.model_path = "/tmp/test_model_path" + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder.model_server = ModelServer.DJL_SERVING + builder.secret_key = "" + builder.s3_upload_path = None + builder.s3_model_data_url = None + builder.shared_libs = [] + builder.dependencies = {} + builder.image_uri = "test-image-uri" + builder.instance_type = "ml.g5.2xlarge" + builder.sagemaker_session = Mock() + builder.schema_builder = MagicMock() + builder.schema_builder.sample_input = {"inputs": "Hello", "parameters": {}} + builder.inference_spec = None + builder.hf_model_config = {} + builder.model_data_download_timeout = None + builder._user_provided_instance_type = True + builder._is_jumpstart_model_id = Mock(return_value=False) + builder._auto_detect_image_uri = Mock() + builder._prepare_for_mode = Mock(return_value=("s3://model-data", None)) + builder._create_model = Mock(return_value=Mock()) + builder._optimizing = False + builder._validate_djl_serving_sample_data = Mock() + builder._validate_tgi_serving_sample_data = Mock() + builder._validate_for_triton = Mock() + builder.get_huggingface_model_metadata = Mock(return_value={"pipeline_tag": "text-generation"}) + builder.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" + return builder + + +class TestDjlPreservesHfModelId(unittest.TestCase): + """Test that _build_for_djl preserves user-provided HF_MODEL_ID.""" + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) + def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + mock_hf_config.return_value = {} + mock_djl_config.return_value = ({}, 256) + + s3_path = "s3://my-bucket/models/Qwen/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + + with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_djl(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) + def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + mock_hf_config.return_value = {} + mock_djl_config.return_value = ({}, 256) + + builder = _create_mock_builder(env_vars={}) + + with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_djl(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTgiPreservesHfModelId(unittest.TestCase): + """Test that _build_for_tgi preserves user-provided HF_MODEL_ID.""" + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) + def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + mock_hf_config.return_value = {} + mock_tgi_config.return_value = ({}, 256) + + s3_path = "s3://my-bucket/models/Qwen/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.TGI + + with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_tgi(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) + def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + mock_hf_config.return_value = {} + mock_tgi_config.return_value = ({}, 256) + + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.TGI + + with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_tgi(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTeiPreservesHfModelId(unittest.TestCase): + """Test that _build_for_tei preserves user-provided HF_MODEL_ID.""" + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + def test_preserves_user_provided_s3_uri(self, mock_nb, mock_hf_config): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + mock_hf_config.return_value = {} + + s3_path = "s3://my-bucket/models/embedding-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.TEI + + with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_tei(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + def test_sets_hf_model_id_when_not_provided(self, mock_nb, mock_hf_config): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + mock_hf_config.return_value = {} + + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.TEI + + with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_tei(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTorchservePreservesHfModelId(unittest.TestCase): + """Test that _build_for_torchserve preserves user-provided HF_MODEL_ID.""" + + def test_preserves_user_provided_s3_uri(self): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + s3_path = "s3://my-bucket/models/my-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.TORCHSERVE + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder._save_model_inference_spec = Mock() + + _ModelBuilderServers._build_for_torchserve(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + def test_sets_hf_model_id_when_not_provided(self): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.TORCHSERVE + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder._save_model_inference_spec = Mock() + + _ModelBuilderServers._build_for_torchserve(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTritonPreservesHfModelId(unittest.TestCase): + """Test that _build_for_triton preserves user-provided HF_MODEL_ID.""" + + def test_preserves_user_provided_s3_uri(self): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + s3_path = "s3://my-bucket/models/my-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.TRITON + builder._save_inference_spec = Mock() + builder._prepare_for_triton = Mock() + builder._auto_detect_image_for_triton = Mock() + + _ModelBuilderServers._build_for_triton(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + def test_sets_hf_model_id_when_not_provided(self): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.TRITON + builder._save_inference_spec = Mock() + builder._prepare_for_triton = Mock() + builder._auto_detect_image_for_triton = Mock() + + _ModelBuilderServers._build_for_triton(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTransformersPreservesHfModelId(unittest.TestCase): + """Test that _build_for_transformers preserves user-provided HF_MODEL_ID.""" + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + def test_preserves_user_provided_s3_uri_with_model_string(self, mock_nb, mock_hf_config): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten when model is a string.""" + mock_hf_config.return_value = {} + + s3_path = "s3://my-bucket/models/my-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.MMS + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder.model_data_download_timeout = None + + with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_transformers(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + def test_sets_hf_model_id_when_not_provided_with_model_string(self, mock_nb, mock_hf_config): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + mock_hf_config.return_value = {} + + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.MMS + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder.model_data_download_timeout = None + + with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_transformers(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers.save_pkl") + def test_preserves_user_provided_hf_model_id_with_inference_spec(self, mock_pkl, mock_nb, mock_hf_config): + """User-provided HF_MODEL_ID should not be overwritten when inference_spec provides a model ID.""" + mock_hf_config.return_value = {} + + s3_path = "s3://my-bucket/models/my-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.MMS + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder.model_data_download_timeout = None + builder.model = None # No model string, using inference_spec + builder.inference_spec = Mock() + builder.inference_spec.get_model.return_value = "some-hf-model-id" + builder._is_jumpstart_model_id = Mock(return_value=False) + + with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + with patch("os.makedirs"): + _ModelBuilderServers._build_for_transformers(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + +if __name__ == "__main__": + unittest.main() From 4badbdd00bf3e8bbf76eec0aa96d71b7626d2221 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:49:41 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- .../test_model_builder_servers_hf_model_id.py | 521 +++++++++++------- 1 file changed, 325 insertions(+), 196 deletions(-) diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py index 0b06dd9877..1af9891cc5 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py @@ -1,13 +1,20 @@ """Unit tests to verify HF_MODEL_ID is not overwritten when user provides it.""" -import unittest -from unittest.mock import Mock, patch, MagicMock, PropertyMock +from __future__ import annotations + +from typing import Optional +from unittest.mock import Mock, patch, MagicMock + +import pytest from sagemaker.serve.model_builder_servers import _ModelBuilderServers from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode -def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"): +def _create_mock_builder( + env_vars: Optional[dict[str, str]] = None, + model: str = "Qwen/Qwen3-VL-4B-Instruct", +) -> MagicMock: """Create a mock builder with common attributes set.""" builder = MagicMock(spec=_ModelBuilderServers) builder.model = model @@ -24,252 +31,374 @@ def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"): builder.instance_type = "ml.g5.2xlarge" builder.sagemaker_session = Mock() builder.schema_builder = MagicMock() - builder.schema_builder.sample_input = {"inputs": "Hello", "parameters": {}} + builder.schema_builder.sample_input = { + "inputs": "Hello", + "parameters": {}, + } builder.inference_spec = None builder.hf_model_config = {} builder.model_data_download_timeout = None builder._user_provided_instance_type = True builder._is_jumpstart_model_id = Mock(return_value=False) builder._auto_detect_image_uri = Mock() - builder._prepare_for_mode = Mock(return_value=("s3://model-data", None)) + builder._prepare_for_mode = Mock( + return_value=("s3://model-data", None) + ) builder._create_model = Mock(return_value=Mock()) builder._optimizing = False builder._validate_djl_serving_sample_data = Mock() builder._validate_tgi_serving_sample_data = Mock() builder._validate_for_triton = Mock() - builder.get_huggingface_model_metadata = Mock(return_value={"pipeline_tag": "text-generation"}) - builder.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" + builder.get_huggingface_model_metadata = Mock( + return_value={"pipeline_tag": "text-generation"} + ) + builder.role_arn = ( + "arn:aws:iam::123456789012:role/SageMakerRole" + ) return builder -class TestDjlPreservesHfModelId(unittest.TestCase): - """Test that _build_for_djl preserves user-provided HF_MODEL_ID.""" - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) - @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) - def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - mock_hf_config.return_value = {} - mock_djl_config.return_value = ({}, 256) - - s3_path = "s3://my-bucket/models/Qwen/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) - - with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"): +@pytest.fixture +def mock_builder() -> MagicMock: + """Create a mock builder with default (empty) env_vars.""" + return _create_mock_builder(env_vars={}) + + +@pytest.fixture +def mock_builder_with_s3() -> MagicMock: + """Create a mock builder with user-provided S3 HF_MODEL_ID.""" + return _create_mock_builder( + env_vars={"HF_MODEL_ID": "s3://my-bucket/models/Qwen/"} + ) + + +S3_PATH = "s3://my-bucket/models/Qwen/" +DEFAULT_MODEL = "Qwen/Qwen3-VL-4B-Instruct" + + +# --------------------------------------------------------------------------- +# DJL Serving +# --------------------------------------------------------------------------- +class TestBuildForDjlHfModelId: + """Test _build_for_djl preserves user-provided HF_MODEL_ID.""" + + _patches = [ + patch( + "sagemaker.serve.model_builder_servers" + "._get_default_tensor_parallel_degree", + return_value=1, + ), + patch( + "sagemaker.serve.model_builder_servers._get_gpu_info", + return_value=1, + ), + patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_default_djl_configurations", + return_value=({}, 256), + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ), + patch( + "sagemaker.serve.model_server.djl_serving" + ".prepare._create_dir_structure", + ), + ] + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + """User-provided S3 URI should not be overwritten.""" + builder = mock_builder_with_s3 + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_djl(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) - @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) - def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - mock_hf_config.return_value = {} - mock_djl_config.return_value = ({}, 256) - - builder = _create_mock_builder(env_vars={}) - - with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"): + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + def test_sets_default_when_not_provided( + self, mock_builder + ): + """HF_MODEL_ID should default to self.model.""" + builder = mock_builder + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_djl(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - -class TestTgiPreservesHfModelId(unittest.TestCase): - """Test that _build_for_tgi preserves user-provided HF_MODEL_ID.""" - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) - @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) - def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - mock_hf_config.return_value = {} - mock_tgi_config.return_value = ({}, 256) - - s3_path = "s3://my-bucket/models/Qwen/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------------------- +# TGI +# --------------------------------------------------------------------------- +class TestBuildForTgiHfModelId: + """Test _build_for_tgi preserves user-provided HF_MODEL_ID.""" + + _patches = [ + patch( + "sagemaker.serve.model_builder_servers" + "._get_default_tensor_parallel_degree", + return_value=1, + ), + patch( + "sagemaker.serve.model_builder_servers._get_gpu_info", + return_value=1, + ), + patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_default_tgi_configurations", + return_value=({}, 256), + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ), + patch( + "sagemaker.serve.model_server.tgi" + ".prepare._create_dir_structure", + ), + ] + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.TGI - - with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_tgi(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) - @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) - def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - mock_hf_config.return_value = {} - mock_tgi_config.return_value = ({}, 256) - - builder = _create_mock_builder(env_vars={}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.TGI - - with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_tgi(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - -class TestTeiPreservesHfModelId(unittest.TestCase): - """Test that _build_for_tei preserves user-provided HF_MODEL_ID.""" - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - def test_preserves_user_provided_s3_uri(self, mock_nb, mock_hf_config): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - mock_hf_config.return_value = {} - - s3_path = "s3://my-bucket/models/embedding-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------------------- +# TEI +# --------------------------------------------------------------------------- +class TestBuildForTeiHfModelId: + """Test _build_for_tei preserves user-provided HF_MODEL_ID.""" + + _patches = [ + patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ), + patch( + "sagemaker.serve.model_server.tgi" + ".prepare._create_dir_structure", + ), + ] + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.TEI - - with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_tei(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - def test_sets_hf_model_id_when_not_provided(self, mock_nb, mock_hf_config): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - mock_hf_config.return_value = {} - - builder = _create_mock_builder(env_vars={}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.TEI - - with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_tei(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - -class TestTorchservePreservesHfModelId(unittest.TestCase): - """Test that _build_for_torchserve preserves user-provided HF_MODEL_ID.""" - - def test_preserves_user_provided_s3_uri(self): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - s3_path = "s3://my-bucket/models/my-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------------------- +# TorchServe +# --------------------------------------------------------------------------- +class TestBuildForTorchserveHfModelId: + """Test _build_for_torchserve preserves user-provided HF_MODEL_ID.""" + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.TORCHSERVE - builder.mode = Mode.SAGEMAKER_ENDPOINT builder._save_model_inference_spec = Mock() - _ModelBuilderServers._build_for_torchserve(builder) + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - def test_sets_hf_model_id_when_not_provided(self): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - builder = _create_mock_builder(env_vars={}) + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.TORCHSERVE - builder.mode = Mode.SAGEMAKER_ENDPOINT builder._save_model_inference_spec = Mock() - _ModelBuilderServers._build_for_torchserve(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL -class TestTritonPreservesHfModelId(unittest.TestCase): - """Test that _build_for_triton preserves user-provided HF_MODEL_ID.""" +# --------------------------------------------------------------------------- +# Triton +# --------------------------------------------------------------------------- +class TestBuildForTritonHfModelId: + """Test _build_for_triton preserves user-provided HF_MODEL_ID.""" - def test_preserves_user_provided_s3_uri(self): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - s3_path = "s3://my-bucket/models/my-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.TRITON builder._save_inference_spec = Mock() builder._prepare_for_triton = Mock() builder._auto_detect_image_for_triton = Mock() - _ModelBuilderServers._build_for_triton(builder) + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - def test_sets_hf_model_id_when_not_provided(self): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - builder = _create_mock_builder(env_vars={}) + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.TRITON builder._save_inference_spec = Mock() builder._prepare_for_triton = Mock() builder._auto_detect_image_for_triton = Mock() - _ModelBuilderServers._build_for_triton(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - -class TestTransformersPreservesHfModelId(unittest.TestCase): - """Test that _build_for_transformers preserves user-provided HF_MODEL_ID.""" - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - def test_preserves_user_provided_s3_uri_with_model_string(self, mock_nb, mock_hf_config): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten when model is a string.""" - mock_hf_config.return_value = {} - - s3_path = "s3://my-bucket/models/my-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------------------- +# Transformers (MMS) +# --------------------------------------------------------------------------- +class TestBuildForTransformersHfModelId: + """Test _build_for_transformers preserves user-provided HF_MODEL_ID.""" + + _patches = [ + patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ), + patch( + "sagemaker.serve.model_server.multi_model_server" + ".prepare._create_dir_structure", + ), + ] + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.MMS - builder.mode = Mode.SAGEMAKER_ENDPOINT builder.model_data_download_timeout = None - - with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_transformers(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - def test_sets_hf_model_id_when_not_provided_with_model_string(self, mock_nb, mock_hf_config): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - mock_hf_config.return_value = {} - - builder = _create_mock_builder(env_vars={}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.MMS - builder.mode = Mode.SAGEMAKER_ENDPOINT builder.model_data_download_timeout = None - - with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_transformers(builder) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) @patch("sagemaker.serve.model_builder_servers.save_pkl") - def test_preserves_user_provided_hf_model_id_with_inference_spec(self, mock_pkl, mock_nb, mock_hf_config): - """User-provided HF_MODEL_ID should not be overwritten when inference_spec provides a model ID.""" - mock_hf_config.return_value = {} - - s3_path = "s3://my-bucket/models/my-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + @patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ) + @patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ) + @patch( + "sagemaker.serve.model_server.multi_model_server" + ".prepare._create_dir_structure", + ) + @patch("os.makedirs") + def test_preserves_with_inference_spec( + self, + _mock_makedirs, + _mock_dir, + _mock_nb, + _mock_hf, + _mock_pkl, + ): + """User-provided HF_MODEL_ID preserved with inference_spec.""" + builder = _create_mock_builder( + env_vars={"HF_MODEL_ID": S3_PATH} + ) builder.model_server = ModelServer.MMS - builder.mode = Mode.SAGEMAKER_ENDPOINT builder.model_data_download_timeout = None - builder.model = None # No model string, using inference_spec + builder.model = None builder.inference_spec = Mock() - builder.inference_spec.get_model.return_value = "some-hf-model-id" - builder._is_jumpstart_model_id = Mock(return_value=False) - - with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): - with patch("os.makedirs"): - _ModelBuilderServers._build_for_transformers(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - -if __name__ == "__main__": - unittest.main() + builder.inference_spec.get_model.return_value = ( + "some-hf-model-id" + ) + builder._is_jumpstart_model_id = Mock( + return_value=False + ) + _ModelBuilderServers._build_for_transformers(builder) + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH