Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sagemaker-core/src/sagemaker/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def pascal_to_snake(pascal_str):


def is_not_primitive(obj):
return not isinstance(obj, (int, float, str, bool, datetime.datetime))
return not isinstance(obj, (int, float, str, bool, datetime.datetime, bytes))


def is_not_str_dict(obj):
Expand Down
5 changes: 5 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def _get_deserializer(self, obj):
return StringDeserializer()
if _is_jsonable(obj):
return JSONDeserializer()
if isinstance(obj, dict) and "content_type" in obj:
try:
return BytesDeserializer()
except ValueError as e:
logger.error(e)

raise ValueError(
(
Expand Down
13 changes: 13 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,20 @@ def _build_for_transformers(self) -> Model:
hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string
# Get model metadata for task detection
hf_model_md = self.get_huggingface_model_metadata(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
model_task = hf_model_md.get("pipeline_tag")
if model_task:
self.env_vars.update({"HF_TASK": model_task})

self.env_vars.update({"HF_MODEL_ID": self.model})

# Add HuggingFace token if available
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")

# Get HF config for string model IDs
if hasattr(self.env_vars, "HF_API_TOKEN"):
self.hf_model_config = _get_model_config_properties_from_hf(
Expand Down
5 changes: 5 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,11 @@ def _hf_schema_builder_init(self, model_task: str) -> None:
sample_inputs,
sample_outputs,
) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task)
# Unwrap list outputs for binary tasks (text-to-image, audio, etc.)
# Remote schema retriever returns [{'data': b'...', 'content_type': '...'}]
# but SchemaBuilder expects {'data': b'...', 'content_type': '...'}
if isinstance(sample_outputs, list) and len(sample_outputs) > 0:
sample_outputs = sample_outputs[0]

self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)

Expand Down
52 changes: 31 additions & 21 deletions sagemaker-serve/src/sagemaker/serve/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utility functions for fetching model information from HuggingFace Hub"""

from __future__ import absolute_import
import json
import urllib.request
Expand All @@ -24,30 +25,39 @@
def _get_model_config_properties_from_hf(model_id: str, hf_hub_token: str = None):
"""Placeholder docstring"""

config_url = f"https://huggingface.co/{model_id}/raw/main/config.json"
config_files = ["config.json", "model_index.json", "adapter_config.json"]

model_config = None
try:
if hf_hub_token:
config_url = urllib.request.Request(
config_url, headers={"Authorization": "Bearer " + hf_hub_token}
)
with urllib.request.urlopen(config_url) as response:
model_config = json.load(response)
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
if "HTTP Error 401: Unauthorized" in str(e):
raise ValueError(
"Trying to access a gated/private HuggingFace model without valid credentials. "
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
for config_file in config_files:
config_url = f"https://huggingface.co/{model_id}/raw/main/{config_file}"
request = config_url

try:
if hf_hub_token:
request = urllib.request.Request(
config_url, headers={"Authorization": "Bearer " + hf_hub_token}
)

with urllib.request.urlopen(request) as response:
model_config = json.load(response)
break
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
if "HTTP Error 401: Unauthorized" in str(e):
raise ValueError(
"Trying to access a gated/private HuggingFace model without valid credentials. "
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
)

logger.warning(
"Exception encountered while trying to read config file %s. Details: %s",
config_url,
e,
)
logger.warning(
"Exception encountered while trying to read config file %s. " "Details: %s",
config_url,
e,
)

if not model_config:
allowed_files = ", ".join(config_files)
raise ValueError(
f"Did not find a config.json or model_index.json file in huggingface hub for "
f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable "
f"Diffusion Models) for this model in the huggingface hub"
f"Did not find any supported model config file in Hugging Face Hub for {model_id}. "
f"Expected one of: {allowed_files}"
)
return model_config
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,10 @@ def test_build_with_hf_model_string(
result = self.builder._build_for_transformers()

self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2")
mock_hf_config.assert_called_once_with(
"gpt2",
"token",
)
mock_create.assert_called_once()

@patch("sagemaker.serve.model_builder_servers._get_nb_instance")
Expand Down
102 changes: 90 additions & 12 deletions sagemaker-serve/tests/unit/utils/test_hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def test_get_model_config_http_error(self, mock_logger, mock_urlopen):

with self.assertRaises(ValueError) as context:
_get_model_config_properties_from_hf("non-existent-model")
self.assertIn("Did not find a config.json", str(context.exception))
mock_logger.warning.assert_called_once()

self.assertIn("Did not find any supported model config file", str(context.exception))
self.assertEqual(mock_logger.warning.call_count, 3)

@patch('urllib.request.urlopen')
@patch('sagemaker.serve.utils.hf_utils.logger')
Expand All @@ -87,9 +87,9 @@ def test_get_model_config_url_error(self, mock_logger, mock_urlopen):

with self.assertRaises(ValueError) as context:
_get_model_config_properties_from_hf("model-id")
self.assertIn("Did not find a config.json", str(context.exception))
mock_logger.warning.assert_called_once()

self.assertIn("Did not find any supported model config file", str(context.exception))
self.assertEqual(mock_logger.warning.call_count, 3)

@patch('urllib.request.urlopen')
@patch('sagemaker.serve.utils.hf_utils.logger')
Expand All @@ -99,9 +99,9 @@ def test_get_model_config_timeout_error(self, mock_logger, mock_urlopen):

with self.assertRaises(ValueError) as context:
_get_model_config_properties_from_hf("model-id")
self.assertIn("Did not find a config.json", str(context.exception))
mock_logger.warning.assert_called_once()

self.assertIn("Did not find any supported model config file", str(context.exception))
self.assertEqual(mock_logger.warning.call_count, 3)

@patch('urllib.request.urlopen')
@patch('sagemaker.serve.utils.hf_utils.logger')
Expand All @@ -115,9 +115,9 @@ def test_get_model_config_json_decode_error(self, mock_logger, mock_urlopen):
with patch('json.load', side_effect=JSONDecodeError("msg", "doc", 0)):
with self.assertRaises(ValueError) as context:
_get_model_config_properties_from_hf("model-id")
self.assertIn("Did not find a config.json", str(context.exception))
mock_logger.warning.assert_called_once()

self.assertIn("Did not find any supported model config file", str(context.exception))
self.assertEqual(mock_logger.warning.call_count, 3)

@patch('urllib.request.urlopen')
def test_get_model_config_url_format(self, mock_urlopen):
Expand All @@ -137,6 +137,84 @@ def test_get_model_config_url_format(self, mock_urlopen):
actual_url = mock_urlopen.call_args[0][0]
self.assertEqual(actual_url, expected_url)

@patch("urllib.request.urlopen")
def test_get_model_config_falls_back_to_model_index(self, mock_urlopen):
"""Test fallback to model_index.json when config.json is missing."""
config_missing_error = HTTPError(
"https://huggingface.co/org/model/raw/main/config.json", 404, "Not Found", {}, None
)
model_index_config = {"_class_name": "FluxPipeline", "_diffusers_version": "0.31.0"}

mock_model_index_response = Mock()
mock_model_index_response.__enter__ = Mock(return_value=mock_model_index_response)
mock_model_index_response.__exit__ = Mock(return_value=False)

def _urlopen_side_effect(request):
url = request.full_url if hasattr(request, "full_url") else request
if url.endswith("/config.json"):
raise config_missing_error
if url.endswith("/model_index.json"):
return mock_model_index_response
raise AssertionError(f"Unexpected URL called: {url}")

mock_urlopen.side_effect = _urlopen_side_effect

with patch("json.load", side_effect=[model_index_config]):
result = _get_model_config_properties_from_hf("org/model-name")

self.assertEqual(result, model_index_config)

@patch("urllib.request.urlopen")
@patch("sagemaker.serve.utils.hf_utils.logger")
def test_get_model_config_dual_file_error_when_both_missing(self, mock_logger, mock_urlopen):
"""Test error when all known config files are missing."""
mock_urlopen.side_effect = HTTPError("url", 404, "Not Found", {}, None)

with self.assertRaises(ValueError) as context:
_get_model_config_properties_from_hf("model-id")

self.assertIn(
"Expected one of: config.json, model_index.json, adapter_config.json",
str(context.exception),
)
self.assertEqual(mock_urlopen.call_count, 3)
self.assertEqual(mock_logger.warning.call_count, 3)

@patch("urllib.request.urlopen")
def test_get_model_config_falls_back_to_adapter_config(self, mock_urlopen):
"""Test fallback to adapter_config.json when config/model_index are missing."""
config_missing_error = HTTPError(
"https://huggingface.co/org/model/raw/main/config.json", 404, "Not Found", {}, None
)
model_index_missing_error = HTTPError(
"https://huggingface.co/org/model/raw/main/model_index.json", 404, "Not Found", {}, None
)
adapter_config = {
"base_model_name_or_path": "LiquidAI/LFM2.5-1.2B-Instruct",
"peft_type": "LORA",
}

mock_adapter_response = Mock()
mock_adapter_response.__enter__ = Mock(return_value=mock_adapter_response)
mock_adapter_response.__exit__ = Mock(return_value=False)

def _urlopen_side_effect(request):
url = request.full_url if hasattr(request, "full_url") else request
if url.endswith("/config.json"):
raise config_missing_error
if url.endswith("/model_index.json"):
raise model_index_missing_error
if url.endswith("/adapter_config.json"):
return mock_adapter_response
raise AssertionError(f"Unexpected URL called: {url}")

mock_urlopen.side_effect = _urlopen_side_effect

with patch("json.load", side_effect=[adapter_config]):
result = _get_model_config_properties_from_hf("org/model-name")

self.assertEqual(result, adapter_config)


if __name__ == "__main__":
unittest.main()
Loading