From 7bd5df9827520703338a57465bcd1500e62faacb Mon Sep 17 00:00:00 2001 From: guanweim Date: Mon, 4 May 2026 19:03:03 +0000 Subject: [PATCH 1/9] feat(train): Add SequenceLength support for SFT, DPO, RLVR, RLAIF trainers Add optional sequence_length parameter to all four trainers that enables customers to specify their desired context length for serverless training jobs. The parameter is passed in ServerlessJobConfig for recipe filtering. During trainer initialization, _get_fine_tuning_options_and_model_arn filters recipes by SequenceLength field, picking the smallest recipe with context length >= the requested value. Raises ValueError if no sufficient recipe exists or if recipes lack SequenceLength metadata. Changes: - ServerlessJobConfig: add sequence_length field - _parse_context_length: parse values like '8K' to integers - _get_fine_tuning_options_and_model_arn: filter by SequenceLength - _create_serverless_config: conditionally include sequence_length - SFTTrainer, DPOTrainer, RLVRTrainer, RLAIFTrainer: accept and thread sequence_length through init and train methods - Unit tests for all new functionality --- .../src/sagemaker/core/shapes/shapes.py | 3 +- .../train/common_utils/finetune_utils.py | 79 +++++++++++-- .../src/sagemaker/train/dpo_trainer.py | 36 +++--- .../src/sagemaker/train/rlaif_trainer.py | 35 +++--- .../src/sagemaker/train/rlvr_trainer.py | 37 +++--- .../src/sagemaker/train/sft_trainer.py | 35 +++--- .../train/common_utils/test_finetune_utils.py | 105 +++++++++++++++++- .../tests/unit/train/test_dpo_trainer.py | 65 ++++++++++- .../tests/unit/train/test_rlaif_trainer.py | 68 +++++++++++- .../tests/unit/train/test_rlvr_trainer.py | 65 ++++++++++- .../tests/unit/train/test_sft_trainer.py | 65 ++++++++++- 11 files changed, 525 insertions(+), 68 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index 8f99bcba8c..d79d33a5b2 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -9588,6 +9588,7 @@ class ServerlessJobConfig(Base): peft: The parameter-efficient fine-tuning configuration. evaluation_type: The evaluation job type. Required when serverless job type is Evaluation. evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. + sequence_length: The sequence length for the training job. Valid values are "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". """ base_model_arn: StrPipeVar @@ -9597,7 +9598,7 @@ class ServerlessJobConfig(Base): peft: Optional[StrPipeVar] = Unassigned() evaluation_type: Optional[StrPipeVar] = Unassigned() evaluator_arn: Optional[StrPipeVar] = Unassigned() - + sequence_length: Optional[StrPipeVar] = Unassigned() class MlflowConfig(Base): """ diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 0ea74ee207..1a42b2a5be 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -318,10 +318,44 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: return None -def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session, - hub_name: Optional[str] = None) -> tuple: +def _parse_context_length(value) -> int: + """Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192). + + Returns 0 if value is None or unparseable. + """ + if not value: + return 0 + value = str(value).strip().upper() + if value.endswith("K"): + try: + return int(value[:-1]) * 1024 + except ValueError: + return 0 + try: + return int(value) + except ValueError: + return 0 + + +def _get_fine_tuning_options_and_model_arn( + model_name: str, + customization_technique: str, + training_type, + sagemaker_session, + sequence_length=None, + hub_name: str = "SageMakerPublicHub" +) -> tuple: """Get fine-tuning options and model ARN for given customization technique. + Args: + model_name: Name of the model in the hub. + customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF"). + training_type: TrainingType enum or string ("LORA", "FULL"). + sagemaker_session: SageMaker session for API calls. + sequence_length: Optional sequence length (e.g., "8K"). When provided, filters + recipes by MaxContextLength >= the requested value. + hub_name: Hub name (default: "SageMakerPublicHub"). + Returns: tuple: (FineTuningOptions, model_arn, is_gated_model) """ @@ -362,9 +396,34 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni # Collect override_params from ALL matching recipes (standard + subscription) recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) + candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] + else: + candidates = [] + + # Filter by SequenceLength if sequence_length is provided + if sequence_length and candidates: + requested = _parse_context_length(sequence_length) + candidates_with_context = [r for r in candidates if r.get("SequenceLength")] + if candidates_with_context: + filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested] + if filtered: + filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) + recipe = filtered[0] + else: + available = sorted(set(r.get("SequenceLength") for r in candidates_with_context)) + raise ValueError( + f"No recipes found with SequenceLength >= {sequence_length}. " + f"Available sequence lengths: {available}" + ) + else: + raise ValueError( + f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, " + f"and sequence length:{sequence_length}" + ) + elif candidates: + recipe = candidates[0] if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") @@ -519,7 +578,8 @@ def _resolve_model_and_name(model, sagemaker_session=None): def _create_serverless_config(model_arn, customization_technique, - training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: + training_type, accept_eula, evaluator_arn=None, + sequence_length=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: """Create serverless job configuration for fine-tuning. Args: @@ -528,6 +588,7 @@ def _create_serverless_config(model_arn, customization_technique, training_type: Training type (TrainingType enum or string) accept_eula: Boolean indicating if EULA is accepted evaluator_arn: Optional evaluator ARN for RLVR/RLAIF + sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K") job_type: Type of job (default: "FineTuning") Returns: @@ -537,14 +598,18 @@ def _create_serverless_config(model_arn, customization_technique, else (training_type.value if isinstance(training_type, TrainingType) else training_type) # Create ServerlessJobConfig using shapes - serverless_config = ServerlessJobConfig( + config_kwargs = dict( job_type=job_type, base_model_arn=model_arn, customization_technique=customization_technique, peft=peft, evaluator_arn=evaluator_arn, - accept_eula=accept_eula + accept_eula=accept_eula, ) + if sequence_length is not None: + config_kwargs["sequence_length"] = sequence_length + + serverless_config = ServerlessJobConfig(**config_kwargs) return serverless_config diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index bd5d9a11bd..8e3bc17d5e 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -100,6 +100,10 @@ class DPOTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( self, @@ -116,6 +120,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -134,16 +139,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.DPO.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.DPO.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Process hyperparameters self._process_hyperparameters() @@ -227,12 +233,14 @@ def train(self, kms_key_id=self.kms_key_id ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.DPO.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.DPO.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index f2d8460989..5d782d8fa3 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -114,6 +114,10 @@ class RLAIFTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -135,6 +139,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -156,14 +161,16 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLAIF.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLAIF.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -242,13 +249,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) evaluator_arn = getattr(self, '_evaluator_arn', None) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLAIF.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLAIF.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 333a93fc55..53029155f2 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -106,6 +106,10 @@ class RLVRTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -126,6 +130,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -146,15 +151,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLVR.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLVR.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Remove constructor-handled hyperparameters self._process_hyperparameters() @@ -233,13 +240,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, # Extract and validate evaluator ARN evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLVR.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLVR.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 233f169d0f..e2193f0b9b 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -102,6 +102,10 @@ class SFTTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -119,6 +123,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: Optional[bool] = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -138,15 +143,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.SFT.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.SFT.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Process hyperparameters self._process_hyperparameters() @@ -225,12 +232,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati kms_key_id=self.kms_key_id ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.SFT.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.SFT.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 7a63e36234..1f9be51896 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -27,9 +27,11 @@ _create_mlflow_config, _validate_eula_for_gated_model, _validate_model_region_availability, - _validate_s3_path_exists + _validate_s3_path_exists, + _parse_context_length ) -from sagemaker.core.resources import ModelPackage, ModelPackageGroup +from sagemaker.core.resources import ModelPackage +from sagemaker.core.utils.utils import Unassigned, ModelPackageGroup from sagemaker.ai_registry.dataset import DataSet from sagemaker.train.common import TrainingType from sagemaker.train.configs import InputData @@ -460,6 +462,7 @@ def test__convert_input_data_to_channels(self): def test__validate_eula_for_gated_model_with_model_package(self): """Test EULA validation returns True for ModelPackage input""" from sagemaker.core.resources import ModelPackage +from sagemaker.core.utils.utils import Unassigned model_package = Mock(spec=ModelPackage) result = _validate_eula_for_gated_model(model_package, False, True) @@ -691,3 +694,101 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, # Should still have standard params, just not datamix ones assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs + + def test__create_serverless_config_with_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K") + + assert config.sequence_length == "8K" + assert config.base_model_arn == "model-arn" + + def test__create_serverless_config_without_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) + + # sequence_length should remain Unassigned (not set), not None + assert isinstance(config.sequence_length, Unassigned) + + def test__parse_context_length_with_k_suffix(self): + assert _parse_context_length("8K") == 8192 + assert _parse_context_length("32K") == 32768 + assert _parse_context_length("128K") == 131072 + + def test__parse_context_length_with_lowercase(self): + assert _parse_context_length("8k") == 8192 + + def test__parse_context_length_with_integer(self): + assert _parse_context_length("4096") == 4096 + + def test__parse_context_length_with_none(self): + assert _parse_context_length(None) == 0 + + def test__parse_context_length_with_empty(self): + assert _parse_context_length("") == 0 + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + @patch('boto3.client') + def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_client, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json", + "Peft": True, + "SequenceLength": "32K" + } + ] + } + } + + mock_s3_client = Mock() + mock_boto_client.return_value = mock_s3_client + mock_s3_client.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) + } + + result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K") + + if result is not None: + options, model_arn, is_gated_model = result + # Should pick the 32K recipe (smallest >= 8K) + mock_s3_client.get_object.assert_called_once() + call_args = mock_s3_client.get_object.call_args[1] + assert "params-32k" in call_args["Key"] + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(self, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K" + } + ] + } + } + + # Requesting 128K but only 4K available — should raise + with pytest.raises(ValueError, match="No recipes found with SequenceLength >= 128K"): + _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="128K") diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 1b70e0bf89..7648b46e35 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -506,4 +506,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.dpo_trainer._get_unique_name') + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._create_input_data_config') + @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.dpo_trainer._create_output_config') + @patch('sagemaker.train.dpo_trainer._create_serverless_config') + @patch('sagemaker.train.dpo_trainer._create_mlflow_config') + @patch('sagemaker.train.dpo_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = DPOTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K" diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index e5666883e8..6811c45540 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -682,4 +682,70 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", sequence_length="128K") + assert trainer.sequence_length == "128K" + + @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlaif_trainer._get_unique_name') + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._create_input_data_config') + @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlaif_trainer._create_output_config') + @patch('sagemaker.train.rlaif_trainer._create_serverless_config') + @patch('sagemaker.train.rlaif_trainer._create_mlflow_config') + @patch('sagemaker.train.rlaif_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_fine_tuning_options._specs = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="64K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "64K" diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index 320b81555d..b4c01385e2 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -509,4 +509,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", sequence_length="32K") + assert trainer.sequence_length == "32K" + + @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlvr_trainer._get_unique_name') + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._create_input_data_config') + @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlvr_trainer._create_output_config') + @patch('sagemaker.train.rlvr_trainer._create_serverless_config') + @patch('sagemaker.train.rlvr_trainer._create_mlflow_config') + @patch('sagemaker.train.rlvr_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="4K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "4K" diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index 108990f839..01fc21f4bd 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -520,4 +520,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.sft_trainer._resolve_model_and_name') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.sft_trainer._get_unique_name') + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._create_input_data_config') + @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.sft_trainer._create_output_config') + @patch('sagemaker.train.sft_trainer._create_serverless_config') + @patch('sagemaker.train.sft_trainer._create_mlflow_config') + @patch('sagemaker.train.sft_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = SFTTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K" From 13e0f499e71ef028ec6e61c1db7f63f821ecbe57 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 00:45:54 +0000 Subject: [PATCH 2/9] fix: use codegen for SequenceLength shape instead of manual shapes.py edit Add SequenceLength to service-2.json and regenerate shapes.py via codegen (python -m sagemaker.core.tools.codegen) instead of editing shapes.py manually. --- .../sample/sagemaker/2017-07-24/service-2.json | 17 +++++++++++++++++ .../src/sagemaker/core/shapes/shapes.py | 3 ++- .../core/utils/code_injection/shape_dag.py | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json index a7640e3f46..1dbcf64d36 100644 --- a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json +++ b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json @@ -43236,6 +43236,10 @@ "EvaluatorArn":{ "shape":"EvaluatorArn", "documentation":"

The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.

" + }, + "SequenceLength":{ + "shape":"SequenceLength", + "documentation":"

The sequence length for the training job.

" } }, "documentation":"

The configuration for the serverless training job.

