diff --git a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json index ceb4f316dc..6b551c5fd6 100644 --- a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json +++ b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json @@ -44132,6 +44132,10 @@ "EvaluatorArn":{ "shape":"EvaluatorArn", "documentation":"

The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.

" + }, + "SequenceLength":{ + "shape":"SequenceLength", + "documentation":"

The sequence length for the training job.

" } }, "documentation":"

The configuration for the serverless training job.

" @@ -44143,6 +44147,19 @@ "Evaluation" ] }, + "SequenceLength":{ + "type":"string", + "enum":[ + "1K", + "2K", + "4K", + "8K", + "16K", + "32K", + "64K", + "128K" + ] + }, "ServerlessMaxConcurrency":{ "type":"integer", "box":true, diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index ce25c890dd..2aa5f2afe8 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -9717,6 +9717,7 @@ class ServerlessJobConfig(Base): peft: The parameter-efficient fine-tuning configuration. evaluation_type: The evaluation job type. Required when serverless job type is Evaluation. evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. + sequence_length: The sequence length for the training job. """ base_model_arn: StrPipeVar @@ -9726,6 +9727,7 @@ class ServerlessJobConfig(Base): peft: Optional[StrPipeVar] = Unassigned() evaluation_type: Optional[StrPipeVar] = Unassigned() evaluator_arn: Optional[StrPipeVar] = Unassigned() + sequence_length: Optional[StrPipeVar] = Unassigned() class MlflowConfig(Base): diff --git a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py index 5d0de63efd..977fe6889f 100644 --- a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py +++ b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py @@ -16206,6 +16206,7 @@ {"name": "Peft", "shape": "Peft", "type": "string"}, {"name": "EvaluationType", "shape": "EvaluationType", "type": "string"}, {"name": "EvaluatorArn", "shape": "EvaluatorArn", "type": "string"}, + {"name": "SequenceLength", "shape": "SequenceLength", "type": "string"}, ], "type": "structure", }, diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 6479e803bd..31c938bb30 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -407,10 +407,44 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: return None -def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session, - hub_name: Optional[str] = None) -> tuple: +def _parse_context_length(value) -> int: + """Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192). + + Returns 0 if value is None or unparseable. + """ + if not value: + return 0 + value = str(value).strip().upper() + if value.endswith("K"): + try: + return int(value[:-1]) * 1024 + except ValueError: + return 0 + try: + return int(value) + except ValueError: + return 0 + + +def _get_fine_tuning_options_and_model_arn( + model_name: str, + customization_technique: str, + training_type, + sagemaker_session, + sequence_length=None, + hub_name: Optional[str] = None +) -> tuple: """Get fine-tuning options and model ARN for given customization technique. + Args: + model_name: Name of the model in the hub. + customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF"). + training_type: TrainingType enum or string ("LORA", "FULL"). + sagemaker_session: SageMaker session for API calls. + sequence_length: Optional sequence length (e.g., "8K"). When provided, filters + recipes by MaxContextLength >= the requested value. + hub_name: Hub name (default: "SageMakerPublicHub"). + Returns: tuple: (FineTuningOptions, model_arn, is_gated_model) """ @@ -447,6 +481,27 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni if not recipes_with_template: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") + # Filter by SequenceLength before recipe selection if sequence_length is requested + if sequence_length: + requested = _parse_context_length(sequence_length) + candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")] + if candidates_with_context: + filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested] + if filtered: + filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) + recipes_with_template = filtered + else: + available = sorted(set(r.get("SequenceLength") for r in candidates_with_context)) + raise ValueError( + f"No recipes found with SequenceLength >= {sequence_length}. " + f"Available sequence lengths: {available}" + ) + else: + raise ValueError( + f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, " + f"and sequence length:{sequence_length}" + ) + # Select recipe based on training type # Collect override_params from ALL matching recipes (standard + subscription) recipe = None @@ -608,7 +663,8 @@ def _resolve_model_and_name(model, sagemaker_session=None): def _create_serverless_config(model_arn, customization_technique, - training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: + training_type, accept_eula, evaluator_arn=None, + sequence_length=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: """Create serverless job configuration for fine-tuning. Args: @@ -617,6 +673,7 @@ def _create_serverless_config(model_arn, customization_technique, training_type: Training type (TrainingType enum or string) accept_eula: Boolean indicating if EULA is accepted evaluator_arn: Optional evaluator ARN for RLVR/RLAIF + sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K") job_type: Type of job (default: "FineTuning") Returns: @@ -632,7 +689,8 @@ def _create_serverless_config(model_arn, customization_technique, customization_technique=customization_technique, peft=peft, evaluator_arn=evaluator_arn, - accept_eula=accept_eula + accept_eula=accept_eula, + sequence_length=sequence_length, ) return serverless_config diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index bd5d9a11bd..8e3bc17d5e 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -100,6 +100,10 @@ class DPOTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( self, @@ -116,6 +120,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -134,16 +139,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.DPO.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.DPO.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Process hyperparameters self._process_hyperparameters() @@ -227,12 +233,14 @@ def train(self, kms_key_id=self.kms_key_id ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.DPO.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.DPO.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index f2d8460989..5d782d8fa3 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -114,6 +114,10 @@ class RLAIFTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -135,6 +139,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -156,14 +161,16 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLAIF.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLAIF.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -242,13 +249,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) evaluator_arn = getattr(self, '_evaluator_arn', None) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLAIF.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLAIF.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 333a93fc55..53029155f2 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -106,6 +106,10 @@ class RLVRTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -126,6 +130,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -146,15 +151,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLVR.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLVR.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Remove constructor-handled hyperparameters self._process_hyperparameters() @@ -233,13 +240,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, # Extract and validate evaluator ARN evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLVR.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLVR.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 233f169d0f..e2193f0b9b 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -102,6 +102,10 @@ class SFTTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -119,6 +123,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: Optional[bool] = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -138,15 +143,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.SFT.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.SFT.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Process hyperparameters self._process_hyperparameters() @@ -225,12 +232,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati kms_key_id=self.kms_key_id ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.SFT.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.SFT.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py index 68446991c4..39bf702025 100644 --- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py @@ -135,3 +135,39 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): assert training_job.training_job_status == "Completed" assert hasattr(training_job, 'output_model_package_arn') assert training_job.output_model_package_arn is not None + + +@pytest.mark.gpu_intensive +def test_sft_trainer_lora_with_sequence_length(sagemaker_session): + """Test SFT training workflow with LORA and sequence_length specified.""" + unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" + + sft_trainer = SFTTrainer( + model="meta-textgeneration-llama-3-2-1b-instruct", + training_type=TrainingType.LORA, + model_package_group="arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models", + training_dataset="s3://mc-flows-sdk-testing/input_data/sft/sample_data_256_final.jsonl", + s3_output_path="s3://mc-flows-sdk-testing/output/", + accept_eula=True, + sequence_length="8K", + base_job_name=f"sft-seqlen-integ-{unique_id}", + ) + + training_job = sft_trainer.train(wait=False) + + max_wait_time = 3600 + poll_interval = 30 + start_time = time.time() + + while time.time() - start_time < max_wait_time: + training_job.refresh() + status = training_job.training_job_status + + if status in ["Completed", "Failed", "Stopped"]: + break + + time.sleep(poll_interval) + + assert training_job.training_job_status == "Completed" + assert hasattr(training_job, 'output_model_package_arn') + assert training_job.output_model_package_arn is not None diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index c98dea477f..c6b561ad64 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -27,9 +27,11 @@ _create_mlflow_config, _validate_eula_for_gated_model, _validate_model_region_availability, - _validate_s3_path_exists + _validate_s3_path_exists, + _parse_context_length ) from sagemaker.core.resources import ModelPackage, ModelPackageGroup +from sagemaker.core.utils.utils import Unassigned from sagemaker.ai_registry.dataset import DataSet from sagemaker.train.common import TrainingType from sagemaker.train.configs import InputData @@ -864,3 +866,98 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, # Should still have standard params, just not datamix ones assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs + + def test__create_serverless_config_with_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K") + + assert config.sequence_length == "8K" + assert config.base_model_arn == "model-arn" + + def test__create_serverless_config_without_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) + + assert config.sequence_length is None + + def test__parse_context_length_with_k_suffix(self): + assert _parse_context_length("8K") == 8192 + assert _parse_context_length("32K") == 32768 + assert _parse_context_length("128K") == 131072 + + def test__parse_context_length_with_lowercase(self): + assert _parse_context_length("8k") == 8192 + + def test__parse_context_length_with_integer(self): + assert _parse_context_length("4096") == 4096 + + def test__parse_context_length_with_none(self): + assert _parse_context_length(None) == 0 + + def test__parse_context_length_with_empty(self): + assert _parse_context_length("") == 0 + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) + } + mock_session.boto_session.client.return_value = mock_s3 + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json", + "Peft": True, + "SequenceLength": "32K" + } + ] + } + } + + result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K") + + if result is not None: + options, model_arn, is_gated_model = result + # Should pick the 32K recipe (smallest >= 8K) + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] + assert "params-32k" in call_args["Key"] + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(self, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K" + } + ] + } + } + + # Requesting 128K but only 4K available — should raise + with pytest.raises(ValueError, match="No recipes found with SequenceLength >= 128K"): + _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="128K") diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 1b70e0bf89..7648b46e35 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -506,4 +506,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.dpo_trainer._get_unique_name') + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._create_input_data_config') + @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.dpo_trainer._create_output_config') + @patch('sagemaker.train.dpo_trainer._create_serverless_config') + @patch('sagemaker.train.dpo_trainer._create_mlflow_config') + @patch('sagemaker.train.dpo_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = DPOTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K" diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index e5666883e8..6811c45540 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -682,4 +682,70 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", sequence_length="128K") + assert trainer.sequence_length == "128K" + + @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlaif_trainer._get_unique_name') + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._create_input_data_config') + @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlaif_trainer._create_output_config') + @patch('sagemaker.train.rlaif_trainer._create_serverless_config') + @patch('sagemaker.train.rlaif_trainer._create_mlflow_config') + @patch('sagemaker.train.rlaif_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_fine_tuning_options._specs = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="64K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "64K" diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index 320b81555d..b4c01385e2 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -509,4 +509,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", sequence_length="32K") + assert trainer.sequence_length == "32K" + + @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlvr_trainer._get_unique_name') + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._create_input_data_config') + @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlvr_trainer._create_output_config') + @patch('sagemaker.train.rlvr_trainer._create_serverless_config') + @patch('sagemaker.train.rlvr_trainer._create_mlflow_config') + @patch('sagemaker.train.rlvr_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="4K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "4K" diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index 108990f839..01fc21f4bd 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -520,4 +520,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.sft_trainer._resolve_model_and_name') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.sft_trainer._get_unique_name') + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._create_input_data_config') + @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.sft_trainer._create_output_config') + @patch('sagemaker.train.sft_trainer._create_serverless_config') + @patch('sagemaker.train.sft_trainer._create_mlflow_config') + @patch('sagemaker.train.sft_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = SFTTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K"