From d6488a439047ff7ab88c6047feae53776688a594 Mon Sep 17 00:00:00 2001 From: guanweim Date: Mon, 4 May 2026 19:03:03 +0000 Subject: [PATCH 1/7] feat(train): Add SequenceLength support for SFT, DPO, RLVR, RLAIF trainers Add optional sequence_length parameter to all four trainers that enables customers to specify their desired context length for serverless training jobs. The parameter is passed in ServerlessJobConfig for recipe filtering. During trainer initialization, _get_fine_tuning_options_and_model_arn filters recipes by SequenceLength field, picking the smallest recipe with context length >= the requested value. Raises ValueError if no sufficient recipe exists or if recipes lack SequenceLength metadata. Changes: - ServerlessJobConfig: add sequence_length field - _parse_context_length: parse values like '8K' to integers - _get_fine_tuning_options_and_model_arn: filter by SequenceLength - _create_serverless_config: conditionally include sequence_length - SFTTrainer, DPOTrainer, RLVRTrainer, RLAIFTrainer: accept and thread sequence_length through init and train methods - Unit tests for all new functionality --- .../src/sagemaker/core/shapes/shapes.py | 3 +- .../train/common_utils/finetune_utils.py | 79 +++++++++++-- .../src/sagemaker/train/dpo_trainer.py | 36 +++--- .../src/sagemaker/train/rlaif_trainer.py | 35 +++--- .../src/sagemaker/train/rlvr_trainer.py | 37 +++--- .../src/sagemaker/train/sft_trainer.py | 35 +++--- .../train/common_utils/test_finetune_utils.py | 105 +++++++++++++++++- .../tests/unit/train/test_dpo_trainer.py | 65 ++++++++++- .../tests/unit/train/test_rlaif_trainer.py | 68 +++++++++++- .../tests/unit/train/test_rlvr_trainer.py | 65 ++++++++++- .../tests/unit/train/test_sft_trainer.py | 65 ++++++++++- 11 files changed, 525 insertions(+), 68 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index ce25c890dd..5e8217f463 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -9717,6 +9717,7 @@ class ServerlessJobConfig(Base): peft: The parameter-efficient fine-tuning configuration. evaluation_type: The evaluation job type. Required when serverless job type is Evaluation. evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. + sequence_length: The sequence length for the training job. Valid values are "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". """ base_model_arn: StrPipeVar @@ -9726,7 +9727,7 @@ class ServerlessJobConfig(Base): peft: Optional[StrPipeVar] = Unassigned() evaluation_type: Optional[StrPipeVar] = Unassigned() evaluator_arn: Optional[StrPipeVar] = Unassigned() - + sequence_length: Optional[StrPipeVar] = Unassigned() class MlflowConfig(Base): """ diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 6479e803bd..a7f1570e05 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -407,10 +407,44 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: return None -def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session, - hub_name: Optional[str] = None) -> tuple: +def _parse_context_length(value) -> int: + """Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192). + + Returns 0 if value is None or unparseable. + """ + if not value: + return 0 + value = str(value).strip().upper() + if value.endswith("K"): + try: + return int(value[:-1]) * 1024 + except ValueError: + return 0 + try: + return int(value) + except ValueError: + return 0 + + +def _get_fine_tuning_options_and_model_arn( + model_name: str, + customization_technique: str, + training_type, + sagemaker_session, + sequence_length=None, + hub_name: str = "SageMakerPublicHub" +) -> tuple: """Get fine-tuning options and model ARN for given customization technique. + Args: + model_name: Name of the model in the hub. + customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF"). + training_type: TrainingType enum or string ("LORA", "FULL"). + sagemaker_session: SageMaker session for API calls. + sequence_length: Optional sequence length (e.g., "8K"). When provided, filters + recipes by MaxContextLength >= the requested value. + hub_name: Hub name (default: "SageMakerPublicHub"). + Returns: tuple: (FineTuningOptions, model_arn, is_gated_model) """ @@ -451,9 +485,34 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni # Collect override_params from ALL matching recipes (standard + subscription) recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) + candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] + else: + candidates = [] + + # Filter by SequenceLength if sequence_length is provided + if sequence_length and candidates: + requested = _parse_context_length(sequence_length) + candidates_with_context = [r for r in candidates if r.get("SequenceLength")] + if candidates_with_context: + filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested] + if filtered: + filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) + recipe = filtered[0] + else: + available = sorted(set(r.get("SequenceLength") for r in candidates_with_context)) + raise ValueError( + f"No recipes found with SequenceLength >= {sequence_length}. " + f"Available sequence lengths: {available}" + ) + else: + raise ValueError( + f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, " + f"and sequence length:{sequence_length}" + ) + elif candidates: + recipe = candidates[0] if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") @@ -608,7 +667,8 @@ def _resolve_model_and_name(model, sagemaker_session=None): def _create_serverless_config(model_arn, customization_technique, - training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: + training_type, accept_eula, evaluator_arn=None, + sequence_length=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: """Create serverless job configuration for fine-tuning. Args: @@ -617,6 +677,7 @@ def _create_serverless_config(model_arn, customization_technique, training_type: Training type (TrainingType enum or string) accept_eula: Boolean indicating if EULA is accepted evaluator_arn: Optional evaluator ARN for RLVR/RLAIF + sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K") job_type: Type of job (default: "FineTuning") Returns: @@ -626,14 +687,18 @@ def _create_serverless_config(model_arn, customization_technique, else (training_type.value if isinstance(training_type, TrainingType) else training_type) # Create ServerlessJobConfig using shapes - serverless_config = ServerlessJobConfig( + config_kwargs = dict( job_type=job_type, base_model_arn=model_arn, customization_technique=customization_technique, peft=peft, evaluator_arn=evaluator_arn, - accept_eula=accept_eula + accept_eula=accept_eula, ) + if sequence_length is not None: + config_kwargs["sequence_length"] = sequence_length + + serverless_config = ServerlessJobConfig(**config_kwargs) return serverless_config diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index bd5d9a11bd..8e3bc17d5e 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -100,6 +100,10 @@ class DPOTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( self, @@ -116,6 +120,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -134,16 +139,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.DPO.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.DPO.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Process hyperparameters self._process_hyperparameters() @@ -227,12 +233,14 @@ def train(self, kms_key_id=self.kms_key_id ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.DPO.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.DPO.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index f2d8460989..5d782d8fa3 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -114,6 +114,10 @@ class RLAIFTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -135,6 +139,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -156,14 +161,16 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLAIF.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLAIF.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -242,13 +249,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) evaluator_arn = getattr(self, '_evaluator_arn', None) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLAIF.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLAIF.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 333a93fc55..53029155f2 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -106,6 +106,10 @@ class RLVRTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -126,6 +130,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -146,15 +151,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLVR.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLVR.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Remove constructor-handled hyperparameters self._process_hyperparameters() @@ -233,13 +240,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, # Extract and validate evaluator ARN evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLVR.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLVR.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 233f169d0f..e2193f0b9b 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -102,6 +102,10 @@ class SFTTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -119,6 +123,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: Optional[bool] = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -138,15 +143,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.SFT.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.SFT.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Process hyperparameters self._process_hyperparameters() @@ -225,12 +232,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati kms_key_id=self.kms_key_id ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.SFT.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.SFT.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index c98dea477f..64685a8b54 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -27,9 +27,11 @@ _create_mlflow_config, _validate_eula_for_gated_model, _validate_model_region_availability, - _validate_s3_path_exists + _validate_s3_path_exists, + _parse_context_length ) -from sagemaker.core.resources import ModelPackage, ModelPackageGroup +from sagemaker.core.resources import ModelPackage +from sagemaker.core.utils.utils import Unassigned, ModelPackageGroup from sagemaker.ai_registry.dataset import DataSet from sagemaker.train.common import TrainingType from sagemaker.train.configs import InputData @@ -465,6 +467,7 @@ def test__convert_input_data_to_channels(self): def test__validate_eula_for_gated_model_with_model_package(self): """Test EULA validation returns True for ModelPackage input""" from sagemaker.core.resources import ModelPackage +from sagemaker.core.utils.utils import Unassigned model_package = Mock(spec=ModelPackage) result = _validate_eula_for_gated_model(model_package, False, True) @@ -864,3 +867,101 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, # Should still have standard params, just not datamix ones assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs + + def test__create_serverless_config_with_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K") + + assert config.sequence_length == "8K" + assert config.base_model_arn == "model-arn" + + def test__create_serverless_config_without_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) + + # sequence_length should remain Unassigned (not set), not None + assert isinstance(config.sequence_length, Unassigned) + + def test__parse_context_length_with_k_suffix(self): + assert _parse_context_length("8K") == 8192 + assert _parse_context_length("32K") == 32768 + assert _parse_context_length("128K") == 131072 + + def test__parse_context_length_with_lowercase(self): + assert _parse_context_length("8k") == 8192 + + def test__parse_context_length_with_integer(self): + assert _parse_context_length("4096") == 4096 + + def test__parse_context_length_with_none(self): + assert _parse_context_length(None) == 0 + + def test__parse_context_length_with_empty(self): + assert _parse_context_length("") == 0 + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + @patch('boto3.client') + def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_client, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json", + "Peft": True, + "SequenceLength": "32K" + } + ] + } + } + + mock_s3_client = Mock() + mock_boto_client.return_value = mock_s3_client + mock_s3_client.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) + } + + result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K") + + if result is not None: + options, model_arn, is_gated_model = result + # Should pick the 32K recipe (smallest >= 8K) + mock_s3_client.get_object.assert_called_once() + call_args = mock_s3_client.get_object.call_args[1] + assert "params-32k" in call_args["Key"] + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(self, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K" + } + ] + } + } + + # Requesting 128K but only 4K available — should raise + with pytest.raises(ValueError, match="No recipes found with SequenceLength >= 128K"): + _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="128K") diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 1b70e0bf89..7648b46e35 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -506,4 +506,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.dpo_trainer._get_unique_name') + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._create_input_data_config') + @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.dpo_trainer._create_output_config') + @patch('sagemaker.train.dpo_trainer._create_serverless_config') + @patch('sagemaker.train.dpo_trainer._create_mlflow_config') + @patch('sagemaker.train.dpo_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = DPOTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K" diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index e5666883e8..6811c45540 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -682,4 +682,70 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", sequence_length="128K") + assert trainer.sequence_length == "128K" + + @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlaif_trainer._get_unique_name') + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._create_input_data_config') + @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlaif_trainer._create_output_config') + @patch('sagemaker.train.rlaif_trainer._create_serverless_config') + @patch('sagemaker.train.rlaif_trainer._create_mlflow_config') + @patch('sagemaker.train.rlaif_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_fine_tuning_options._specs = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="64K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "64K" diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index 320b81555d..b4c01385e2 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -509,4 +509,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", sequence_length="32K") + assert trainer.sequence_length == "32K" + + @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlvr_trainer._get_unique_name') + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._create_input_data_config') + @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlvr_trainer._create_output_config') + @patch('sagemaker.train.rlvr_trainer._create_serverless_config') + @patch('sagemaker.train.rlvr_trainer._create_mlflow_config') + @patch('sagemaker.train.rlvr_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="4K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "4K" diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index 108990f839..01fc21f4bd 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -520,4 +520,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.sft_trainer._resolve_model_and_name') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.sft_trainer._get_unique_name') + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._create_input_data_config') + @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.sft_trainer._create_output_config') + @patch('sagemaker.train.sft_trainer._create_serverless_config') + @patch('sagemaker.train.sft_trainer._create_mlflow_config') + @patch('sagemaker.train.sft_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = SFTTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K" From d3f3af0af0bba27487915cccce2b16b367590b20 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 00:45:54 +0000 Subject: [PATCH 2/7] fix: use codegen for SequenceLength shape instead of manual shapes.py edit Add SequenceLength to service-2.json and regenerate shapes.py via codegen (python -m sagemaker.core.tools.codegen) instead of editing shapes.py manually. --- .../sample/sagemaker/2017-07-24/service-2.json | 17 +++++++++++++++++ .../src/sagemaker/core/shapes/shapes.py | 3 ++- .../core/utils/code_injection/shape_dag.py | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json index ceb4f316dc..6b551c5fd6 100644 --- a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json +++ b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json @@ -44132,6 +44132,10 @@ "EvaluatorArn":{ "shape":"EvaluatorArn", "documentation":"