" @@ -43247,6 +43251,19 @@ "Evaluation" ] }, + "SequenceLength":{ + "type":"string", + "enum":[ + "1K", + "2K", + "4K", + "8K", + "16K", + "32K", + "64K", + "128K" + ] + }, "ServerlessMaxConcurrency":{ "type":"integer", "box":true, diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index d79d33a5b2..9749b2a4b5 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -9588,7 +9588,7 @@ class ServerlessJobConfig(Base): peft: The parameter-efficient fine-tuning configuration. evaluation_type: The evaluation job type. Required when serverless job type is Evaluation. evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. - sequence_length: The sequence length for the training job. Valid values are "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + sequence_length: The sequence length for the training job. """ base_model_arn: StrPipeVar @@ -9600,6 +9600,7 @@ class ServerlessJobConfig(Base): evaluator_arn: Optional[StrPipeVar] = Unassigned() sequence_length: Optional[StrPipeVar] = Unassigned() + class MlflowConfig(Base): """ MlflowConfig diff --git a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py index d7ad54ee25..f34a282348 100644 --- a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py +++ b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py @@ -15877,6 +15877,7 @@ {"name": "Peft", "shape": "Peft", "type": "string"}, {"name": "EvaluationType", "shape": "EvaluationType", "type": "string"}, {"name": "EvaluatorArn", "shape": "EvaluatorArn", "type": "string"}, + {"name": "SequenceLength", "shape": "SequenceLength", "type": "string"}, ], "type": "structure", }, From 37a599650e7c2d4427d7cede050881da24da8d01 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 17:14:21 +0000 Subject: [PATCH 3/9] refactor: preserve original recipe selection path when sequence_length not provided Keep the existing `next(...)` logic untouched for the default case (no sequence_length). Only build the candidates list and filter when sequence_length is explicitly requested, ensuring zero behavioral change for existing callers. --- .../train/common_utils/finetune_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 1a42b2a5be..5dc5b444a1 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -396,14 +396,19 @@ def _get_fine_tuning_options_and_model_arn( # Collect override_params from ALL matching recipes (standard + subscription) recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] + recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] - else: - candidates = [] + recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + + # Override recipe selection when sequence_length is explicitly requested + if sequence_length: + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] + elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": + candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] + else: + candidates = [] - # Filter by SequenceLength if sequence_length is provided - if sequence_length and candidates: requested = _parse_context_length(sequence_length) candidates_with_context = [r for r in candidates if r.get("SequenceLength")] if candidates_with_context: @@ -422,8 +427,6 @@ def _get_fine_tuning_options_and_model_arn( f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, " f"and sequence length:{sequence_length}" ) - elif candidates: - recipe = candidates[0] if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") From 1c6f0ef9dbaad1b3b918a4ee50ee4de6e35e565b Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 17:31:36 +0000 Subject: [PATCH 4/9] fix: correct test imports and mock setup for sequence_length tests --- .../train/common_utils/test_finetune_utils.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 1f9be51896..0c34dae260 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -30,8 +30,8 @@ _validate_s3_path_exists, _parse_context_length ) -from sagemaker.core.resources import ModelPackage -from sagemaker.core.utils.utils import Unassigned, ModelPackageGroup +from sagemaker.core.resources import ModelPackage, ModelPackageGroup +from sagemaker.core.utils.utils import Unassigned from sagemaker.ai_registry.dataset import DataSet from sagemaker.train.common import TrainingType from sagemaker.train.configs import InputData @@ -462,7 +462,6 @@ def test__convert_input_data_to_channels(self): def test__validate_eula_for_gated_model_with_model_package(self): """Test EULA validation returns True for ModelPackage input""" from sagemaker.core.resources import ModelPackage -from sagemaker.core.utils.utils import Unassigned model_package = Mock(spec=ModelPackage) result = _validate_eula_for_gated_model(model_package, False, True) @@ -725,10 +724,14 @@ def test__parse_context_length_with_empty(self): assert _parse_context_length("") == 0 @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - @patch('boto3.client') - def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_client, mock_get_hub_content): + def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) + } + mock_session.boto_session.client.return_value = mock_s3 mock_get_hub_content.return_value = { 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", @@ -753,19 +756,13 @@ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_cli } } - mock_s3_client = Mock() - mock_boto_client.return_value = mock_s3_client - mock_s3_client.get_object.return_value = { - "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) - } - result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K") if result is not None: options, model_arn, is_gated_model = result # Should pick the 32K recipe (smallest >= 8K) - mock_s3_client.get_object.assert_called_once() - call_args = mock_s3_client.get_object.call_args[1] + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] assert "params-32k" in call_args["Key"] @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') From 3d36503b4f359a8b6ea436447fd831e47b5020e7 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 21:59:20 +0000 Subject: [PATCH 5/9] address PR review: sequence_length as recipe pre-filter and simplify config - Move sequence_length filtering above recipe selection to reduce recipes_with_template before existing logic runs - Always pass sequence_length to ServerlessJobConfig (no None guard) --- .../train/common_utils/finetune_utils.py | 36 +++++++------------ .../train/common_utils/test_finetune_utils.py | 3 +- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 5dc5b444a1..54b6497061 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -392,30 +392,15 @@ def _get_fine_tuning_options_and_model_arn( if not recipes_with_template: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") - # Select recipe based on training type - # Collect override_params from ALL matching recipes (standard + subscription) - recipe = None - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) - elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) - - # Override recipe selection when sequence_length is explicitly requested + # Filter by SequenceLength before recipe selection if sequence_length is requested if sequence_length: - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] - elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] - else: - candidates = [] - requested = _parse_context_length(sequence_length) - candidates_with_context = [r for r in candidates if r.get("SequenceLength")] + candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")] if candidates_with_context: filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested] if filtered: filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) - recipe = filtered[0] + recipes_with_template = filtered else: available = sorted(set(r.get("SequenceLength") for r in candidates_with_context)) raise ValueError( @@ -428,6 +413,14 @@ def _get_fine_tuning_options_and_model_arn( f"and sequence length:{sequence_length}" ) + # Select recipe based on training type + # Collect override_params from ALL matching recipes (standard + subscription) + recipe = None + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) + elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": + recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") @@ -601,18 +594,15 @@ def _create_serverless_config(model_arn, customization_technique, else (training_type.value if isinstance(training_type, TrainingType) else training_type) # Create ServerlessJobConfig using shapes - config_kwargs = dict( + serverless_config = ServerlessJobConfig( job_type=job_type, base_model_arn=model_arn, customization_technique=customization_technique, peft=peft, evaluator_arn=evaluator_arn, accept_eula=accept_eula, + sequence_length=sequence_length, ) - if sequence_length is not None: - config_kwargs["sequence_length"] = sequence_length - - serverless_config = ServerlessJobConfig(**config_kwargs) return serverless_config diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 0c34dae260..50ae26a800 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -703,8 +703,7 @@ def test__create_serverless_config_with_sequence_length(self): def test__create_serverless_config_without_sequence_length(self): config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) - # sequence_length should remain Unassigned (not set), not None - assert isinstance(config.sequence_length, Unassigned) + assert config.sequence_length is None def test__parse_context_length_with_k_suffix(self): assert _parse_context_length("8K") == 8192 From c5747adaee80861e8e88940cabdcbdf12a45613b Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 22:02:51 +0000 Subject: [PATCH 6/9] fix: change hub_name default to None for consistency Use Optional[str] = None instead of hardcoded "SageMakerPublicHub" default, letting get_sagemaker_hub_name() resolve it at runtime. --- .../sagemaker/train/common_utils/finetune_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 54b6497061..23fd8e3383 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -338,12 +338,12 @@ def _parse_context_length(value) -> int: def _get_fine_tuning_options_and_model_arn( - model_name: str, - customization_technique: str, - training_type, - sagemaker_session, - sequence_length=None, - hub_name: str = "SageMakerPublicHub" + model_name: str, + customization_technique: str, + training_type, + sagemaker_session, + sequence_length=None, + hub_name: Optional[str] = None ) -> tuple: """Get fine-tuning options and model ARN for given customization technique. From 6dbac4838b74472aebee54bb0709a3d5cd238ca0 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 22:05:18 +0000 Subject: [PATCH 7/9] test: add integration test for SFT trainer with sequence_length --- .../train/test_sft_trainer_integration.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py index 93be84a738..eb7ebda1e3 100644 --- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py @@ -135,3 +135,39 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): assert training_job.training_job_status == "Completed" assert hasattr(training_job, 'output_model_package_arn') assert training_job.output_model_package_arn is not None + + +@pytest.mark.gpu_intensive +def test_sft_trainer_lora_with_sequence_length(sagemaker_session): + """Test SFT training workflow with LORA and sequence_length specified.""" + unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" + + sft_trainer = SFTTrainer( + model="meta-textgeneration-llama-3-2-1b-instruct", + training_type=TrainingType.LORA, + model_package_group="arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models", + training_dataset="s3://mc-flows-sdk-testing/input_data/sft/sample_data_256_final.jsonl", + s3_output_path="s3://mc-flows-sdk-testing/output/", + accept_eula=True, + sequence_length="8K", + base_job_name=f"sft-seqlen-integ-{unique_id}", + ) + + training_job = sft_trainer.train(wait=False) + + max_wait_time = 3600 + poll_interval = 30 + start_time = time.time() + + while time.time() - start_time < max_wait_time: + training_job.refresh() + status = training_job.training_job_status + + if status in ["Completed", "Failed", "Stopped"]: + break + + time.sleep(poll_interval) + + assert training_job.training_job_status == "Completed" + assert hasattr(training_job, 'output_model_package_arn') + assert training_job.output_model_package_arn is not None From 303fa79c178c8fd70c8d5afafd71ae09d613b407 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 22:16:28 +0000 Subject: [PATCH 8/9] style: apply black formatting and remove unused import --- .../train/common_utils/finetune_utils.py | 444 +++++++++++------- .../src/sagemaker/train/dpo_trainer.py | 147 +++--- .../src/sagemaker/train/rlaif_trainer.py | 167 ++++--- .../src/sagemaker/train/rlvr_trainer.py | 114 +++-- .../src/sagemaker/train/sft_trainer.py | 112 +++-- .../train/test_sft_trainer_integration.py | 53 ++- .../train/common_utils/test_finetune_utils.py | 426 ++++++++++------- 7 files changed, 841 insertions(+), 622 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 23fd8e3383..0c04b29d2a 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -12,7 +12,13 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.train.common_utils.recipe_utils import _get_hub_content_metadata from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE, FineTuningOptions -from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig +from sagemaker.core.shapes import ( + ServerlessJobConfig, + Channel, + DataSource, + ModelPackageConfig, + MlflowConfig, +) from sagemaker.train.configs import InputData, OutputDataConfig from sagemaker.train.defaults import TrainDefaults from sagemaker.train.constants import get_sagemaker_hub_name @@ -20,11 +26,17 @@ logger = logging.getLogger(__name__) # Region mappings for model availability -OPEN_WEIGHTS_REGIONS = ['us-east-1', 'us-west-2', 'ap-northeast-1', 'eu-west-1'] # IAD, PDX, NRT, DUB -NOVA_REGIONS = ['us-east-1', 'us-west-2'] # IAD, PDX +OPEN_WEIGHTS_REGIONS = [ + "us-east-1", + "us-west-2", + "ap-northeast-1", + "eu-west-1", +] # IAD, PDX, NRT, DUB +NOVA_REGIONS = ["us-east-1", "us-west-2"] # IAD, PDX # Constants DEFAULT_REGION = "us-west-2" + def _validate_model_region_availability(model_name: str, region_name: str): """Validate if the model is available in the specified region.""" if "nova" in model_name.lower(): @@ -48,26 +60,24 @@ def _validate_model_region_availability(model_name: str, region_name: str): ) - - def _get_beta_session(): """Create a SageMaker session with beta endpoint for demo purposes.""" - sm_client = boto3.client('sagemaker', region_name=DEFAULT_REGION) + sm_client = boto3.client("sagemaker", region_name=DEFAULT_REGION) return Session(sagemaker_client=sm_client) def _read_domain_id_from_metadata() -> Optional[str]: """Read domain ID from Studio metadata file. - + This is the standard location for domain information in Studio with Spaces. Returns None if not running in Studio or if metadata file doesn't exist. """ try: - metadata_path = '/opt/ml/metadata/resource-metadata.json' + metadata_path = "/opt/ml/metadata/resource-metadata.json" if os.path.exists(metadata_path): - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: metadata = json.load(f) - return metadata.get('DomainId') + return metadata.get("DomainId") except Exception as e: logger.debug(f"Could not read Studio metadata file: {e}") return None @@ -75,78 +85,88 @@ def _read_domain_id_from_metadata() -> Optional[str]: def _get_current_domain_id(sagemaker_session) -> Optional[str]: """Get current SageMaker Studio domain ID. - + Checks multiple sources in order of reliability: 1. Studio metadata file (Studio with Spaces - newer architecture) 2. User profile ARN (Studio Classic with User Profiles - legacy) - + Returns None if not running in a Studio environment with domain. """ # Try metadata file first (Studio with Spaces) domain_id = _read_domain_id_from_metadata() if domain_id: return domain_id - + # Fallback to original logic (Studio Classic with User Profiles) try: user_profile_arn = sagemaker_session.get_caller_identity_arn() - if user_profile_arn and 'user-profile' in user_profile_arn: + if user_profile_arn and "user-profile" in user_profile_arn: # ARN format: arn:aws:sagemaker:region:account:user-profile/domain-id/profile-name - return user_profile_arn.split('/')[1] + return user_profile_arn.split("/")[1] except Exception as e: logger.debug(f"Could not extract domain ID from user profile ARN: {e}") - + return None -def _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn: Optional[str] = None) -> Optional[str]: +def _resolve_mlflow_resource_arn( + sagemaker_session, mlflow_resource_arn: Optional[str] = None +) -> Optional[str]: """Resolve MLflow resource ARN using default experience logic.""" if mlflow_resource_arn: return mlflow_resource_arn - + try: - + mlflow_apps = MlflowApp.get_all( session=sagemaker_session.boto_session, - region=sagemaker_session.boto_session.region_name + region=sagemaker_session.boto_session.region_name, ) - + mlflow_apps_list = list(mlflow_apps) current_domain_id = _get_current_domain_id(sagemaker_session) - + # Check for domain match if current_domain_id: - domain_match = next((app for app in mlflow_apps_list - if isinstance(app.default_domain_id_list, list) - and current_domain_id in app.default_domain_id_list), None) + domain_match = next( + ( + app + for app in mlflow_apps_list + if isinstance(app.default_domain_id_list, list) + and current_domain_id in app.default_domain_id_list + ), + None, + ) if domain_match: logger.info("Using domain-matched MLflow app: %s", domain_match.arn) return domain_match.arn - + # Check for account default - account_default = next((app for app in mlflow_apps_list - if app.account_default_status == "ENABLED"), None) + account_default = next( + (app for app in mlflow_apps_list if app.account_default_status == "ENABLED"), None + ) if account_default: logger.info("Using account default MLflow app: %s", account_default.arn) return account_default.arn - + # Use first available with ready status if mlflow_apps_list: - ready_app = next((app for app in mlflow_apps_list - if app.status in ["Created", "Updated"]), None) + ready_app = next( + (app for app in mlflow_apps_list if app.status in ["Created", "Updated"]), None + ) if ready_app: logger.info("Using first available ready MLflow app: %s", ready_app.arn) return ready_app.arn - + # Create new app new_app = _create_mlflow_app(sagemaker_session) if new_app: logger.info("Created new MLflow app: %s", new_app.arn) return new_app.arn - + logger.warning("Failed to create MLflow app. MLflow tracking disabled.") return None - + except Exception as e: logger.error("Error resolving MLflow resource ARN: %s", e) return None @@ -156,45 +176,46 @@ def _create_mlflow_app(sagemaker_session) -> Optional[MlflowApp]: """Create a new MLflow app with minimal configuration.""" try: app_name = f"finetune-mlflow-{int(time.time())}" - account_id = sagemaker_session.boto_session.client('sts').get_caller_identity()['Account'] + account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"] region = sagemaker_session.boto_session.region_name artifact_store_uri = f"s3://sagemaker-{region}-{account_id}/mlflow-artifacts" role_arn = TrainDefaults.get_role(role=None, sagemaker_session=sagemaker_session) - + # Ensure S3 bucket and prefix exist - s3_client = sagemaker_session.boto_session.client('s3') + s3_client = sagemaker_session.boto_session.client("s3") bucket_name = f"sagemaker-{region}-{account_id}" - + try: # Check if prefix exists - response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix="mlflow-artifacts/", MaxKeys=1) - if 'Contents' not in response: + response = s3_client.list_objects_v2( + Bucket=bucket_name, Prefix="mlflow-artifacts/", MaxKeys=1 + ) + if "Contents" not in response: s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/") except s3_client.exceptions.NoSuchBucket: # Bucket doesn't exist, create bucket and prefix - if region == 'us-east-1': + if region == "us-east-1": s3_client.create_bucket(Bucket=bucket_name) else: s3_client.create_bucket( - Bucket=bucket_name, - CreateBucketConfiguration={'LocationConstraint': region} + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} ) s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/") - + new_app = MlflowApp.create( name=app_name, artifact_store_uri=artifact_store_uri, role_arn=role_arn, account_default_status="DISABLED", session=sagemaker_session.boto_session, - region=region + region=region, ) - + # Wait for app to reach Created/Updated state max_wait_time = 600 # 10 minutes - poll_interval = 10 # 10 seconds + poll_interval = 10 # 10 seconds start_time = time.time() - + while time.time() - start_time < max_wait_time: new_app.refresh() if new_app.status in ["Created", "Updated"]: @@ -202,18 +223,18 @@ def _create_mlflow_app(sagemaker_session) -> Optional[MlflowApp]: elif new_app.status in ["Failed", "Stopped"]: # Get detailed error from MLflow app error_msg = f"MLflow app creation failed with status: {new_app.status}" - if hasattr(new_app, 'failure_reason') and new_app.failure_reason: + if hasattr(new_app, "failure_reason") and new_app.failure_reason: error_msg += f". Reason: {new_app.failure_reason}" raise RuntimeError(error_msg) time.sleep(poll_interval) - + # Timeout case - get current status and any error details new_app.refresh() error_msg = f"MLflow app creation failed. Current status: {new_app.status}" - if hasattr(new_app, 'failure_reason') and new_app.failure_reason: + if hasattr(new_app, "failure_reason") and new_app.failure_reason: error_msg += f". Reason: {new_app.failure_reason}" raise RuntimeError(error_msg) - + except Exception as e: logger.error("Failed to create MLflow app: %s", e) return None @@ -229,14 +250,18 @@ def _validate_dataset_arn(dataset: str, param_name: str): def _validate_evaluator_arn(evaluator_arn: str, param_name: str): """Validate that evaluator_arn is in correct ARN format.""" arn_pattern = r"^arn:aws:sagemaker:[^:]+:\d+:hub-content/[^/]+/JsonDoc/[^/]+/[\d\.]+$" - if not evaluator_arn.startswith("arn:aws:sagemaker:") or not re.match(arn_pattern, evaluator_arn): + if not evaluator_arn.startswith("arn:aws:sagemaker:") or not re.match( + arn_pattern, evaluator_arn + ): raise ValueError(f"{param_name} must be a valid SageMaker hub-content evaluator ARN") def _validate_model_package_group_requirement(model, model_package_group_name): """Validate model_package_group_name when source_model_package_arn is not available.""" if not isinstance(model, ModelPackage) and not model_package_group_name: - raise ValueError("model_package_group_name must be provided when source_model_package_arn is not available") + raise ValueError( + "model_package_group_name must be provided when source_model_package_arn is not available" + ) def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_session) -> str: @@ -244,7 +269,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ if isinstance(model_package_group_name_or_arn, str): # Check if it's already an ARN using pattern matching arn_pattern = r"^arn:aws:sagemaker:[^:]+:\d+:model-package-group/[^/]+$" - + if re.match(arn_pattern, model_package_group_name_or_arn): # It's already an ARN return model_package_group_name_or_arn @@ -253,7 +278,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ model_package_group = ModelPackageGroup.get( model_package_group_name=model_package_group_name_or_arn, session=sagemaker_session.boto_session, - region=sagemaker_session.boto_session.region_name + region=sagemaker_session.boto_session.region_name, ) return model_package_group.model_package_group_arn else: @@ -263,7 +288,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ def _get_default_s3_output_path(sagemaker_session) -> str: """Generate default S3 output path: s3://sagemaker--/output""" - account_id = sagemaker_session.boto_session.client('sts').get_caller_identity()['Account'] + account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"] region = sagemaker_session.boto_session.region_name return f"s3://sagemaker-{region}-{account_id}/output" @@ -295,17 +320,21 @@ def _resolve_model_name(model_package) -> str: if model_package: try: # Extract base model from InferenceSpecification - if (model_package.inference_specification and - model_package.inference_specification.containers): + if ( + model_package.inference_specification + and model_package.inference_specification.containers + ): container = model_package.inference_specification.containers[0] - if hasattr(container, 'base_model') and container.base_model: + if hasattr(container, "base_model") and container.base_model: return container.base_model.hub_content_name - - raise ValueError("Continued fine tuning is only allowed on model packages fine tuned with sagemaker 1p models") + + raise ValueError( + "Continued fine tuning is only allowed on model packages fine tuned with sagemaker 1p models" + ) except Exception as e: logger.error("Failed to resolve model_name from model package: %s", e) raise - + raise ValueError("model name or package must be provided") @@ -320,7 +349,7 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: def _parse_context_length(value) -> int: """Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192). - + Returns 0 if value is None or unparseable. """ if not value: @@ -343,10 +372,10 @@ def _get_fine_tuning_options_and_model_arn( training_type, sagemaker_session, sequence_length=None, - hub_name: Optional[str] = None + hub_name: Optional[str] = None, ) -> tuple: """Get fine-tuning options and model ARN for given customization technique. - + Args: model_name: Name of the model in the hub. customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF"). @@ -355,7 +384,7 @@ def _get_fine_tuning_options_and_model_arn( sequence_length: Optional sequence length (e.g., "8K"). When provided, filters recipes by MaxContextLength >= the requested value. hub_name: Hub name (default: "SageMakerPublicHub"). - + Returns: tuple: (FineTuningOptions, model_arn, is_gated_model) """ @@ -366,29 +395,35 @@ def _get_fine_tuning_options_and_model_arn( hub_content = _get_hub_content_metadata( hub_name=hub_name, - hub_content_type="Model", + hub_content_type="Model", hub_content_name=model_name, session=sagemaker_session.boto_session, - region=sagemaker_session.boto_session.region_name + region=sagemaker_session.boto_session.region_name, ) - - model_arn = hub_content.get('hub_content_arn') - document = hub_content.get('hub_content_document') - + + model_arn = hub_content.get("hub_content_arn") + document = hub_content.get("hub_content_document") + # Check if model is gated is_gated_model = document.get("GatedBucket", False) - + recipe_collection = document.get("RecipeCollection", []) - + # Filter recipes by customization technique - matching_recipes = [r for r in recipe_collection if r.get("CustomizationTechnique") == customization_technique] - + matching_recipes = [ + r + for r in recipe_collection + if r.get("CustomizationTechnique") == customization_technique + ] + if not matching_recipes: - raise ValueError(f"No recipes found for customization technique: {customization_technique}") - + raise ValueError( + f"No recipes found for customization technique: {customization_technique}" + ) + # Filter recipes that have SmtjRecipeTemplateS3Uri key recipes_with_template = [r for r in matching_recipes if r.get("SmtjRecipeTemplateS3Uri")] - + if not recipes_with_template: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") @@ -397,12 +432,18 @@ def _get_fine_tuning_options_and_model_arn( requested = _parse_context_length(sequence_length) candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")] if candidates_with_context: - filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested] + filtered = [ + r + for r in candidates_with_context + if _parse_context_length(r.get("SequenceLength")) >= requested + ] if filtered: filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) recipes_with_template = filtered else: - available = sorted(set(r.get("SequenceLength") for r in candidates_with_context)) + available = sorted( + set(r.get("SequenceLength") for r in candidates_with_context) + ) raise ValueError( f"No recipes found with SequenceLength >= {sequence_length}. " f"Available sequence lengths: {available}" @@ -416,13 +457,33 @@ def _get_fine_tuning_options_and_model_arn( # Select recipe based on training type # Collect override_params from ALL matching recipes (standard + subscription) recipe = None - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) - elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + if ( + isinstance(training_type, TrainingType) and training_type == TrainingType.LORA + ) or training_type == "LORA": + recipe = next( + ( + r + for r in recipes_with_template + if r.get("Peft") and not r.get("IsSubscriptionModel") + ), + None, + ) + elif ( + isinstance(training_type, TrainingType) and training_type == TrainingType.FULL + ) or training_type == "FULL": + recipe = next( + ( + r + for r in recipes_with_template + if not r.get("Peft") and not r.get("IsSubscriptionModel") + ), + None, + ) if not recipe: - raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") + raise ValueError( + f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}" + ) # Start with the standard recipe's override_params options_dict = {} @@ -435,14 +496,33 @@ def _get_fine_tuning_options_and_model_arn( options_dict = json.loads(obj["Body"].read()) # Auto-detect and merge subscription recipe's override_params if available - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) + if ( + isinstance(training_type, TrainingType) and training_type == TrainingType.LORA + ) or training_type == "LORA": + sub_recipe = next( + ( + r + for r in recipes_with_template + if r.get("Peft") and r.get("IsSubscriptionModel") + ), + None, + ) else: - sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) + sub_recipe = next( + ( + r + for r in recipes_with_template + if not r.get("Peft") and r.get("IsSubscriptionModel") + ), + None, + ) if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"): try: - sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]) + sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace( + "{customer_id}", + sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"], + ) sub_uri_path = sub_s3_uri.replace("s3://", "") # Handle access point ARN URIs if sub_uri_path.startswith("arn:"): @@ -460,73 +540,77 @@ def _get_fine_tuning_options_and_model_arn( if k not in options_dict: v_copy = v.copy() if isinstance(v, dict) else v if isinstance(v_copy, dict): - v_copy['default'] = None # No default — won't appear in to_dict() unless set + v_copy["default"] = ( + None # No default — won't appear in to_dict() unless set + ) options_dict[k] = v_copy except Exception as e: - logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}") + logger.debug( + f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}" + ) if options_dict: return FineTuningOptions(options_dict), model_arn, is_gated_model else: return FineTuningOptions({}), model_arn, is_gated_model - + except Exception as e: logger.error("Exception getting fine-tuning options: %s", e) raise -def _create_input_channels(dataset: str, content_type: Optional[str] = None, - input_compression_type: Optional[str] = None, - record_wrapper_type: Optional[str] = None, - input_mode: Optional[str] = None): +def _create_input_channels( + dataset: str, + content_type: Optional[str] = None, + input_compression_type: Optional[str] = None, + record_wrapper_type: Optional[str] = None, + input_mode: Optional[str] = None, +): """Create input channels from dataset (S3 URI or dataset ARN). - + Args: dataset: S3 URI (s3://bucket/key) or dataset ARN (arn:aws:sagemaker:...) - + Returns: list: List of Channel objects """ channels = [] - if dataset.startswith("s3://"): # S3 URI - create S3DataSource data_source = DataSource( s3_data_source={ "s3_uri": dataset, "s3_data_type": "S3Prefix", - "s3_data_distribution_type": "FullyReplicated" + "s3_data_distribution_type": "FullyReplicated", } ) else: # Dataset ARN - validate and create dataset source _validate_dataset_arn(dataset, "dataset") - data_source = DataSource( - dataset_source={"dataset_arn": dataset} - ) - + data_source = DataSource(dataset_source={"dataset_arn": dataset}) + channel = Channel( channel_name="train", data_source=data_source, content_type=content_type, compression_type=input_compression_type, record_wrapper_type=record_wrapper_type, - input_mode=input_mode - ) + input_mode=input_mode, + ) channels.append(channel) - + return channels def _resolve_model_and_name(model, sagemaker_session=None): """Resolve model and extract model name from string, ARN, or ModelPackage object. - + Args: model: Can be a model name (str), model package ARN (str), or ModelPackage object sagemaker_session: SageMaker session for API calls (required for ARN resolution) - + Returns: tuple: (resolved_model, model_name) """ @@ -536,14 +620,15 @@ def _resolve_model_and_name(model, sagemaker_session=None): region_name = sagemaker_session.boto_region_name else: # Try to get region from SAGEMAKER_REGION env var, then boto3 session, then AWS_DEFAULT_REGION - region_name = os.environ.get('SAGEMAKER_REGION') + region_name = os.environ.get("SAGEMAKER_REGION") if not region_name: try: import boto3 - region_name = boto3.Session().region_name or os.environ.get('AWS_DEFAULT_REGION') + + region_name = boto3.Session().region_name or os.environ.get("AWS_DEFAULT_REGION") except: pass - + if isinstance(model, str): # Check if it's a model package ARN if model.startswith("arn:aws:sagemaker:") and ":model-package/" in model: @@ -551,7 +636,7 @@ def _resolve_model_and_name(model, sagemaker_session=None): model_package = ModelPackage.get( model_package_name=model, session=sagemaker_session.boto_session if sagemaker_session else None, - region=sagemaker_session.boto_session.region_name if sagemaker_session else None + region=sagemaker_session.boto_session.region_name if sagemaker_session else None, ) model_name = _resolve_model_name(model_package) # Validate region availability @@ -573,11 +658,17 @@ def _resolve_model_and_name(model, sagemaker_session=None): return model, model_name -def _create_serverless_config(model_arn, customization_technique, - training_type, accept_eula, evaluator_arn=None, - sequence_length=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: +def _create_serverless_config( + model_arn, + customization_technique, + training_type, + accept_eula, + evaluator_arn=None, + sequence_length=None, + job_type=JOB_TYPE, +) -> Optional["ServerlessJobConfig"]: """Create serverless job configuration for fine-tuning. - + Args: model_arn: ARN of the base model customization_technique: Technique used (e.g., "SFT", "DPO", "RLVR", "RLAIF") @@ -586,12 +677,15 @@ def _create_serverless_config(model_arn, customization_technique, evaluator_arn: Optional evaluator ARN for RLVR/RLAIF sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K") job_type: Type of job (default: "FineTuning") - + Returns: ServerlessJobConfig object or None if required parameters are missing """ - peft = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) \ + peft = ( + None + if (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) else (training_type.value if isinstance(training_type, TrainingType) else training_type) + ) # Create ServerlessJobConfig using shapes serverless_config = ServerlessJobConfig( @@ -609,44 +703,41 @@ def _create_serverless_config(model_arn, customization_technique, def _create_input_data_config(training_dataset, validation_dataset=None): """Create input data configuration from training and validation datasets. - + Args: training_dataset: Training dataset (method parameter takes priority over class attribute) validation_dataset: Validation dataset (method parameter takes priority over class attribute) - + Returns: List of InputData objects for training job configuration """ # Extract and validate training dataset final_training_dataset = _extract_dataset_source(training_dataset, "training_dataset") - - input_data_config = [ - InputData(channel_name="train", data_source=final_training_dataset) - ] - + + input_data_config = [InputData(channel_name="train", data_source=final_training_dataset)] + # Add validation dataset if provided if validation_dataset: final_validation_dataset = _extract_dataset_source(validation_dataset, "validation_dataset") input_data_config.append( InputData(channel_name="validation", data_source=final_validation_dataset) ) - - return input_data_config + return input_data_config def _create_model_package_config(model_package_group_name, model, sagemaker_session): """Create model package configuration with resolved ARNs. - + Args: model_package_group_name: Model package group name to resolve model: Model object (used to resolve source model package ARN if it's a ModelPackage) sagemaker_session: SageMaker session for API calls - + Returns: ModelPackageConfig object or None if no model package group name provided """ - + model_package_group_arn = None if model_package_group_name: model_package_group_arn = _resolve_model_package_group_arn( @@ -663,22 +754,21 @@ def _create_model_package_config(model_package_group_name, model, sagemaker_sess ) - -def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None, - mlflow_experiment_name=None, mlflow_run_name=None): +def _create_mlflow_config( + sagemaker_session, mlflow_resource_arn=None, mlflow_experiment_name=None, mlflow_run_name=None +): """Create MLflow configuration with resolved resource ARN. - + Args: sagemaker_session: SageMaker session for resolving MLflow ARN mlflow_resource_arn: MLflow resource ARN (if None, uses default experience) mlflow_experiment_name: MLflow experiment name mlflow_run_name: MLflow run name - + Returns: MlflowConfig object or None if no MLflow resource ARN is resolved """ - # Derive mlflow_resource_arn with default experience resolved_mlflow_arn = _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn) logger.info(f"MLflow resource ARN: {resolved_mlflow_arn}") @@ -691,18 +781,18 @@ def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None, mlflow_experiment_name=mlflow_experiment_name, mlflow_run_name=mlflow_run_name, ) - + return mlflow_config -def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None): +def _create_output_config(sagemaker_session, s3_output_path=None, kms_key_id=None): """Create output data configuration with default S3 path if needed. - + Args: s3_output_path: S3 output path (if None, generates default path) sagemaker_session: SageMaker session for generating default path kms_key_id: Optional KMS key ID for encryption - + Returns: OutputDataConfig object """ @@ -710,7 +800,7 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None # Use default S3 output path if none provided if s3_output_path is None: s3_output_path = _get_default_s3_output_path(sagemaker_session) - + # Validate S3 path exists _validate_s3_path_exists(s3_output_path, sagemaker_session) @@ -720,16 +810,16 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None ) -def _convert_input_data_to_channels(input_data_config ): +def _convert_input_data_to_channels(input_data_config): """Convert InputData objects to Channel objects with S3 and dataset ARN support. - + Args: input_data_config: List of InputData objects - + Returns: List of Channel objects """ - + channels = [] for input_data in input_data_config: if input_data.data_source.startswith("s3://"): @@ -738,21 +828,19 @@ def _convert_input_data_to_channels(input_data_config ): s3_data_source={ "s3_uri": input_data.data_source, "s3_data_type": "S3Prefix", - "s3_data_distribution_type": "FullyReplicated" + "s3_data_distribution_type": "FullyReplicated", } ) else: # Dataset ARN - create dataset source - data_source = DataSource( - dataset_source={"dataset_arn": input_data.data_source} - ) + data_source = DataSource(dataset_source={"dataset_arn": input_data.data_source}) channel = Channel( channel_name=input_data.channel_name, data_source=data_source, ) channels.append(channel) - + return channels @@ -761,41 +849,45 @@ def _validate_and_resolve_model_package_group(model, model_package_group_name): # If model_package_group_name is already provided, return it as-is if model_package_group_name: return model_package_group_name - + # Try to resolve from ModelPackage if available if isinstance(model, ModelPackage): return model.model_package_group_name - + # Only validate if model_package_group_name is None and model is not ModelPackage - raise ValueError("model_package_group_name must be provided when model given is " - "not a ModelPackage artifact/not continued finetuning") + raise ValueError( + "model_package_group_name must be provided when model given is " + "not a ModelPackage artifact/not continued finetuning" + ) def _validate_eula_for_gated_model(model, accept_eula, is_gated_model): """Validate EULA acceptance for gated models. - + Args: model: Original model input (string, ARN, or ModelPackage) accept_eula: Boolean indicating if EULA is accepted is_gated_model: Boolean indicating if the model is gated - + Returns: bool: True if EULA is accepted (either explicitly or by default for ARN/ModelPackage) - + Raises: ValueError: If model is gated but accept_eula is False """ # For ModelPackage/ARN inputs, EULA is assumed accepted by default - if isinstance(model, ModelPackage) or (isinstance(model, str) and model.startswith("arn:aws:sagemaker:")): + if isinstance(model, ModelPackage) or ( + isinstance(model, str) and model.startswith("arn:aws:sagemaker:") + ): return True - + # Validate EULA acceptance for gated models if is_gated_model and not accept_eula: raise ValueError( f"Model '{model}' is a gated model and requires EULA acceptance. " "Please set accept_eula=True to proceed with training." ) - + return accept_eula @@ -803,14 +895,14 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session): """Validate S3 path and create bucket/prefix if they don't exist.""" if not s3_path.startswith("s3://"): raise ValueError(f"Invalid S3 path format: {s3_path}") - + # Parse S3 URI s3_parts = s3_path.replace("s3://", "").split("/", 1) bucket_name = s3_parts[0] prefix = s3_parts[1] if len(s3_parts) > 1 else "" - - s3_client = sagemaker_session.boto_session.client('s3') - + + s3_client = sagemaker_session.boto_session.client("s3") + try: # Check if bucket exists, create if it doesn't try: @@ -819,25 +911,24 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session): if "NoSuchBucket" in str(e) or "Not Found" in str(e): # Create bucket region = sagemaker_session.boto_region_name - if region == 'us-east-1': + if region == "us-east-1": s3_client.create_bucket(Bucket=bucket_name) else: s3_client.create_bucket( - Bucket=bucket_name, - CreateBucketConfiguration={'LocationConstraint': region} + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} ) else: raise - + # If prefix is provided, check if it exists, create if it doesn't if prefix: response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1) - if 'Contents' not in response: + if "Contents" not in response: # Create the prefix by putting an empty object - if not prefix.endswith('/'): - prefix += '/' - s3_client.put_object(Bucket=bucket_name, Key=prefix, Body=b'') - + if not prefix.endswith("/"): + prefix += "/" + s3_client.put_object(Bucket=bucket_name, Key=prefix, Body=b"") + except Exception as e: raise ValueError(f"Failed to validate/create S3 path '{s3_path}': {str(e)}") @@ -845,6 +936,7 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session): def _validate_hyperparameter_values(hyperparameters: dict): """Validate hyperparameter values for allowed characters.""" import re + allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$" for key, value in hyperparameters.items(): if isinstance(value, str) and not re.match(allowed_chars, value): diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index 8e3bc17d5e..6a6f3f07bd 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -19,7 +19,7 @@ _create_mlflow_config, _create_model_package_config, _validate_eula_for_gated_model, - _validate_hyperparameter_values + _validate_hyperparameter_values, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -53,19 +53,19 @@ class DPOTrainer(BaseTrainer): model="meta-llama/Llama-2-7b-hf", model_package_group="my-dpo-models" ) - + # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/preference_data.jsonl", wait=False ) - + # Wait for completion training_job.wait() - + # Refresh job status training_job.refresh() - + # Get the fine-tuned model package ARN model_package_arn = training_job.output_model_package_arn @@ -105,31 +105,33 @@ class DPOTrainer(BaseTrainer): "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". If not specified, the service will use default recipe selection behavior. """ + def __init__( - self, - model: Union[str, ModelPackage], - training_type: Union[TrainingType, str] = TrainingType.LORA, - model_package_group: Optional[Union[str, ModelPackageGroup]] = None, - mlflow_resource_arn: Optional[str] = None, - mlflow_experiment_name: Optional[str] = None, - mlflow_run_name: Optional[str] = None, - training_dataset: Optional[Union[str, DataSet]] = None, - validation_dataset: Optional[Union[str, DataSet]] = None, - s3_output_path: Optional[str] = None, - kms_key_id: Optional[str] = None, - networking: Optional[VpcConfig] = None, - accept_eula: bool = False, - stopping_condition: Optional[StoppingCondition] = None, - sequence_length: Optional[str] = None, - **kwargs, + self, + model: Union[str, ModelPackage], + training_type: Union[TrainingType, str] = TrainingType.LORA, + model_package_group: Optional[Union[str, ModelPackageGroup]] = None, + mlflow_resource_arn: Optional[str] = None, + mlflow_experiment_name: Optional[str] = None, + mlflow_run_name: Optional[str] = None, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + s3_output_path: Optional[str] = None, + kms_key_id: Optional[str] = None, + networking: Optional[VpcConfig] = None, + accept_eula: bool = False, + stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, + **kwargs, ): super().__init__(**kwargs) - + self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type - self.model_package_group = _validate_and_resolve_model_package_group(model, - model_package_group) + self.model_package_group = _validate_and_resolve_model_package_group( + model, model_package_group + ) self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name self.mlflow_run_name = mlflow_run_name @@ -142,17 +144,20 @@ def __init__( self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( - self._model_name, - CustomizationTechnique.DPO.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), - sequence_length=self.sequence_length + self.hyperparameters, self._model_arn, is_gated_model = ( + _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.DPO.value, + self.training_type, + self.sagemaker_session + or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length, + ) ) # Process hyperparameters self._process_hyperparameters() - + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -160,35 +165,37 @@ def _process_hyperparameters(self): """Remove hyperparameter keys that are handled by constructor inputs.""" if self.hyperparameters: # Remove keys that are handled by constructor inputs - if hasattr(self.hyperparameters, 'data_path'): - delattr(self.hyperparameters, 'data_path') - self.hyperparameters._specs.pop('data_path', None) - if hasattr(self.hyperparameters, 'output_path'): - delattr(self.hyperparameters, 'output_path') - self.hyperparameters._specs.pop('output_path', None) - if hasattr(self.hyperparameters, 'data_s3_path'): - delattr(self.hyperparameters, 'data_s3_path') - self.hyperparameters._specs.pop('data_s3_path', None) - if hasattr(self.hyperparameters, 'output_s3_path'): - delattr(self.hyperparameters, 'output_s3_path') - self.hyperparameters._specs.pop('output_s3_path', None) - if hasattr(self.hyperparameters, 'training_data_name'): - delattr(self.hyperparameters, 'training_data_name') - self.hyperparameters._specs.pop('training_data_name', None) - if hasattr(self.hyperparameters, 'validation_data_name'): - delattr(self.hyperparameters, 'validation_data_name') - self.hyperparameters._specs.pop('validation_data_name', None) - if hasattr(self.hyperparameters, 'validation_data_path'): - delattr(self.hyperparameters, 'validation_data_path') - self.hyperparameters._specs.pop('validation_data_path', None) + if hasattr(self.hyperparameters, "data_path"): + delattr(self.hyperparameters, "data_path") + self.hyperparameters._specs.pop("data_path", None) + if hasattr(self.hyperparameters, "output_path"): + delattr(self.hyperparameters, "output_path") + self.hyperparameters._specs.pop("output_path", None) + if hasattr(self.hyperparameters, "data_s3_path"): + delattr(self.hyperparameters, "data_s3_path") + self.hyperparameters._specs.pop("data_s3_path", None) + if hasattr(self.hyperparameters, "output_s3_path"): + delattr(self.hyperparameters, "output_s3_path") + self.hyperparameters._specs.pop("output_s3_path", None) + if hasattr(self.hyperparameters, "training_data_name"): + delattr(self.hyperparameters, "training_data_name") + self.hyperparameters._specs.pop("training_data_name", None) + if hasattr(self.hyperparameters, "validation_data_name"): + delattr(self.hyperparameters, "validation_data_name") + self.hyperparameters._specs.pop("validation_data_name", None) + if hasattr(self.hyperparameters, "validation_data_path"): + delattr(self.hyperparameters, "validation_data_path") + self.hyperparameters._specs.pop("validation_data_path", None) @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DPOTrainer.train") - def train(self, - training_dataset: Optional[Union[str, DataSet]] = None, - validation_dataset: Optional[Union[str, DataSet]] = None, - wait: bool = True, - wait_timeout: Optional[int] = None, - poll: int = 5): + def train( + self, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + wait: bool = True, + wait_timeout: Optional[int] = None, + poll: int = 5, + ): """Execute the DPO training job. Parameters: @@ -221,16 +228,16 @@ def train(self, logger.info(f"Training Job Name: {current_training_job_name}") print(f"Training Job Name: {current_training_job_name}") - #data - input_data_config = _create_input_data_config(training_dataset or self.training_dataset, - validation_dataset or self.validation_dataset - ) + # data + input_data_config = _create_input_data_config( + training_dataset or self.training_dataset, validation_dataset or self.validation_dataset + ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, - kms_key_id=self.kms_key_id + kms_key_id=self.kms_key_id, ) serverless_config = _create_serverless_config( @@ -239,7 +246,7 @@ def train(self, training_type=self.training_type, accept_eula=self.accept_eula, sequence_length=self.sequence_length, - job_type=JOB_TYPE + job_type=JOB_TYPE, ) mlflow_config = _create_mlflow_config( @@ -255,7 +262,7 @@ def train(self, model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) vpc_config = self.networking if self.networking else None @@ -276,7 +283,7 @@ def train(self, "region": sagemaker_session.boto_session.region_name, "tags": tags, } - + # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition @@ -290,15 +297,15 @@ def train(self, if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError - try : + + try: wait_kwargs = {} if wait_timeout is not None: - wait_kwargs['timeout'] = wait_timeout - wait_kwargs['poll'] = poll + wait_kwargs["timeout"] = wait_timeout + wait_kwargs["poll"] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) self.latest_training_job = training_job return training_job - diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index 5d782d8fa3..09084359f1 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -2,7 +2,12 @@ import logging from sagemaker.train.base_trainer import BaseTrainer from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE -from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage +from sagemaker.core.resources import ( + TrainingJob, + ModelPackageGroup, + MlflowTrackingServer, + ModelPackage, +) from sagemaker.core.shapes import VpcConfig from sagemaker.train.defaults import TrainDefaults from sagemaker.train.utils import _get_unique_name, _get_studio_tags @@ -23,7 +28,7 @@ _create_mlflow_config, _create_model_package_config, _validate_eula_for_gated_model, - _validate_hyperparameter_values + _validate_hyperparameter_values, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -60,19 +65,19 @@ class RLAIFTrainer(BaseTrainer): reward_model_id="reward-model-id", reward_prompt="Rate the helpfulness of this response on a scale of 1-10" ) - + # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/rlaif_data.jsonl", wait=False ) - + # Wait for completion training_job.wait() - + # Refresh job status training_job.refresh() - + # Get the fine-tuned model package ARN model_package_arn = training_job.output_model_package_arn @@ -148,8 +153,9 @@ def __init__( self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type - self.model_package_group = _validate_and_resolve_model_package_group(model, - model_package_group) + self.model_package_group = _validate_and_resolve_model_package_group( + model, model_package_group + ) self.reward_model_id = self._validate_reward_model_id(reward_model_id) self.reward_prompt = reward_prompt self.mlflow_resource_arn = mlflow_resource_arn @@ -164,17 +170,20 @@ def __init__( self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( - self._model_name, - CustomizationTechnique.RLAIF.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), - sequence_length=self.sequence_length + self.hyperparameters, self._model_arn, is_gated_model = ( + _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLAIF.value, + self.training_type, + self.sagemaker_session + or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length, + ) ) - + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) - + # Process reward_prompt parameter self._process_hyperparameters() @@ -188,23 +197,33 @@ def _validate_reward_model_id(self, reward_model_id): f"Invalid reward_model_id '{reward_model_id}'. " f"Available models are: {list(_ALLOWED_REWARD_MODEL_IDS.keys())}" ) - + # Check region compatibility - session = self.sagemaker_session if hasattr(self, 'sagemaker_session') and self.sagemaker_session else TrainDefaults.get_sagemaker_session() + session = ( + self.sagemaker_session + if hasattr(self, "sagemaker_session") and self.sagemaker_session + else TrainDefaults.get_sagemaker_session() + ) current_region = session.boto_region_name allowed_regions = _ALLOWED_REWARD_MODEL_IDS[reward_model_id] - + if current_region not in allowed_regions: raise ValueError( f"Reward model '{reward_model_id}' is not available in region '{current_region}'. " f"Available regions for this model: {allowed_regions}" ) - + return reward_model_id - @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): + def train( + self, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + wait: bool = True, + wait_timeout: Optional[int] = None, + poll: int = 5, + ): """Execute the RLAIF training job. Parameters: @@ -236,19 +255,19 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati logger.info(f"Training Job Name: {current_training_job_name}") - #data - input_data_config = _create_input_data_config(training_dataset or self.training_dataset, - validation_dataset or self.validation_dataset - ) + # data + input_data_config = _create_input_data_config( + training_dataset or self.training_dataset, validation_dataset or self.validation_dataset + ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, - kms_key_id=self.kms_key_id + kms_key_id=self.kms_key_id, ) - evaluator_arn = getattr(self, '_evaluator_arn', None) + evaluator_arn = getattr(self, "_evaluator_arn", None) serverless_config = _create_serverless_config( model_arn=self._model_arn, customization_technique=CustomizationTechnique.RLAIF.value, @@ -256,7 +275,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati accept_eula=self.accept_eula, evaluator_arn=evaluator_arn, sequence_length=self.sequence_length, - job_type=JOB_TYPE + job_type=JOB_TYPE, ) mlflow_config = _create_mlflow_config( @@ -273,7 +292,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) vpc_config = self.networking if self.networking else None @@ -294,7 +313,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati "region": sagemaker_session.boto_session.region_name, "tags": tags, } - + # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition @@ -308,11 +327,12 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError - try : + + try: wait_kwargs = {} if wait_timeout is not None: - wait_kwargs['timeout'] = wait_timeout - wait_kwargs['poll'] = poll + wait_kwargs["timeout"] = wait_timeout + wait_kwargs["poll"] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) @@ -322,27 +342,31 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati def _process_hyperparameters(self): """Update hyperparameters based on constructor inputs and process reward_prompt.""" - if not self.hyperparameters or not hasattr(self.hyperparameters, '_specs') or not self.hyperparameters._specs: + if ( + not self.hyperparameters + or not hasattr(self.hyperparameters, "_specs") + or not self.hyperparameters._specs + ): return - + # Remove keys that are handled by constructor inputs - if hasattr(self.hyperparameters, 'output_path'): - delattr(self.hyperparameters, 'output_path') - self.hyperparameters._specs.pop('output_path', None) - if hasattr(self.hyperparameters, 'data_path'): - delattr(self.hyperparameters, 'data_path') - self.hyperparameters._specs.pop('data_path', None) - if hasattr(self.hyperparameters, 'validation_data_path'): - delattr(self.hyperparameters, 'validation_data_path') - self.hyperparameters._specs.pop('validation_data_path', None) - + if hasattr(self.hyperparameters, "output_path"): + delattr(self.hyperparameters, "output_path") + self.hyperparameters._specs.pop("output_path", None) + if hasattr(self.hyperparameters, "data_path"): + delattr(self.hyperparameters, "data_path") + self.hyperparameters._specs.pop("data_path", None) + if hasattr(self.hyperparameters, "validation_data_path"): + delattr(self.hyperparameters, "validation_data_path") + self.hyperparameters._specs.pop("validation_data_path", None) + # Update judge_model_id if reward_model_id is provided - if hasattr(self, 'reward_model_id') and self.reward_model_id: + if hasattr(self, "reward_model_id") and self.reward_model_id: judge_model_value = f"bedrock/{self.reward_model_id}" self.hyperparameters.judge_model_id = judge_model_value - + # Process reward_prompt parameter - if hasattr(self, 'reward_prompt') and self.reward_prompt: + if hasattr(self, "reward_prompt") and self.reward_prompt: if isinstance(self.reward_prompt, str): if self.reward_prompt.startswith("Builtin"): # Handle builtin reward prompts @@ -352,9 +376,9 @@ def _process_hyperparameters(self): self._process_non_builtin_reward_prompt() else: # Handle evaluator object - if hasattr(self.hyperparameters, 'judge_prompt_template'): - delattr(self.hyperparameters, 'judge_prompt_template') - self.hyperparameters._specs.pop('judge_prompt_template', None) + if hasattr(self.hyperparameters, "judge_prompt_template"): + delattr(self.hyperparameters, "judge_prompt_template") + self.hyperparameters._specs.pop("judge_prompt_template", None) evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") self._evaluator_arn = evaluator_arn @@ -362,10 +386,10 @@ def _process_hyperparameters(self): def _process_non_builtin_reward_prompt(self): """Process non-builtin reward prompt (ARN or hub content name).""" # Remove judge_prompt_template for non-builtin prompts - if hasattr(self.hyperparameters, 'judge_prompt_template'): - delattr(self.hyperparameters, 'judge_prompt_template') - self.hyperparameters._specs.pop('judge_prompt_template', None) - + if hasattr(self.hyperparameters, "judge_prompt_template"): + delattr(self.hyperparameters, "judge_prompt_template") + self.hyperparameters._specs.pop("judge_prompt_template", None) + if self.reward_prompt.startswith("arn:aws:sagemaker:"): # Validate and assign ARN evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") @@ -373,39 +397,39 @@ def _process_non_builtin_reward_prompt(self): else: try: session = TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - ) + sagemaker_session=self.sagemaker_session + ) hub_content = _get_hub_content_metadata( hub_name=get_sagemaker_hub_name(), hub_content_type="JsonDoc", hub_content_name=self.reward_prompt, session=session.boto_session, - region=session.boto_session.region_name + region=session.boto_session.region_name, ) # Store ARN for evaluator_arn self._evaluator_arn = hub_content.hub_content_arn except Exception as e: - raise ValueError(f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}") - - + raise ValueError( + f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}" + ) def _update_judge_prompt_template_direct(self, reward_prompt): """Update judge_prompt_template based on Builtin reward function.""" # Get available templates from hyperparameters specs - judge_prompt_spec = self.hyperparameters._specs.get('judge_prompt_template', {}) - available_templates = judge_prompt_spec.get('enum', []) - + judge_prompt_spec = self.hyperparameters._specs.get("judge_prompt_template", {}) + available_templates = judge_prompt_spec.get("enum", []) + if not available_templates: # If no enum found, use the current value as the only available option - current_value = getattr(self.hyperparameters, 'judge_prompt_template', None) + current_value = getattr(self.hyperparameters, "judge_prompt_template", None) if current_value: available_templates = [current_value] else: return - + # Extract template name after "Builtin." and convert to lowercase template_name = reward_prompt.split(".", 1)[1].lower() - + # Find matching template by extracting filename without extension matching_template = None for template in available_templates: @@ -413,14 +437,15 @@ def _update_judge_prompt_template_direct(self, reward_prompt): if template_filename == template_name: matching_template = template break - + if matching_template: self.hyperparameters.judge_prompt_template = matching_template else: - available_options = [f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates] + available_options = [ + f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates + ] raise ValueError( f"Selected reward function option '{reward_prompt}' is not available. " f"Choose one from the available options: {available_options}. " f"Example: reward_prompt='Builtin.summarize'" ) - diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 53029155f2..3abcbbf47e 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -2,7 +2,12 @@ import logging from sagemaker.train.base_trainer import BaseTrainer from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE -from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage +from sagemaker.core.resources import ( + TrainingJob, + ModelPackageGroup, + MlflowTrackingServer, + ModelPackage, +) from sagemaker.core.shapes import VpcConfig from sagemaker.train.defaults import TrainDefaults from sagemaker.train.utils import _get_unique_name, _get_studio_tags @@ -21,7 +26,7 @@ _create_mlflow_config, _create_model_package_config, _validate_eula_for_gated_model, - _validate_hyperparameter_values + _validate_hyperparameter_values, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -56,19 +61,19 @@ class RLVRTrainer(BaseTrainer): model_package_group="my-rlvr-models", custom_reward_function="arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/my-evaluator/1.0" ) - + # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/rlvr_data.jsonl", wait=False ) - + # Wait for completion training_job.wait() - + # Refresh job status training_job.refresh() - + # Get the fine-tuned model package ARN model_package_arn = training_job.output_model_package_arn @@ -139,8 +144,9 @@ def __init__( self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type - self.model_package_group = _validate_and_resolve_model_package_group(model, - model_package_group) + self.model_package_group = _validate_and_resolve_model_package_group( + model, model_package_group + ) self.custom_reward_function = custom_reward_function self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name @@ -154,17 +160,20 @@ def __init__( self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( - self._model_name, - CustomizationTechnique.RLVR.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), - sequence_length=self.sequence_length + self.hyperparameters, self._model_arn, is_gated_model = ( + _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLVR.value, + self.training_type, + self.sagemaker_session + or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length, + ) ) # Remove constructor-handled hyperparameters self._process_hyperparameters() - + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -172,28 +181,34 @@ def _process_hyperparameters(self): """Remove hyperparameter keys that are handled by constructor inputs.""" if self.hyperparameters: # Remove keys that are handled by constructor inputs - if hasattr(self.hyperparameters, 'data_s3_path'): - delattr(self.hyperparameters, 'data_s3_path') - self.hyperparameters._specs.pop('data_s3_path', None) - if hasattr(self.hyperparameters, 'reward_lambda_arn'): - delattr(self.hyperparameters, 'reward_lambda_arn') - self.hyperparameters._specs.pop('reward_lambda_arn', None) - if hasattr(self.hyperparameters, 'preset_reward_function'): - delattr(self.hyperparameters, 'preset_reward_function') - self.hyperparameters._specs.pop('preset_reward_function', None) - if hasattr(self.hyperparameters, 'data_path'): - delattr(self.hyperparameters, 'data_path') - self.hyperparameters._specs.pop('data_path', None) - if hasattr(self.hyperparameters, 'validation_data_path'): - delattr(self.hyperparameters, 'validation_data_path') - self.hyperparameters._specs.pop('validation_data_path', None) - if hasattr(self.hyperparameters, 'output_path'): - delattr(self.hyperparameters, 'output_path') - self.hyperparameters._specs.pop('output_path', None) + if hasattr(self.hyperparameters, "data_s3_path"): + delattr(self.hyperparameters, "data_s3_path") + self.hyperparameters._specs.pop("data_s3_path", None) + if hasattr(self.hyperparameters, "reward_lambda_arn"): + delattr(self.hyperparameters, "reward_lambda_arn") + self.hyperparameters._specs.pop("reward_lambda_arn", None) + if hasattr(self.hyperparameters, "preset_reward_function"): + delattr(self.hyperparameters, "preset_reward_function") + self.hyperparameters._specs.pop("preset_reward_function", None) + if hasattr(self.hyperparameters, "data_path"): + delattr(self.hyperparameters, "data_path") + self.hyperparameters._specs.pop("data_path", None) + if hasattr(self.hyperparameters, "validation_data_path"): + delattr(self.hyperparameters, "validation_data_path") + self.hyperparameters._specs.pop("validation_data_path", None) + if hasattr(self.hyperparameters, "output_path"): + delattr(self.hyperparameters, "output_path") + self.hyperparameters._specs.pop("output_path", None) @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, - validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): + def train( + self, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + wait: bool = True, + wait_timeout: Optional[int] = None, + poll: int = 5, + ): """Execute the RLVR training job. Parameters: @@ -226,20 +241,24 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, logger.info(f"Training Job Name: {current_training_job_name}") - #data - input_data_config = _create_input_data_config(training_dataset or self.training_dataset, - validation_dataset or self.validation_dataset - ) + # data + input_data_config = _create_input_data_config( + training_dataset or self.training_dataset, validation_dataset or self.validation_dataset + ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, - kms_key_id=self.kms_key_id + kms_key_id=self.kms_key_id, ) # Extract and validate evaluator ARN - evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None + evaluator_arn = ( + _extract_evaluator_arn(self.custom_reward_function) + if self.custom_reward_function + else None + ) serverless_config = _create_serverless_config( model_arn=self._model_arn, customization_technique=CustomizationTechnique.RLVR.value, @@ -247,7 +266,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, accept_eula=self.accept_eula, evaluator_arn=evaluator_arn, sequence_length=self.sequence_length, - job_type=JOB_TYPE + job_type=JOB_TYPE, ) mlflow_config = _create_mlflow_config( sagemaker_session, @@ -257,14 +276,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, ) final_hyperparameters = self.hyperparameters.to_dict() - + # Validate hyperparameter values _validate_hyperparameter_values(final_hyperparameters) model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) vpc_config = self.networking if self.networking else None @@ -285,7 +304,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, "region": sagemaker_session.boto_session.region_name, "tags": tags, } - + # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition @@ -299,11 +318,12 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError + try: wait_kwargs = {} if wait_timeout is not None: - wait_kwargs['timeout'] = wait_timeout - wait_kwargs['poll'] = poll + wait_kwargs["timeout"] = wait_timeout + wait_kwargs["poll"] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index e2193f0b9b..9d47e9742a 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -20,7 +20,7 @@ _create_mlflow_config, _create_model_package_config, _validate_eula_for_gated_model, - _validate_hyperparameter_values + _validate_hyperparameter_values, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -55,19 +55,19 @@ class SFTTrainer(BaseTrainer): model="meta-llama/Llama-2-7b-hf", model_package_group="my-fine-tuned-models" ) - + # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/train.jsonl", wait=False ) - + # Wait for completion training_job.wait() - + # Refresh job status training_job.refresh() - + # Get the fine-tuned model artifacts ARN model_package_arn = training_job.output_model_package_arn @@ -127,13 +127,14 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - + # Resolve model and model name self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type - self.model_package_group = _validate_and_resolve_model_package_group(model, - model_package_group) + self.model_package_group = _validate_and_resolve_model_package_group( + model, model_package_group + ) self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name self.mlflow_run_name = mlflow_run_name @@ -146,17 +147,20 @@ def __init__( self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( - self._model_name, - CustomizationTechnique.SFT.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), - sequence_length=self.sequence_length + self.hyperparameters, self._model_arn, is_gated_model = ( + _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.SFT.value, + self.training_type, + self.sagemaker_session + or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length, + ) ) # Process hyperparameters self._process_hyperparameters() - + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -164,30 +168,37 @@ def _process_hyperparameters(self): """Remove hyperparameter keys that are handled by constructor inputs.""" if self.hyperparameters: # Remove keys that are handled by constructor inputs - if hasattr(self.hyperparameters, 'data_path'): - delattr(self.hyperparameters, 'data_path') - self.hyperparameters._specs.pop('data_path', None) - if hasattr(self.hyperparameters, 'output_path'): - delattr(self.hyperparameters, 'output_path') - self.hyperparameters._specs.pop('output_path', None) - if hasattr(self.hyperparameters, 'data_s3_path'): - delattr(self.hyperparameters, 'data_s3_path') - self.hyperparameters._specs.pop('data_s3_path', None) - if hasattr(self.hyperparameters, 'output_s3_path'): - delattr(self.hyperparameters, 'output_s3_path') - self.hyperparameters._specs.pop('output_s3_path', None) - if hasattr(self.hyperparameters, 'training_data_name'): - delattr(self.hyperparameters, 'training_data_name') - self.hyperparameters._specs.pop('training_data_name', None) - if hasattr(self.hyperparameters, 'validation_data_name'): - delattr(self.hyperparameters, 'validation_data_name') - self.hyperparameters._specs.pop('validation_data_name', None) - if hasattr(self.hyperparameters, 'validation_data_path'): - delattr(self.hyperparameters, 'validation_data_path') - self.hyperparameters._specs.pop('validation_data_path', None) + if hasattr(self.hyperparameters, "data_path"): + delattr(self.hyperparameters, "data_path") + self.hyperparameters._specs.pop("data_path", None) + if hasattr(self.hyperparameters, "output_path"): + delattr(self.hyperparameters, "output_path") + self.hyperparameters._specs.pop("output_path", None) + if hasattr(self.hyperparameters, "data_s3_path"): + delattr(self.hyperparameters, "data_s3_path") + self.hyperparameters._specs.pop("data_s3_path", None) + if hasattr(self.hyperparameters, "output_s3_path"): + delattr(self.hyperparameters, "output_s3_path") + self.hyperparameters._specs.pop("output_s3_path", None) + if hasattr(self.hyperparameters, "training_data_name"): + delattr(self.hyperparameters, "training_data_name") + self.hyperparameters._specs.pop("training_data_name", None) + if hasattr(self.hyperparameters, "validation_data_name"): + delattr(self.hyperparameters, "validation_data_name") + self.hyperparameters._specs.pop("validation_data_name", None) + if hasattr(self.hyperparameters, "validation_data_path"): + delattr(self.hyperparameters, "validation_data_path") + self.hyperparameters._specs.pop("validation_data_path", None) @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="SFTTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): + def train( + self, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + wait: bool = True, + wait_timeout: Optional[int] = None, + poll: int = 5, + ): """Execute the SFT training job. Parameters: @@ -220,16 +231,16 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati logger.info(f"Training Job Name: {current_training_job_name}") - #data - input_data_config = _create_input_data_config(training_dataset or self.training_dataset, - validation_dataset or self.validation_dataset - ) + # data + input_data_config = _create_input_data_config( + training_dataset or self.training_dataset, validation_dataset or self.validation_dataset + ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, - kms_key_id=self.kms_key_id + kms_key_id=self.kms_key_id, ) serverless_config = _create_serverless_config( @@ -238,7 +249,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati training_type=self.training_type, accept_eula=self.accept_eula, sequence_length=self.sequence_length, - job_type=JOB_TYPE + job_type=JOB_TYPE, ) mlflow_config = _create_mlflow_config( sagemaker_session, @@ -248,14 +259,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) final_hyperparameters = self.hyperparameters.to_dict() - + # Validate hyperparameter values _validate_hyperparameter_values(final_hyperparameters) model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) vpc_config = self.networking if self.networking else None @@ -276,7 +287,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati "region": sagemaker_session.boto_session.region_name, "tags": tags, } - + # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition @@ -290,16 +301,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError - try : + + try: wait_kwargs = {} if wait_timeout is not None: - wait_kwargs['timeout'] = wait_timeout - wait_kwargs['poll'] = poll + wait_kwargs["timeout"] = wait_timeout + wait_kwargs["poll"] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) self._latest_training_job = training_job return training_job - - diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py index eb7ebda1e3..b09b608ae3 100644 --- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py @@ -21,11 +21,12 @@ from sagemaker.train.sft_trainer import SFTTrainer from sagemaker.train.common import TrainingType + @pytest.mark.gpu_intensive def test_sft_trainer_lora_complete_workflow(sagemaker_session): """Test complete SFT training workflow with LORA.""" unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" - + sft_trainer = SFTTrainer( model="meta-textgeneration-llama-3-2-1b-instruct", training_type=TrainingType.LORA, @@ -35,27 +36,27 @@ def test_sft_trainer_lora_complete_workflow(sagemaker_session): accept_eula=True, base_job_name=f"sft-lora-integ-{unique_id}", ) - + # Create training job training_job = sft_trainer.train(wait=False) - + # Manual wait loop to avoid resource_config issue max_wait_time = 3600 # 1 hour timeout - poll_interval = 30 # Check every 30 seconds + poll_interval = 30 # Check every 30 seconds start_time = time.time() - + while time.time() - start_time < max_wait_time: training_job.refresh() status = training_job.training_job_status - + if status in ["Completed", "Failed", "Stopped"]: break - + time.sleep(poll_interval) - + # Verify job completed successfully assert training_job.training_job_status == "Completed" - assert hasattr(training_job, 'output_model_package_arn') + assert hasattr(training_job, "output_model_package_arn") assert training_job.output_model_package_arn is not None @@ -73,26 +74,26 @@ def test_sft_trainer_with_validation_dataset(sagemaker_session): accept_eula=True, base_job_name=f"sft-val-integ-{unique_id}", ) - + training_job = sft_trainer.train(wait=False) - + # Manual wait loop max_wait_time = 3600 poll_interval = 30 start_time = time.time() - + while time.time() - start_time < max_wait_time: training_job.refresh() status = training_job.training_job_status - + if status in ["Completed", "Failed", "Stopped"]: break - + time.sleep(poll_interval) - + # Verify job completed successfully assert training_job.training_job_status == "Completed" - assert hasattr(training_job, 'output_model_package_arn') + assert hasattr(training_job, "output_model_package_arn") @pytest.mark.gpu_intensive @@ -104,7 +105,7 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" sft_trainer_nova = SFTTrainer( model="nova-textgeneration-lite-v2", - training_type=TrainingType.LORA, + training_type=TrainingType.LORA, model_package_group="sdk-test-finetuned-models", mlflow_experiment_name="test-nova-finetuned-models-exp", mlflow_run_name="test-nova-finetuned-models-run", @@ -113,27 +114,27 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): sagemaker_session=sagemaker_session_us_east_1, base_job_name=f"sft-nova-integ-{unique_id}", ) - + # Create training job training_job = sft_trainer_nova.train(wait=False) - + # Manual wait loop max_wait_time = 3600 # 1 hour timeout - poll_interval = 30 # Check every 30 seconds + poll_interval = 30 # Check every 30 seconds start_time = time.time() - + while time.time() - start_time < max_wait_time: training_job.refresh() status = training_job.training_job_status - + if status in ["Completed", "Failed", "Stopped"]: break - + time.sleep(poll_interval) - + # Verify job completed successfully assert training_job.training_job_status == "Completed" - assert hasattr(training_job, 'output_model_package_arn') + assert hasattr(training_job, "output_model_package_arn") assert training_job.output_model_package_arn is not None @@ -169,5 +170,5 @@ def test_sft_trainer_lora_with_sequence_length(sagemaker_session): time.sleep(poll_interval) assert training_job.training_job_status == "Completed" - assert hasattr(training_job, 'output_model_package_arn') + assert hasattr(training_job, "output_model_package_arn") assert training_job.output_model_package_arn is not None diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 50ae26a800..f1e92ac025 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -28,10 +28,9 @@ _validate_eula_for_gated_model, _validate_model_region_availability, _validate_s3_path_exists, - _parse_context_length + _parse_context_length, ) from sagemaker.core.resources import ModelPackage, ModelPackageGroup -from sagemaker.core.utils.utils import Unassigned from sagemaker.ai_registry.dataset import DataSet from sagemaker.train.common import TrainingType from sagemaker.train.configs import InputData @@ -39,47 +38,53 @@ class TestFinetuneUtils: - @patch('sagemaker.train.common_utils.finetune_utils.boto3.client') - @patch('sagemaker.train.common_utils.finetune_utils.Session') + @patch("sagemaker.train.common_utils.finetune_utils.boto3.client") + @patch("sagemaker.train.common_utils.finetune_utils.Session") def test__get_beta_session(self, mock_session, mock_boto_client): mock_client = Mock() mock_boto_client.return_value = mock_client mock_sagemaker_session = Mock() mock_session.return_value = mock_sagemaker_session - + result = _get_beta_session() - + assert result == mock_sagemaker_session mock_boto_client.assert_called_once() def test_get_current_domain_id_with_studio_arn(self): mock_session = Mock() - mock_session.get_caller_identity_arn.return_value = "arn:aws:sts::123456789012:assumed-role/SageMakerStudioExecutionRole/SageMaker" - + mock_session.get_caller_identity_arn.return_value = ( + "arn:aws:sts::123456789012:assumed-role/SageMakerStudioExecutionRole/SageMaker" + ) + result = _get_current_domain_id(mock_session) - + assert result is None def test_get_current_domain_id_with_domain_arn(self): mock_session = Mock() - mock_session.get_caller_identity_arn.return_value = "arn:aws:sagemaker:us-east-1:123456789012:user-profile/d-123456789/test-user" - + mock_session.get_caller_identity_arn.return_value = ( + "arn:aws:sagemaker:us-east-1:123456789012:user-profile/d-123456789/test-user" + ) + result = _get_current_domain_id(mock_session) - + assert result == "d-123456789" def test__resolve_mlflow_resource_arn_with_provided_arn(self): mock_session = Mock() provided_arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/test" - + result = _resolve_mlflow_resource_arn(mock_session, provided_arn) - + assert result == provided_arn - @patch('sagemaker.train.common_utils.finetune_utils._get_current_domain_id') - @patch('sagemaker.train.common_utils.finetune_utils._create_mlflow_app') - @patch('sagemaker.core.resources.MlflowApp.get_all') - def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_all, mock_create_app, mock_get_domain): + @patch("sagemaker.train.common_utils.finetune_utils._get_current_domain_id") + @patch("sagemaker.train.common_utils.finetune_utils._create_mlflow_app") + @patch("sagemaker.core.resources.MlflowApp.get_all") + def test__resolve_mlflow_resource_arn_creates_new_app( + self, mock_get_all, mock_create_app, mock_get_domain + ): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" mock_get_domain.return_value = "d-123456789" @@ -87,13 +92,13 @@ def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_all, mock_c mock_app = Mock() mock_app.arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/new-app" mock_create_app.return_value = mock_app - + result = _resolve_mlflow_resource_arn(mock_session, None) - + assert result == mock_app.arn - @patch('sagemaker.train.common_utils.finetune_utils.TrainDefaults.get_role') - @patch('sagemaker.core.resources.MlflowApp.create') + @patch("sagemaker.train.common_utils.finetune_utils.TrainDefaults.get_role") + @patch("sagemaker.core.resources.MlflowApp.create") def test_create_mlflow_app_success(self, mock_create, mock_get_role): mock_session = Mock() mock_session.region_name = "us-east-1" @@ -101,63 +106,67 @@ def test_create_mlflow_app_success(self, mock_create, mock_get_role): mock_sts_client.get_caller_identity.return_value = {"Account": "123456789012"} mock_s3_client = Mock() mock_s3_client.list_objects_v2.return_value = {"Contents": [{"Key": "mlflow-artifacts/"}]} - + def mock_client(service_name): - if service_name == 'sts': + if service_name == "sts": return mock_sts_client - elif service_name == 's3': + elif service_name == "s3": return mock_s3_client return Mock() - + mock_session.boto_session.client.side_effect = mock_client mock_session.boto_session.region_name = "us-east-1" mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role" mock_app = Mock() mock_app.status = "Created" mock_create.return_value = mock_app - + result = _create_mlflow_app(mock_session) - + assert result == mock_app mock_create.assert_called_once() mock_app.refresh.assert_called() - @patch('sagemaker.core.resources.MlflowApp.create') + @patch("sagemaker.core.resources.MlflowApp.create") def test_create_mlflow_app_failure(self, mock_create): mock_session = Mock() mock_create.side_effect = Exception("Creation failed") - + result = _create_mlflow_app(mock_session) - + assert result is None def test__validate_dataset_arn_valid(self): valid_arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test-dataset/1.0" - + # Should not raise exception _validate_dataset_arn(valid_arn, "test_dataset") def test__validate_dataset_arn_invalid(self): invalid_arn = "invalid-arn" - - with pytest.raises(ValueError, match="test_dataset must be a valid SageMaker hub-content DataSet ARN"): + + with pytest.raises( + ValueError, match="test_dataset must be a valid SageMaker hub-content DataSet ARN" + ): _validate_dataset_arn(invalid_arn, "test_dataset") def test_validate_evaluator_arn_valid(self): valid_arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test-evaluator/1.0" - + # Should not raise exception _validate_evaluator_arn(valid_arn, "test_evaluator") def test_validate_evaluator_arn_invalid(self): invalid_arn = "invalid-arn" - - with pytest.raises(ValueError, match="test_evaluator must be a valid SageMaker hub-content evaluator ARN"): + + with pytest.raises( + ValueError, match="test_evaluator must be a valid SageMaker hub-content evaluator ARN" + ): _validate_evaluator_arn(invalid_arn, "test_evaluator") def test__validate_model_package_group_requirement_with_model_package(self): model_package = Mock(spec=ModelPackage) - + # Should not raise exception _validate_model_package_group_requirement(model_package, None) @@ -165,33 +174,37 @@ def test__validate_model_package_group_requirement_without_group_name(self): with pytest.raises(ValueError, match="model_package_group_name must be provided"): _validate_model_package_group_requirement("string-model", None) - @patch('sagemaker.core.resources.ModelPackageGroup.get') + @patch("sagemaker.core.resources.ModelPackageGroup.get") def test__resolve_model_package_group_arn_with_name(self, mock_get): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" mock_group = Mock() - mock_group.model_package_group_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + mock_group.model_package_group_arn = ( + "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + ) mock_get.return_value = mock_group - + result = _resolve_model_package_group_arn("test-group", mock_session) - + assert result == mock_group.model_package_group_arn def test__resolve_model_package_group_arn_with_arn(self): mock_session = Mock() arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" - + result = _resolve_model_package_group_arn(arn, mock_session) - + assert result == arn def test__resolve_model_package_group_arn_with_object(self): mock_session = Mock() mock_group = Mock(spec=ModelPackageGroup) - mock_group.model_package_group_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" - + mock_group.model_package_group_arn = ( + "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + ) + result = _resolve_model_package_group_arn(mock_group, mock_session) - + assert result == mock_group.model_package_group_arn def test__get_default_s3_output_path(self): @@ -200,50 +213,50 @@ def test__get_default_s3_output_path(self): mock_sts_client.get_caller_identity.return_value = {"Account": "123456789012"} mock_session.boto_session.client.return_value = mock_sts_client mock_session.boto_session.region_name = "us-east-1" - + result = _get_default_s3_output_path(mock_session) - + assert result == "s3://sagemaker-us-east-1-123456789012/output" def test__extract_dataset_source_s3_uri(self): s3_uri = "s3://bucket/dataset" - + result = _extract_dataset_source(s3_uri, "test_dataset") - + assert result == s3_uri - @patch('sagemaker.train.common_utils.finetune_utils._validate_dataset_arn') + @patch("sagemaker.train.common_utils.finetune_utils._validate_dataset_arn") def test__extract_dataset_source_arn(self, mock_validate): arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0" - + result = _extract_dataset_source(arn, "test_dataset") - + assert result == arn mock_validate.assert_called_once_with(arn, "test_dataset") def test__extract_dataset_source_dataset_object(self): mock_dataset = Mock(spec=DataSet) mock_dataset.arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0" - + result = _extract_dataset_source(mock_dataset, "test_dataset") - + assert result == mock_dataset.arn - @patch('sagemaker.train.common_utils.finetune_utils._validate_evaluator_arn') + @patch("sagemaker.train.common_utils.finetune_utils._validate_evaluator_arn") def test_extract_evaluator_arn_string(self, mock_validate): arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test/1.0" - + result = _extract_evaluator_arn(arn, "test_evaluator") - + assert result == arn mock_validate.assert_called_once_with(arn, "test_evaluator") def test_extract_evaluator_arn_object(self): mock_evaluator = Mock() mock_evaluator.arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test/1.0" - + result = _extract_evaluator_arn(mock_evaluator, "test_evaluator") - + assert result == mock_evaluator.arn def test__resolve_model_name_with_model_package(self): @@ -253,9 +266,9 @@ def test__resolve_model_name_with_model_package(self): mock_base_model.hub_content_name = "test-model" mock_container.base_model = mock_base_model mock_model_package.inference_specification.containers = [mock_container] - + result = _resolve_model_name(mock_model_package) - + assert result == "test-model" def test__resolve_model_name_with_none(self): @@ -266,41 +279,41 @@ def test__resolve_model_package_arn_success(self): mock_model_package = Mock() expected_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/test-package" mock_model_package.model_package_arn = expected_arn - + result = _resolve_model_package_arn(mock_model_package) - + assert result == expected_arn def test__resolve_model_package_arn_failure(self): mock_model_package = Mock() mock_model_package.model_package_arn = None - + result = _resolve_model_package_arn(mock_model_package) - + assert result is None - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - @patch('boto3.client') + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + @patch("boto3.client") def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get_hub_content): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" - + # Mock hub content metadata mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template.json", "SmtjOverrideParamsS3Uri": "s3://bucket/params.json", - "Peft": True + "Peft": True, } - ] - } + ], + }, } - + # Mock S3 client mock_s3_client = Mock() mock_boto_client.return_value = mock_s3_client @@ -309,9 +322,9 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get } mock_session.boto_session.client.return_value = mock_s3_client mock_session.boto_session.client.return_value = mock_s3_client - + result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session) - + # Handle case where function might return None if result is not None: options, model_arn, is_gated_model = result @@ -324,7 +337,7 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get def test_create_input_channels_s3_uri(self): result = _create_input_channels("s3://bucket/data", "application/json") - + assert len(result) == 1 assert result[0].channel_name == "train" assert result[0].data_source.s3_data_source.s3_uri == "s3://bucket/data" @@ -332,9 +345,9 @@ def test_create_input_channels_s3_uri(self): def test_create_input_channels_dataset_arn(self): arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0" - + result = _create_input_channels(arn) - + assert len(result) == 1 assert result[0].channel_name == "train" assert result[0].data_source.dataset_source.dataset_arn == arn @@ -342,24 +355,24 @@ def test_create_input_channels_dataset_arn(self): def test__validate_and_resolve_model_package_group_with_provided_name(self): model = "test-model" group_name = "test-group" - + result = _validate_and_resolve_model_package_group(model, group_name) - + assert result == group_name def test__validate_and_resolve_model_package_group_from_model_package(self): mock_model = Mock(spec=ModelPackage) mock_model.model_package_group_name = "extracted-group" - + result = _validate_and_resolve_model_package_group(mock_model, None) - + assert result == "extracted-group" def test__validate_and_resolve_model_package_group_missing_both(self): with pytest.raises(ValueError, match="model_package_group_name must be provided"): _validate_and_resolve_model_package_group("string-model", None) - @patch('sagemaker.core.resources.ModelPackage.get') + @patch("sagemaker.core.resources.ModelPackage.get") def test__resolve_model_and_name_with_model_package_arn(self, mock_get): mock_session = Mock() mock_session.boto_region_name = "us-east-1" # Set valid region @@ -371,15 +384,17 @@ def test__resolve_model_and_name_with_model_package_arn(self, mock_get): mock_model_package.inference_specification = Mock() mock_model_package.inference_specification.containers = [mock_container] mock_get.return_value = mock_model_package - - model, name = _resolve_model_and_name("arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session) - + + model, name = _resolve_model_and_name( + "arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session + ) + assert model == mock_model_package assert name == "test-model" def test__resolve_model_and_name_with_string(self): model, name = _resolve_model_and_name("test-model") - + assert model == "test-model" assert name == "test-model" @@ -391,15 +406,15 @@ def test__resolve_model_and_name_with_model_package_object(self): mock_container.base_model = mock_base_model mock_model_package.inference_specification = Mock() mock_model_package.inference_specification.containers = [mock_container] - + model, name = _resolve_model_and_name(mock_model_package) - + assert model == mock_model_package assert name == "test-model" def test__create_serverless_config_with_lora(self): config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) - + assert config.job_type == "FineTuning" assert config.base_model_arn == "model-arn" assert config.customization_technique == "SFT" @@ -407,14 +422,13 @@ def test__create_serverless_config_with_lora(self): def test__create_serverless_config_with_full(self): config = _create_serverless_config("model-arn", "SFT", TrainingType.FULL, accept_eula=True) - + assert config.peft is None def test__create_input_data_config(self): - config = _create_input_data_config("s3://bucket/train", "s3://bucket/val") - + assert len(config) == 2 assert config[0].channel_name == "train" assert config[1].channel_name == "validation" @@ -423,30 +437,34 @@ def test__create_model_package_config(self): mock_session = Mock() mock_model = Mock(spec=ModelPackage) mock_model.model_package_arn = "source-arn" - - with patch('sagemaker.train.common_utils.finetune_utils._resolve_model_package_group_arn') as mock_resolve: + + with patch( + "sagemaker.train.common_utils.finetune_utils._resolve_model_package_group_arn" + ) as mock_resolve: mock_resolve.return_value = "group-arn" config = _create_model_package_config("test-group", mock_model, mock_session) - + assert config.model_package_group_arn == "group-arn" assert config.source_model_package_arn == "source-arn" def test__create_mlflow_config(self): mock_session = Mock() - - with patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn') as mock_resolve: + + with patch( + "sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn" + ) as mock_resolve: mock_resolve.return_value = "mlflow-arn" config = _create_mlflow_config(mock_session, mlflow_experiment_name="test-exp") - + assert config.mlflow_resource_arn == "mlflow-arn" assert config.mlflow_experiment_name == "test-exp" - @patch('sagemaker.train.common_utils.finetune_utils._validate_s3_path_exists') + @patch("sagemaker.train.common_utils.finetune_utils._validate_s3_path_exists") def test__create_output_config(self, mock_validate_s3): mock_session = Mock() - + config = _create_output_config(mock_session, "s3://bucket/output", "kms-key") - + assert config.s3_output_path == "s3://bucket/output" assert config.kms_key_id == "kms-key" mock_validate_s3.assert_called_once_with("s3://bucket/output", mock_session) @@ -455,22 +473,23 @@ def test__convert_input_data_to_channels(self): input_data = [InputData(channel_name="train", data_source="s3://bucket/data")] channels = _convert_input_data_to_channels(input_data) - + assert len(channels) == 1 assert channels[0].channel_name == "train" def test__validate_eula_for_gated_model_with_model_package(self): """Test EULA validation returns True for ModelPackage input""" from sagemaker.core.resources import ModelPackage + model_package = Mock(spec=ModelPackage) - + result = _validate_eula_for_gated_model(model_package, False, True) assert result == True def test__validate_eula_for_gated_model_with_arn(self): """Test EULA validation returns True for ARN input""" model_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/test/1" - + result = _validate_eula_for_gated_model(model_arn, False, True) assert result == True @@ -497,7 +516,9 @@ def test__validate_model_region_availability_nova_valid_region(self): def test__validate_model_region_availability_nova_invalid_region(self): """Test Nova model validation fails for invalid region""" - with pytest.raises(ValueError, match="Region 'eu-west-1' does not support model customization"): + with pytest.raises( + ValueError, match="Region 'eu-west-1' does not support model customization" + ): _validate_model_region_availability("nova-textgeneration-lite-v2", "eu-west-1") def test__validate_model_region_availability_open_weights_valid_region(self): @@ -507,57 +528,63 @@ def test__validate_model_region_availability_open_weights_valid_region(self): def test__validate_model_region_availability_open_weights_invalid_region(self): """Test open weights model validation fails for invalid region""" - with pytest.raises(ValueError, match="Region 'us-west-1' does not support model customization"): + with pytest.raises( + ValueError, match="Region 'us-west-1' does not support model customization" + ): _validate_model_region_availability("meta-textgeneration-llama-3-2-1b", "us-west-1") def test__validate_s3_path_exists_invalid_format(self): """Test S3 path validation fails for invalid format""" mock_session = Mock() - + with pytest.raises(ValueError, match="Invalid S3 path format"): _validate_s3_path_exists("invalid-path", mock_session) - @patch('boto3.client') + @patch("boto3.client") def test__validate_s3_path_exists_bucket_only_success(self, mock_boto_client): """Test S3 path validation succeeds for bucket-only path""" mock_session = Mock() mock_s3_client = Mock() mock_session.boto_session.client.return_value = mock_s3_client - + _validate_s3_path_exists("s3://test-bucket", mock_session) - + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") - @patch('boto3.client') + @patch("boto3.client") def test__validate_s3_path_exists_with_prefix_exists(self, mock_boto_client): """Test S3 path validation succeeds when prefix exists""" mock_session = Mock() mock_s3_client = Mock() mock_session.boto_session.client.return_value = mock_s3_client mock_s3_client.list_objects_v2.return_value = {"Contents": [{"Key": "prefix/file.txt"}]} - + _validate_s3_path_exists("s3://test-bucket/prefix/", mock_session) - + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") - mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix/", MaxKeys=1) + mock_s3_client.list_objects_v2.assert_called_once_with( + Bucket="test-bucket", Prefix="prefix/", MaxKeys=1 + ) - @patch('boto3.client') + @patch("boto3.client") def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client): """Test S3 path validation creates prefix when it doesn't exist""" mock_session = Mock() mock_s3_client = Mock() mock_session.boto_session.client.return_value = mock_s3_client mock_s3_client.list_objects_v2.return_value = {} # No contents - - _validate_s3_path_exists("s3://test-bucket/prefix", mock_session) - - mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") - mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix", MaxKeys=1) - mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'') + _validate_s3_path_exists("s3://test-bucket/prefix", mock_session) + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") + mock_s3_client.list_objects_v2.assert_called_once_with( + Bucket="test-bucket", Prefix="prefix", MaxKeys=1 + ) + mock_s3_client.put_object.assert_called_once_with( + Bucket="test-bucket", Key="prefix/", Body=b"" + ) - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_get_hub_content): """When and user is subscribed, datamix HPs are available.""" mock_session = Mock() @@ -565,34 +592,40 @@ def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_ge mock_s3 = Mock() mock_sts = Mock() mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} - mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + mock_session.boto_session.client.side_effect = lambda service, **kwargs: ( + mock_s3 if service == "s3" else mock_sts + ) mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", - "Name": "standard_sft" + "Name": "standard_sft", }, { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/template.yaml", "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", "Name": "datamix_sft", - "IsSubscriptionModel": True - } - ] - } + "IsSubscriptionModel": True, + }, + ], + }, } # Standard recipe returns base params - standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + standard_params = json.dumps( + {"max_steps": {"type": "integer", "required": True, "default": 100}} + ) # Subscription recipe returns datamix params - datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}}) + datamix_params = json.dumps( + {"customer_data_percent": {"type": "integer", "required": False, "default": 50}} + ) mock_s3.get_object.side_effect = [ {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, @@ -600,15 +633,22 @@ def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_ge ] options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "test-model", "SFT", "FULL", mock_session, + "test-model", + "SFT", + "FULL", + mock_session, ) assert "max_steps" in options._specs assert "customer_data_percent" in options._specs - assert options._specs["customer_data_percent"]["default"] is None # defaults are None so they dont serialize unless explicitly set - - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, mock_get_hub_content): + assert ( + options._specs["customer_data_percent"]["default"] is None + ) # defaults are None so they dont serialize unless explicitly set + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps( + self, mock_get_hub_content + ): """When (default), datamix HPs are NOT available.""" mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" @@ -616,70 +656,83 @@ def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, moc mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", - "Name": "standard_sft" + "Name": "standard_sft", }, { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml", "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", "Name": "datamix_sft", - "IsSubscriptionModel": True - } - ] - } + "IsSubscriptionModel": True, + }, + ], + }, } - standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) - mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=standard_params.encode()))} + standard_params = json.dumps( + {"max_steps": {"type": "integer", "required": True, "default": 100}} + ) + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=standard_params.encode())) + } options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "test-model", "SFT", "FULL", mock_session, + "test-model", + "SFT", + "FULL", + mock_session, ) assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, mock_get_hub_content): + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed( + self, mock_get_hub_content + ): """When but user is NOT subscribed, falls back gracefully.""" mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" mock_s3 = Mock() mock_sts = Mock() mock_sts.get_caller_identity.return_value = {"Account": "999999999999"} - mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + mock_session.boto_session.client.side_effect = lambda service, **kwargs: ( + mock_s3 if service == "s3" else mock_sts + ) mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", - "Name": "standard_sft" + "Name": "standard_sft", }, { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml", "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", "Name": "datamix_sft", - "IsSubscriptionModel": True - } - ] - } + "IsSubscriptionModel": True, + }, + ], + }, } - standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + standard_params = json.dumps( + {"max_steps": {"type": "integer", "required": True, "default": 100}} + ) # First call succeeds (standard recipe), second call fails (access denied) mock_s3.get_object.side_effect = [ {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, @@ -687,7 +740,10 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, ] options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "test-model", "SFT", "FULL", mock_session, + "test-model", + "SFT", + "FULL", + mock_session, ) # Should still have standard params, just not datamix ones @@ -695,7 +751,9 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, assert "customer_data_percent" not in options._specs def test__create_serverless_config_with_sequence_length(self): - config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K") + config = _create_serverless_config( + "model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K" + ) assert config.sequence_length == "8K" assert config.base_model_arn == "model-arn" @@ -722,7 +780,7 @@ def test__parse_context_length_with_none(self): def test__parse_context_length_with_empty(self): assert _parse_context_length("") == 0 - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" @@ -733,8 +791,8 @@ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_ mock_session.boto_session.client.return_value = mock_s3 mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { @@ -742,20 +800,22 @@ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", "Peft": True, - "SequenceLength": "4K" + "SequenceLength": "4K", }, { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json", "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json", "Peft": True, - "SequenceLength": "32K" - } - ] - } + "SequenceLength": "32K", + }, + ], + }, } - result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K") + result = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "LORA", mock_session, sequence_length="8K" + ) if result is not None: options, model_arn, is_gated_model = result @@ -764,14 +824,16 @@ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_ call_args = mock_s3.get_object.call_args[1] assert "params-32k" in call_args["Key"] - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(self, mock_get_hub_content): + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_raises_when_no_sufficient_context_length( + self, mock_get_hub_content + ): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { @@ -779,12 +841,14 @@ def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(self, "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", "Peft": True, - "SequenceLength": "4K" + "SequenceLength": "4K", } - ] - } + ], + }, } # Requesting 128K but only 4K available — should raise with pytest.raises(ValueError, match="No recipes found with SequenceLength >= 128K"): - _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="128K") + _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "LORA", mock_session, sequence_length="128K" + ) From 973b62d1ab30cca6c7b1c39a61a759b5644955db Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 22:29:07 +0000 Subject: [PATCH 9/9] test: add more unit tests for sequence_length coverage - _parse_context_length with invalid K value and non-numeric string - sequence_length filter when no recipes have SequenceLength field - sequence_length filter with FULL training type - verify smallest sufficient sequence_length is selected - verify no sequence_length uses first recipe (backward compat) --- .../train/common_utils/test_finetune_utils.py | 174 ++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index f1e92ac025..0f5ff6f92e 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -852,3 +852,177 @@ def test__get_fine_tuning_options_raises_when_no_sufficient_context_length( _get_fine_tuning_options_and_model_arn( "test-model", "SFT", "LORA", mock_session, sequence_length="128K" ) + + def test__parse_context_length_with_invalid_k_value(self): + assert _parse_context_length("abcK") == 0 + + def test__parse_context_length_with_non_numeric_string(self): + assert _parse_context_length("hello") == 0 + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_raises_when_no_recipes_have_sequence_length( + self, mock_get_hub_content + ): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params.json", + "Peft": True, + } + ], + }, + } + + with pytest.raises(ValueError, match="and sequence length"): + _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "LORA", mock_session, sequence_length="8K" + ) + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_filters_by_sequence_length_full_training( + self, mock_get_hub_content + ): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 8192}}')) + } + mock_session.boto_session.client.return_value = mock_s3 + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-8k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-8k.json", + "SequenceLength": "8K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json", + "SequenceLength": "32K", + }, + ], + }, + } + + result = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "FULL", mock_session, sequence_length="8K" + ) + + if result is not None: + options, model_arn, is_gated_model = result + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] + assert "params-8k" in call_args["Key"] + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_selects_smallest_sufficient_sequence_length( + self, mock_get_hub_content + ): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 16384}}')) + } + mock_session.boto_session.client.return_value = mock_s3 + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-16k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-16k.json", + "Peft": True, + "SequenceLength": "16K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-128k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-128k.json", + "Peft": True, + "SequenceLength": "128K", + }, + ], + }, + } + + result = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "LORA", mock_session, sequence_length="8K" + ) + + if result is not None: + options, model_arn, is_gated_model = result + # Should pick 16K (smallest >= 8K), not 128K + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] + assert "params-16k" in call_args["Key"] + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_without_sequence_length_uses_first_recipe( + self, mock_get_hub_content + ): + """Verify that when no sequence_length is provided, existing behavior is unchanged.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 4096}}')) + } + mock_session.boto_session.client.return_value = mock_s3 + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-first.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-first.json", + "Peft": True, + "SequenceLength": "4K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-second.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-second.json", + "Peft": True, + "SequenceLength": "32K", + }, + ], + }, + } + + result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session) + + if result is not None: + options, model_arn, is_gated_model = result + # Without sequence_length, should pick the first matching recipe + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] + assert "params-first" in call_args["Key"]