Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions sagemaker-core/src/sagemaker/core/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from sagemaker.core.shapes import Unassigned
from sagemaker.core.modules import logger
from sagemaker.core.helper.pipeline_variable import PipelineVariable


def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
Expand Down Expand Up @@ -129,19 +130,24 @@ def safe_serialize(data):

This function handles the following cases:
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
2. If `data` is of type `PipelineVariable`, it returns the PipelineVariable object
as-is for pipeline serialization.
3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
the JSON-encoded string using `json.dumps()`.
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
4. If `data` cannot be serialized (e.g., a custom object), it returns the string
representation of the data using `str(data)`.

Args:
data (Any): The data to serialize.

Returns:
str: The serialized JSON-compatible string or the string representation of the input.
str | PipelineVariable: The serialized JSON-compatible string, the string
representation of the input, or the PipelineVariable object as-is.
"""
if isinstance(data, str):
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
return data
elif isinstance(data, PipelineVariable):
return data
try:
return json.dumps(data)
except TypeError:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for safe_serialize in sagemaker.core.modules.utils with PipelineVariable support.

Verifies that safe_serialize correctly handles PipelineVariable objects
(e.g., ParameterInteger, ParameterString) by returning them as-is rather
than attempting str() conversion which would raise TypeError.

See: https://github.com/aws/sagemaker-python-sdk/issues/5504
"""
from __future__ import annotations

import pytest

Comment thread
aviruthen marked this conversation as resolved.
from sagemaker.core.modules.utils import safe_serialize
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.


class TestSafeSerializeWithPipelineVariables:
"""Test safe_serialize handles PipelineVariable objects correctly."""

@pytest.mark.parametrize("param", [
ParameterInteger(name="MaxDepth", default_value=5),
ParameterString(name="Algorithm", default_value="xgboost"),
])
def test_safe_serialize_returns_pipeline_variable_as_is(self, param):
"""PipelineVariable objects should be returned as-is (identity preserved)."""
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

def test_pipeline_variable_str_raises_type_error(self):
"""Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug)."""
param = ParameterInteger(name="TestParam", default_value=10)
with pytest.raises(TypeError):
str(param)


class TestSafeSerializeBasicTypes:
"""Regression tests: verify basic types still work after PipelineVariable support."""

def test_safe_serialize_with_string(self):
"""Strings should be returned as-is without JSON wrapping."""
assert safe_serialize("hello") == "hello"

def test_safe_serialize_with_int(self):
"""Integers should be JSON-serialized to string."""
assert safe_serialize(42) == "42"

def test_safe_serialize_with_dict(self):
"""Dicts should be JSON-serialized."""
result = safe_serialize({"key": "val"})
assert result == '{"key": "val"}'

def test_safe_serialize_with_bool(self):
"""Booleans should be JSON-serialized."""
assert safe_serialize(True) == "true"
assert safe_serialize(False) == "false"

def test_safe_serialize_with_none(self):
"""None should be JSON-serialized to 'null'."""
assert safe_serialize(None) == "null"

def test_safe_serialize_with_custom_object(self):
"""Custom objects should fall back to str()."""

class CustomObj:
def __str__(self):
return "custom"

assert safe_serialize(CustomObj()) == "custom"
Comment thread
aviruthen marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.configs import (
Compute,
Expand Down Expand Up @@ -176,3 +176,68 @@ def test_training_image_rejects_invalid_type(self):
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)


class TestModelTrainerPipelineVariableHyperparameters:
"""Test that PipelineVariable objects work correctly in ModelTrainer hyperparameters."""

def test_hyperparameters_with_parameter_integer(self):
"""ParameterInteger in hyperparameters should be preserved through _create_training_job_args.

This test documents the exact bug scenario from GH#5504: safe_serialize
would fall back to str(data) for PipelineVariable objects, but
PipelineVariable.__str__ intentionally raises TypeError.
Before the fix, this call would have raised TypeError.
"""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
Comment thread
aviruthen marked this conversation as resolved.
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"max_depth": max_depth},
)
# This call would have raised TypeError before the fix (GH#5504)
args = trainer._create_training_job_args()
# PipelineVariable should be preserved as-is, not stringified
Comment thread
aviruthen marked this conversation as resolved.
assert args["hyper_parameters"]["max_depth"] is max_depth