The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.

" + }, + "SequenceLength":{ + "shape":"SequenceLength", + "documentation":"

The sequence length for the training job.

" } }, "documentation":"

The configuration for the serverless training job.

" @@ -44143,6 +44147,19 @@ "Evaluation" ] }, + "SequenceLength":{ + "type":"string", + "enum":[ + "1K", + "2K", + "4K", + "8K", + "16K", + "32K", + "64K", + "128K" + ] + }, "ServerlessMaxConcurrency":{ "type":"integer", "box":true, diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index 5e8217f463..2aa5f2afe8 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -9717,7 +9717,7 @@ class ServerlessJobConfig(Base): peft: The parameter-efficient fine-tuning configuration. evaluation_type: The evaluation job type. Required when serverless job type is Evaluation. evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. - sequence_length: The sequence length for the training job. Valid values are "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + sequence_length: The sequence length for the training job. """ base_model_arn: StrPipeVar @@ -9729,6 +9729,7 @@ class ServerlessJobConfig(Base): evaluator_arn: Optional[StrPipeVar] = Unassigned() sequence_length: Optional[StrPipeVar] = Unassigned() + class MlflowConfig(Base): """ MlflowConfig diff --git a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py index 5d0de63efd..977fe6889f 100644 --- a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py +++ b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py @@ -16206,6 +16206,7 @@ {"name": "Peft", "shape": "Peft", "type": "string"}, {"name": "EvaluationType", "shape": "EvaluationType", "type": "string"}, {"name": "EvaluatorArn", "shape": "EvaluatorArn", "type": "string"}, + {"name": "SequenceLength", "shape": "SequenceLength", "type": "string"}, ], "type": "structure", }, From 8032756d8ef7bdb1ff8413e5ea2053da9c5e4afd Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 17:14:21 +0000 Subject: [PATCH 3/7] refactor: preserve original recipe selection path when sequence_length not provided Keep the existing `next(...)` logic untouched for the default case (no sequence_length). Only build the candidates list and filter when sequence_length is explicitly requested, ensuring zero behavioral change for existing callers. --- .../train/common_utils/finetune_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index a7f1570e05..e99698cef8 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -485,14 +485,19 @@ def _get_fine_tuning_options_and_model_arn( # Collect override_params from ALL matching recipes (standard + subscription) recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] + recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] - else: - candidates = [] + recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + + # Override recipe selection when sequence_length is explicitly requested + if sequence_length: + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] + elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": + candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] + else: + candidates = [] - # Filter by SequenceLength if sequence_length is provided - if sequence_length and candidates: requested = _parse_context_length(sequence_length) candidates_with_context = [r for r in candidates if r.get("SequenceLength")] if candidates_with_context: @@ -511,8 +516,6 @@ def _get_fine_tuning_options_and_model_arn( f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, " f"and sequence length:{sequence_length}" ) - elif candidates: - recipe = candidates[0] if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") From 85299bef8aaf73123a82ca26af9f9335d34d3aac Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 17:31:36 +0000 Subject: [PATCH 4/7] fix: correct test imports and mock setup for sequence_length tests --- .../train/common_utils/test_finetune_utils.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 64685a8b54..c2cc5a9bc9 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -30,8 +30,8 @@ _validate_s3_path_exists, _parse_context_length ) -from sagemaker.core.resources import ModelPackage -from sagemaker.core.utils.utils import Unassigned, ModelPackageGroup +from sagemaker.core.resources import ModelPackage, ModelPackageGroup +from sagemaker.core.utils.utils import Unassigned from sagemaker.ai_registry.dataset import DataSet from sagemaker.train.common import TrainingType from sagemaker.train.configs import InputData @@ -467,7 +467,6 @@ def test__convert_input_data_to_channels(self): def test__validate_eula_for_gated_model_with_model_package(self): """Test EULA validation returns True for ModelPackage input""" from sagemaker.core.resources import ModelPackage -from sagemaker.core.utils.utils import Unassigned model_package = Mock(spec=ModelPackage) result = _validate_eula_for_gated_model(model_package, False, True) @@ -898,10 +897,14 @@ def test__parse_context_length_with_empty(self): assert _parse_context_length("") == 0 @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - @patch('boto3.client') - def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_client, mock_get_hub_content): + def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) + } + mock_session.boto_session.client.return_value = mock_s3 mock_get_hub_content.return_value = { 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", @@ -926,19 +929,13 @@ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_cli } } - mock_s3_client = Mock() - mock_boto_client.return_value = mock_s3_client - mock_s3_client.get_object.return_value = { - "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) - } - result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K") if result is not None: options, model_arn, is_gated_model = result # Should pick the 32K recipe (smallest >= 8K) - mock_s3_client.get_object.assert_called_once() - call_args = mock_s3_client.get_object.call_args[1] + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] assert "params-32k" in call_args["Key"] @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') From c90ff8e642efb60de2a51fb9fd21ee1694f72dec Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 21:59:20 +0000 Subject: [PATCH 5/7] address PR review: sequence_length as recipe pre-filter and simplify config - Move sequence_length filtering above recipe selection to reduce recipes_with_template before existing logic runs - Always pass sequence_length to ServerlessJobConfig (no None guard) --- .../train/common_utils/finetune_utils.py | 36 +++++++------------ .../train/common_utils/test_finetune_utils.py | 3 +- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index e99698cef8..70b95a9392 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -481,30 +481,15 @@ def _get_fine_tuning_options_and_model_arn( if not recipes_with_template: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") - # Select recipe based on training type - # Collect override_params from ALL matching recipes (standard + subscription) - recipe = None - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) - elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) - - # Override recipe selection when sequence_length is explicitly requested + # Filter by SequenceLength before recipe selection if sequence_length is requested if sequence_length: - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] - elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] - else: - candidates = [] - requested = _parse_context_length(sequence_length) - candidates_with_context = [r for r in candidates if r.get("SequenceLength")] + candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")] if candidates_with_context: filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested] if filtered: filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) - recipe = filtered[0] + recipes_with_template = filtered else: available = sorted(set(r.get("SequenceLength") for r in candidates_with_context)) raise ValueError( @@ -517,6 +502,14 @@ def _get_fine_tuning_options_and_model_arn( f"and sequence length:{sequence_length}" ) + # Select recipe based on training type + # Collect override_params from ALL matching recipes (standard + subscription) + recipe = None + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) + elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": + recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") @@ -690,18 +683,15 @@ def _create_serverless_config(model_arn, customization_technique, else (training_type.value if isinstance(training_type, TrainingType) else training_type) # Create ServerlessJobConfig using shapes - config_kwargs = dict( + serverless_config = ServerlessJobConfig( job_type=job_type, base_model_arn=model_arn, customization_technique=customization_technique, peft=peft, evaluator_arn=evaluator_arn, accept_eula=accept_eula, + sequence_length=sequence_length, ) - if sequence_length is not None: - config_kwargs["sequence_length"] = sequence_length - - serverless_config = ServerlessJobConfig(**config_kwargs) return serverless_config diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index c2cc5a9bc9..c6b561ad64 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -876,8 +876,7 @@ def test__create_serverless_config_with_sequence_length(self): def test__create_serverless_config_without_sequence_length(self): config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) - # sequence_length should remain Unassigned (not set), not None - assert isinstance(config.sequence_length, Unassigned) + assert config.sequence_length is None def test__parse_context_length_with_k_suffix(self): assert _parse_context_length("8K") == 8192 From 74810096f11e92b28982e6e3d163fa78b2c8cd86 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 22:02:51 +0000 Subject: [PATCH 6/7] fix: change hub_name default to None for consistency Use Optional[str] = None instead of hardcoded "SageMakerPublicHub" default, letting get_sagemaker_hub_name() resolve it at runtime. --- .../sagemaker/train/common_utils/finetune_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 70b95a9392..31c938bb30 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -427,12 +427,12 @@ def _parse_context_length(value) -> int: def _get_fine_tuning_options_and_model_arn( - model_name: str, - customization_technique: str, - training_type, - sagemaker_session, - sequence_length=None, - hub_name: str = "SageMakerPublicHub" + model_name: str, + customization_technique: str, + training_type, + sagemaker_session, + sequence_length=None, + hub_name: Optional[str] = None ) -> tuple: """Get fine-tuning options and model ARN for given customization technique. From fce49db36be24f075c6b161502b8d336a859efcb Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 22:05:18 +0000 Subject: [PATCH 7/7] test: add integration test for SFT trainer with sequence_length --- .../train/test_sft_trainer_integration.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py index 68446991c4..39bf702025 100644 --- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py @@ -135,3 +135,39 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): assert training_job.training_job_status == "Completed" assert hasattr(training_job, 'output_model_package_arn') assert training_job.output_model_package_arn is not None + + +@pytest.mark.gpu_intensive +def test_sft_trainer_lora_with_sequence_length(sagemaker_session): + """Test SFT training workflow with LORA and sequence_length specified.""" + unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" + + sft_trainer = SFTTrainer( + model="meta-textgeneration-llama-3-2-1b-instruct", + training_type=TrainingType.LORA, + model_package_group="arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models", + training_dataset="s3://mc-flows-sdk-testing/input_data/sft/sample_data_256_final.jsonl", + s3_output_path="s3://mc-flows-sdk-testing/output/", + accept_eula=True, + sequence_length="8K", + base_job_name=f"sft-seqlen-integ-{unique_id}", + ) + + training_job = sft_trainer.train(wait=False) + + max_wait_time = 3600 + poll_interval = 30 + start_time = time.time() + + while time.time() - start_time < max_wait_time: + training_job.refresh() + status = training_job.training_job_status + + if status in ["Completed", "Failed", "Stopped"]: + break + + time.sleep(poll_interval) + + assert training_job.training_job_status == "Completed" + assert hasattr(training_job, 'output_model_package_arn') + assert training_job.output_model_package_arn is not None