Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,7 +2331,9 @@ def _build_single_modelbuilder(
)
model_package = self._fetch_model_package()
# Fetch recipe config first to set image_uri, instance_type, env_vars, and s3_upload_path
self._fetch_and_cache_recipe_config()
base_model = model_package.inference_specification.containers[0].base_model
if base_model is not None:
self._fetch_and_cache_recipe_config()

# Nova models use a completely different deployment architecture
if self._is_nova_model():
Expand Down
41 changes: 41 additions & 0 deletions sagemaker-serve/tests/unit/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,47 @@ def test_build_single_modelbuilder_with_model_customization(self, mock_is_1p, mo
mock_model_class.create.assert_called_once()
self.assertEqual(result, mock_created_model)

@patch('sagemaker.serve.model_builder.Model')
@patch('sagemaker.serve.model_builder.is_1p_image_uri')
def test_build_single_modelbuilder_with_model_customization_no_jumpstart(self, mock_is_1p, mock_model_class):
"""Test _build_single_modelbuilder skips _fetch_and_cache_recipe_config when base_model is None."""
mock_is_1p.return_value = True

# Setup mock model package with base_model = None (custom model package, not JumpStart)
mock_model_package = Mock()
mock_container = Mock()
mock_container.base_model = None
mock_container.model_data_source.s3_data_source.s3_uri = "s3://bucket/model"
mock_model_package.inference_specification.containers = [mock_container]

# Setup training job with model_package_config
self.mock_training_job.model_package_config = Mock()
self.mock_training_job.model_package_config.source_model_package_arn = (
"arn:aws:sagemaker:us-east-1:123456789012:model-package/source"
)

mock_created_model = Mock()
mock_model_class.create.return_value = mock_created_model

builder = ModelBuilder(
model=self.mock_training_job,
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
sagemaker_session=self.mock_session,
image_uri="test-image:latest",
instance_type="ml.g5.2xlarge"
)

with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
with patch.object(builder, '_fetch_and_cache_recipe_config') as mock_recipe:
with patch.object(builder, '_get_serve_setting', return_value=Mock()):
with patch.object(builder, '_is_nova_model', return_value=False):
with patch.object(builder, '_fetch_peft', return_value=None):
result = builder._build_single_modelbuilder()

mock_recipe.assert_not_called()
mock_model_class.create.assert_called_once()
self.assertEqual(result, mock_created_model)

def test_deploy_model_customization_new_endpoint(self):
"""Test _deploy_model_customization for new endpoint creation."""
from sagemaker.core.shapes import InferenceComponentComputeResourceRequirements
Expand Down
Loading