def test_hyperparameters_with_parameter_string(self):
Comment thread
aviruthen marked this conversation as resolved.
"""ParameterString in hyperparameters should be preserved."""
algo = ParameterString(name="Algorithm", default_value="xgboost")
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"algorithm": algo},
)
args = trainer._create_training_job_args()
assert args["hyper_parameters"]["algorithm"] is algo

def test_hyperparameters_with_mixed_pipeline_and_static_values(self):
"""Mixed PipelineVariable and static values should both be handled correctly."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={
"max_depth": max_depth,
"eta": 0.1,
"objective": "binary:logistic",
"num_round": 100,
},
)
args = trainer._create_training_job_args()
hp = args["hyper_parameters"]
# PipelineVariable preserved as-is
assert hp["max_depth"] is max_depth
# Static values serialized to strings
assert hp["eta"] == "0.1"
assert hp["objective"] == "binary:logistic"
assert hp["num_round"] == "100"
92 changes: 92 additions & 0 deletions sagemaker-train/tests/unit/train/test_safe_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for safe_serialize with PipelineVariable support.

Verifies that safe_serialize in sagemaker.train.utils correctly handles
PipelineVariable objects (e.g., ParameterInteger, ParameterString) by
returning them as-is rather than attempting str() conversion which would
raise TypeError.

See: https://github.com/aws/sagemaker-python-sdk/issues/5504
"""
from __future__ import annotations

import pytest

Comment thread
aviruthen marked this conversation as resolved.
from sagemaker.train.utils import safe_serialize
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString


class TestSafeSerializeWithPipelineVariables:
"""Test safe_serialize handles PipelineVariable objects correctly."""

@pytest.mark.parametrize("param", [
Comment thread
aviruthen marked this conversation as resolved.
ParameterInteger(name="MaxDepth", default_value=5),
ParameterString(name="Algorithm", default_value="xgboost"),
])
def test_safe_serialize_returns_pipeline_variable_as_is(self, param):
"""PipelineVariable objects should be returned as-is (identity preserved)."""
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

def test_pipeline_variable_str_raises_type_error(self):
"""Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug)."""
param = ParameterInteger(name="TestParam", default_value=10)
with pytest.raises(TypeError):
str(param)


class TestSafeSerializeBasicTypes:
"""Regression tests: verify basic types still work after PipelineVariable support."""

def test_safe_serialize_with_string(self):
"""Strings should be returned as-is without JSON wrapping."""
assert safe_serialize("hello") == "hello"
assert safe_serialize("12345") == "12345"

def test_safe_serialize_with_int(self):
"""Integers should be JSON-serialized to string."""
assert safe_serialize(42) == "42"

def test_safe_serialize_with_float(self):
"""Floats should be JSON-serialized to string."""
assert safe_serialize(3.14) == "3.14"

def test_safe_serialize_with_dict(self):
"""Dicts should be JSON-serialized."""
result = safe_serialize({"key": "val"})
assert result == '{"key": "val"}'

def test_safe_serialize_with_bool(self):
"""Booleans should be JSON-serialized."""
assert safe_serialize(True) == "true"
assert safe_serialize(False) == "false"

def test_safe_serialize_with_none(self):
"""None should be JSON-serialized to 'null'."""
assert safe_serialize(None) == "null"

def test_safe_serialize_with_list(self):
"""Lists should be JSON-serialized."""
assert safe_serialize([1, 2, 3]) == "[1, 2, 3]"

def test_safe_serialize_with_custom_object(self):
"""Custom objects should fall back to str()."""

class CustomObj:
def __str__(self):
return "custom"

assert safe_serialize(CustomObj()) == "custom"
Loading