From c517f249c2a891d4ba554213851f482ee79db640 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:18:26 -0400 Subject: [PATCH 1/4] fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) --- .../src/sagemaker/core/modules/utils.py | 9 +- .../core/modules/test_utils_safe_serialize.py | 87 +++++++++++++++++ .../test_model_trainer_pipeline_variable.py | 60 +++++++++++- .../tests/unit/train/test_safe_serialize.py | 97 +++++++++++++++++++ 4 files changed, 250 insertions(+), 3 deletions(-) create mode 100644 sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py create mode 100644 sagemaker-train/tests/unit/train/test_safe_serialize.py diff --git a/sagemaker-core/src/sagemaker/core/modules/utils.py b/sagemaker-core/src/sagemaker/core/modules/utils.py index 94dc2dff22..d1a0a3ea09 100644 --- a/sagemaker-core/src/sagemaker/core/modules/utils.py +++ b/sagemaker-core/src/sagemaker/core/modules/utils.py @@ -24,6 +24,7 @@ from sagemaker.core.shapes import Unassigned from sagemaker.core.modules import logger +from sagemaker.core.helper.pipeline_variable import PipelineVariable def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: @@ -129,9 +130,11 @@ def safe_serialize(data): This function handles the following cases: 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + 2. If `data` is of type `PipelineVariable`, it returns the PipelineVariable object + as-is for pipeline serialization. + 3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns the JSON-encoded string using `json.dumps()`. - 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + 4. If `data` cannot be serialized (e.g., a custom object), it returns the string representation of the data using `str(data)`. Args: @@ -142,6 +145,8 @@ def safe_serialize(data): """ if isinstance(data, str): return data + elif isinstance(data, PipelineVariable): + return data try: return json.dumps(data) except TypeError: diff --git a/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py b/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py new file mode 100644 index 0000000000..57a8fe2a5d --- /dev/null +++ b/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py @@ -0,0 +1,87 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for safe_serialize in sagemaker.core.modules.utils with PipelineVariable support. + +Verifies that safe_serialize correctly handles PipelineVariable objects +(e.g., ParameterInteger, ParameterString) by returning them as-is rather +than attempting str() conversion which would raise TypeError. + +See: https://github.com/aws/sagemaker-python-sdk/issues/5504 +""" +from __future__ import absolute_import + +import pytest + +from sagemaker.core.modules.utils import safe_serialize +from sagemaker.core.helper.pipeline_variable import PipelineVariable +from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString + + +class TestSafeSerializeWithPipelineVariables: + """Test safe_serialize handles PipelineVariable objects correctly.""" + + def test_safe_serialize_with_parameter_integer(self): + """ParameterInteger should be returned as-is (identity preserved).""" + param = ParameterInteger(name="MaxDepth", default_value=5) + result = safe_serialize(param) + assert result is param + assert isinstance(result, PipelineVariable) + + def test_safe_serialize_with_parameter_string(self): + """ParameterString should be returned as-is (identity preserved).""" + param = ParameterString(name="Algorithm", default_value="xgboost") + result = safe_serialize(param) + assert result is param + assert isinstance(result, PipelineVariable) + + def test_safe_serialize_does_not_call_str_on_pipeline_variable(self): + """Verify that PipelineVariable.__str__ is never invoked (would raise TypeError).""" + param = ParameterInteger(name="TestParam", default_value=10) + # This should NOT raise TypeError + result = safe_serialize(param) + assert result is param + + +class TestSafeSerializeBasicTypes: + """Regression tests: verify basic types still work after PipelineVariable support.""" + + def test_safe_serialize_with_string(self): + """Strings should be returned as-is without JSON wrapping.""" + assert safe_serialize("hello") == "hello" + + def test_safe_serialize_with_int(self): + """Integers should be JSON-serialized to string.""" + assert safe_serialize(42) == "42" + + def test_safe_serialize_with_dict(self): + """Dicts should be JSON-serialized.""" + result = safe_serialize({"key": "val"}) + assert result == '{"key": "val"}' + + def test_safe_serialize_with_bool(self): + """Booleans should be JSON-serialized.""" + assert safe_serialize(True) == "true" + assert safe_serialize(False) == "false" + + def test_safe_serialize_with_none(self): + """None should be JSON-serialized to 'null'.""" + assert safe_serialize(None) == "null" + + def test_safe_serialize_with_custom_object(self): + """Custom objects should fall back to str().""" + + class CustomObj: + def __str__(self): + return "custom" + + assert safe_serialize(CustomObj()) == "custom" diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3fd34fa47b..55ac754017 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,7 +26,7 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger from sagemaker.train.model_trainer import ModelTrainer, Mode from sagemaker.train.configs import ( Compute, @@ -176,3 +176,61 @@ def test_training_image_rejects_invalid_type(self): stopping_condition=DEFAULT_STOPPING, output_data_config=DEFAULT_OUTPUT, ) + + +class TestModelTrainerPipelineVariableHyperparameters: + """Test that PipelineVariable objects work correctly in ModelTrainer hyperparameters.""" + + def test_hyperparameters_with_parameter_integer(self): + """ParameterInteger in hyperparameters should be preserved through _create_training_job_args.""" + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"max_depth": max_depth}, + ) + args = trainer._create_training_job_args() + # PipelineVariable should be preserved as-is, not stringified + assert args["hyper_parameters"]["max_depth"] is max_depth + + def test_hyperparameters_with_parameter_string(self): + """ParameterString in hyperparameters should be preserved through _create_training_job_args.""" + algo = ParameterString(name="Algorithm", default_value="xgboost") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"algorithm": algo}, + ) + args = trainer._create_training_job_args() + assert args["hyper_parameters"]["algorithm"] is algo + + def test_hyperparameters_with_mixed_pipeline_and_static_values(self): + """Mixed PipelineVariable and static values should both be handled correctly.""" + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={ + "max_depth": max_depth, + "eta": 0.1, + "objective": "binary:logistic", + "num_round": 100, + }, + ) + args = trainer._create_training_job_args() + hp = args["hyper_parameters"] + # PipelineVariable preserved as-is + assert hp["max_depth"] is max_depth + # Static values serialized to strings + assert hp["eta"] == "0.1" + assert hp["objective"] == "binary:logistic" + assert hp["num_round"] == "100" diff --git a/sagemaker-train/tests/unit/train/test_safe_serialize.py b/sagemaker-train/tests/unit/train/test_safe_serialize.py new file mode 100644 index 0000000000..0f20e21efd --- /dev/null +++ b/sagemaker-train/tests/unit/train/test_safe_serialize.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for safe_serialize with PipelineVariable support. + +Verifies that safe_serialize in sagemaker.train.utils correctly handles +PipelineVariable objects (e.g., ParameterInteger, ParameterString) by +returning them as-is rather than attempting str() conversion which would +raise TypeError. + +See: https://github.com/aws/sagemaker-python-sdk/issues/5504 +""" +from __future__ import absolute_import + +import pytest + +from sagemaker.train.utils import safe_serialize +from sagemaker.core.helper.pipeline_variable import PipelineVariable +from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString + + +class TestSafeSerializeWithPipelineVariables: + """Test safe_serialize handles PipelineVariable objects correctly.""" + + def test_safe_serialize_with_parameter_integer(self): + """ParameterInteger should be returned as-is (identity preserved).""" + param = ParameterInteger(name="MaxDepth", default_value=5) + result = safe_serialize(param) + assert result is param + assert isinstance(result, PipelineVariable) + + def test_safe_serialize_with_parameter_string(self): + """ParameterString should be returned as-is (identity preserved).""" + param = ParameterString(name="Algorithm", default_value="xgboost") + result = safe_serialize(param) + assert result is param + assert isinstance(result, PipelineVariable) + + def test_safe_serialize_does_not_call_str_on_pipeline_variable(self): + """Verify that PipelineVariable.__str__ is never invoked (would raise TypeError).""" + param = ParameterInteger(name="TestParam", default_value=10) + # This should NOT raise TypeError + result = safe_serialize(param) + assert result is param + + +class TestSafeSerializeBasicTypes: + """Regression tests: verify basic types still work after PipelineVariable support.""" + + def test_safe_serialize_with_string(self): + """Strings should be returned as-is without JSON wrapping.""" + assert safe_serialize("hello") == "hello" + assert safe_serialize("12345") == "12345" + + def test_safe_serialize_with_int(self): + """Integers should be JSON-serialized to string.""" + assert safe_serialize(42) == "42" + + def test_safe_serialize_with_float(self): + """Floats should be JSON-serialized to string.""" + assert safe_serialize(3.14) == "3.14" + + def test_safe_serialize_with_dict(self): + """Dicts should be JSON-serialized.""" + result = safe_serialize({"key": "val"}) + assert result == '{"key": "val"}' + + def test_safe_serialize_with_bool(self): + """Booleans should be JSON-serialized.""" + assert safe_serialize(True) == "true" + assert safe_serialize(False) == "false" + + def test_safe_serialize_with_none(self): + """None should be JSON-serialized to 'null'.""" + assert safe_serialize(None) == "null" + + def test_safe_serialize_with_list(self): + """Lists should be JSON-serialized.""" + assert safe_serialize([1, 2, 3]) == "[1, 2, 3]" + + def test_safe_serialize_with_custom_object(self): + """Custom objects should fall back to str().""" + + class CustomObj: + def __str__(self): + return "custom" + + assert safe_serialize(CustomObj()) == "custom" From 109214aa51e33a559bca636bda405b62ec5a07dc Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:26:41 -0400 Subject: [PATCH 2/4] fix: address review comments (iteration #1) --- .../src/sagemaker/core/modules/utils.py | 8 ++++-- .../core/modules/test_utils_safe_serialize.py | 8 +++++- .../test_model_trainer_pipeline_variable.py | 26 +++++++++++++++++++ .../tests/unit/train/test_safe_serialize.py | 8 +++++- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/modules/utils.py b/sagemaker-core/src/sagemaker/core/modules/utils.py index d1a0a3ea09..a33d31677b 100644 --- a/sagemaker-core/src/sagemaker/core/modules/utils.py +++ b/sagemaker-core/src/sagemaker/core/modules/utils.py @@ -24,7 +24,11 @@ from sagemaker.core.shapes import Unassigned from sagemaker.core.modules import logger -from sagemaker.core.helper.pipeline_variable import PipelineVariable + +try: + from sagemaker.core.helper.pipeline_variable import PipelineVariable +except ImportError: + PipelineVariable = None def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: @@ -145,7 +149,7 @@ def safe_serialize(data): """ if isinstance(data, str): return data - elif isinstance(data, PipelineVariable): + elif PipelineVariable is not None and isinstance(data, PipelineVariable): return data try: return json.dumps(data) diff --git a/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py b/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py index 57a8fe2a5d..05d358ce37 100644 --- a/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py +++ b/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py @@ -18,7 +18,7 @@ See: https://github.com/aws/sagemaker-python-sdk/issues/5504 """ -from __future__ import absolute_import +from __future__ import annotations import pytest @@ -51,6 +51,12 @@ def test_safe_serialize_does_not_call_str_on_pipeline_variable(self): result = safe_serialize(param) assert result is param + def test_pipeline_variable_str_raises_type_error(self): + """Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug).""" + param = ParameterInteger(name="TestParam", default_value=10) + with pytest.raises(TypeError): + str(param) + class TestSafeSerializeBasicTypes: """Regression tests: verify basic types still work after PipelineVariable support.""" diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 55ac754017..2679fe2995 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -210,6 +210,32 @@ def test_hyperparameters_with_parameter_string(self): args = trainer._create_training_job_args() assert args["hyper_parameters"]["algorithm"] is algo + def test_hyperparameters_with_parameter_integer_does_not_raise(self): + """Verify ParameterInteger in hyperparameters does NOT raise TypeError. + + This test documents the exact bug scenario from GH#5504: safe_serialize + would fall back to str(data) for PipelineVariable objects, but + PipelineVariable.__str__ intentionally raises TypeError. + """ + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"max_depth": max_depth}, + ) + # This call would have raised TypeError before the fix + try: + args = trainer._create_training_job_args() + except TypeError: + pytest.fail( + "safe_serialize raised TypeError on PipelineVariable - " + "this is the bug described in GH#5504" + ) + assert args["hyper_parameters"]["max_depth"] is max_depth + def test_hyperparameters_with_mixed_pipeline_and_static_values(self): """Mixed PipelineVariable and static values should both be handled correctly.""" max_depth = ParameterInteger(name="MaxDepth", default_value=5) diff --git a/sagemaker-train/tests/unit/train/test_safe_serialize.py b/sagemaker-train/tests/unit/train/test_safe_serialize.py index 0f20e21efd..c0a26c4e88 100644 --- a/sagemaker-train/tests/unit/train/test_safe_serialize.py +++ b/sagemaker-train/tests/unit/train/test_safe_serialize.py @@ -19,7 +19,7 @@ See: https://github.com/aws/sagemaker-python-sdk/issues/5504 """ -from __future__ import absolute_import +from __future__ import annotations import pytest @@ -52,6 +52,12 @@ def test_safe_serialize_does_not_call_str_on_pipeline_variable(self): result = safe_serialize(param) assert result is param + def test_pipeline_variable_str_raises_type_error(self): + """Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug).""" + param = ParameterInteger(name="TestParam", default_value=10) + with pytest.raises(TypeError): + str(param) + class TestSafeSerializeBasicTypes: """Regression tests: verify basic types still work after PipelineVariable support.""" From 0d1195dfb1529432dee789bb17ffd811ec37ec00 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:30:33 -0400 Subject: [PATCH 3/4] fix: address review comments (iteration #2) --- .../test_model_trainer_pipeline_variable.py | 35 +++++-------------- .../tests/unit/train/test_safe_serialize.py | 23 ++++-------- 2 files changed, 14 insertions(+), 44 deletions(-) diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 2679fe2995..3c84e2c63a 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -182,7 +182,13 @@ class TestModelTrainerPipelineVariableHyperparameters: """Test that PipelineVariable objects work correctly in ModelTrainer hyperparameters.""" def test_hyperparameters_with_parameter_integer(self): - """ParameterInteger in hyperparameters should be preserved through _create_training_job_args.""" + """ParameterInteger in hyperparameters should be preserved through _create_training_job_args. + + This test documents the exact bug scenario from GH#5504: safe_serialize + would fall back to str(data) for PipelineVariable objects, but + PipelineVariable.__str__ intentionally raises TypeError. + Before the fix, this call would have raised TypeError. + """ max_depth = ParameterInteger(name="MaxDepth", default_value=5) trainer = ModelTrainer( training_image=DEFAULT_IMAGE, @@ -192,6 +198,7 @@ def test_hyperparameters_with_parameter_integer(self): output_data_config=DEFAULT_OUTPUT, hyperparameters={"max_depth": max_depth}, ) + # This call would have raised TypeError before the fix (GH#5504) args = trainer._create_training_job_args() # PipelineVariable should be preserved as-is, not stringified assert args["hyper_parameters"]["max_depth"] is max_depth @@ -210,32 +217,6 @@ def test_hyperparameters_with_parameter_string(self): args = trainer._create_training_job_args() assert args["hyper_parameters"]["algorithm"] is algo - def test_hyperparameters_with_parameter_integer_does_not_raise(self): - """Verify ParameterInteger in hyperparameters does NOT raise TypeError. - - This test documents the exact bug scenario from GH#5504: safe_serialize - would fall back to str(data) for PipelineVariable objects, but - PipelineVariable.__str__ intentionally raises TypeError. - """ - max_depth = ParameterInteger(name="MaxDepth", default_value=5) - trainer = ModelTrainer( - training_image=DEFAULT_IMAGE, - role=DEFAULT_ROLE, - compute=DEFAULT_COMPUTE, - stopping_condition=DEFAULT_STOPPING, - output_data_config=DEFAULT_OUTPUT, - hyperparameters={"max_depth": max_depth}, - ) - # This call would have raised TypeError before the fix - try: - args = trainer._create_training_job_args() - except TypeError: - pytest.fail( - "safe_serialize raised TypeError on PipelineVariable - " - "this is the bug described in GH#5504" - ) - assert args["hyper_parameters"]["max_depth"] is max_depth - def test_hyperparameters_with_mixed_pipeline_and_static_values(self): """Mixed PipelineVariable and static values should both be handled correctly.""" max_depth = ParameterInteger(name="MaxDepth", default_value=5) diff --git a/sagemaker-train/tests/unit/train/test_safe_serialize.py b/sagemaker-train/tests/unit/train/test_safe_serialize.py index c0a26c4e88..1ad227a278 100644 --- a/sagemaker-train/tests/unit/train/test_safe_serialize.py +++ b/sagemaker-train/tests/unit/train/test_safe_serialize.py @@ -31,27 +31,16 @@ class TestSafeSerializeWithPipelineVariables: """Test safe_serialize handles PipelineVariable objects correctly.""" - def test_safe_serialize_with_parameter_integer(self): - """ParameterInteger should be returned as-is (identity preserved).""" - param = ParameterInteger(name="MaxDepth", default_value=5) + @pytest.mark.parametrize("param", [ + ParameterInteger(name="MaxDepth", default_value=5), + ParameterString(name="Algorithm", default_value="xgboost"), + ]) + def test_safe_serialize_returns_pipeline_variable_as_is(self, param): + """PipelineVariable objects should be returned as-is (identity preserved).""" result = safe_serialize(param) assert result is param assert isinstance(result, PipelineVariable) - def test_safe_serialize_with_parameter_string(self): - """ParameterString should be returned as-is (identity preserved).""" - param = ParameterString(name="Algorithm", default_value="xgboost") - result = safe_serialize(param) - assert result is param - assert isinstance(result, PipelineVariable) - - def test_safe_serialize_does_not_call_str_on_pipeline_variable(self): - """Verify that PipelineVariable.__str__ is never invoked (would raise TypeError).""" - param = ParameterInteger(name="TestParam", default_value=10) - # This should NOT raise TypeError - result = safe_serialize(param) - assert result is param - def test_pipeline_variable_str_raises_type_error(self): """Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug).""" param = ParameterInteger(name="TestParam", default_value=10) From e3b2d8202480a2f212f0286e3749eb5e4a229d13 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:38:30 -0400 Subject: [PATCH 4/4] fix: address review comments (iteration #3) --- .../src/sagemaker/core/modules/utils.py | 11 ++++----- .../core/modules/test_utils_safe_serialize.py | 23 +++++-------------- .../test_model_trainer_pipeline_variable.py | 2 +- 3 files changed, 11 insertions(+), 25 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/modules/utils.py b/sagemaker-core/src/sagemaker/core/modules/utils.py index a33d31677b..3cb884d475 100644 --- a/sagemaker-core/src/sagemaker/core/modules/utils.py +++ b/sagemaker-core/src/sagemaker/core/modules/utils.py @@ -24,11 +24,7 @@ from sagemaker.core.shapes import Unassigned from sagemaker.core.modules import logger - -try: - from sagemaker.core.helper.pipeline_variable import PipelineVariable -except ImportError: - PipelineVariable = None +from sagemaker.core.helper.pipeline_variable import PipelineVariable def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: @@ -145,11 +141,12 @@ def safe_serialize(data): data (Any): The data to serialize. Returns: - str: The serialized JSON-compatible string or the string representation of the input. + str | PipelineVariable: The serialized JSON-compatible string, the string + representation of the input, or the PipelineVariable object as-is. """ if isinstance(data, str): return data - elif PipelineVariable is not None and isinstance(data, PipelineVariable): + elif isinstance(data, PipelineVariable): return data try: return json.dumps(data) diff --git a/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py b/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py index 05d358ce37..ad54f3c28f 100644 --- a/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py +++ b/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py @@ -30,27 +30,16 @@ class TestSafeSerializeWithPipelineVariables: """Test safe_serialize handles PipelineVariable objects correctly.""" - def test_safe_serialize_with_parameter_integer(self): - """ParameterInteger should be returned as-is (identity preserved).""" - param = ParameterInteger(name="MaxDepth", default_value=5) + @pytest.mark.parametrize("param", [ + ParameterInteger(name="MaxDepth", default_value=5), + ParameterString(name="Algorithm", default_value="xgboost"), + ]) + def test_safe_serialize_returns_pipeline_variable_as_is(self, param): + """PipelineVariable objects should be returned as-is (identity preserved).""" result = safe_serialize(param) assert result is param assert isinstance(result, PipelineVariable) - def test_safe_serialize_with_parameter_string(self): - """ParameterString should be returned as-is (identity preserved).""" - param = ParameterString(name="Algorithm", default_value="xgboost") - result = safe_serialize(param) - assert result is param - assert isinstance(result, PipelineVariable) - - def test_safe_serialize_does_not_call_str_on_pipeline_variable(self): - """Verify that PipelineVariable.__str__ is never invoked (would raise TypeError).""" - param = ParameterInteger(name="TestParam", default_value=10) - # This should NOT raise TypeError - result = safe_serialize(param) - assert result is param - def test_pipeline_variable_str_raises_type_error(self): """Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug).""" param = ParameterInteger(name="TestParam", default_value=10) diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3c84e2c63a..e6aacf13f4 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -204,7 +204,7 @@ def test_hyperparameters_with_parameter_integer(self): assert args["hyper_parameters"]["max_depth"] is max_depth def test_hyperparameters_with_parameter_string(self): - """ParameterString in hyperparameters should be preserved through _create_training_job_args.""" + """ParameterString in hyperparameters should be preserved.""" algo = ParameterString(name="Algorithm", default_value="xgboost") trainer = ModelTrainer( training_image=DEFAULT_IMAGE,