diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 6659857b8c..b9fa6f6cca 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -44,7 +44,7 @@ python3 -m pip install safetensors --no-deps export MODEL= # e.g. 'llama3.1-8b-Instruct' export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-checkpoint-directory export USE_PATHWAYS=0 # Set to 1 if you intend to use Pathways for training, 0 for McJAX -export LAZY_LOAD_TENSORS= # Set to True to save RAM +export LAZY_LOAD_TENSORS=True # Defaults to True. Set to False only for multimodal models. ``` ### Run Conversion @@ -77,7 +77,7 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i - `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local. - `hardware=cpu`: The conversion script runs on a CPU machine. - `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False. -- `--lazy_load_tensors` (Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. +- `--lazy_load_tensors` (Optional): Defaults to `True`. Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model uses around 200GB of RAM and completes in ~10 minutes. **Note:** This must be overloaded to `False` for multimodal models (e.g., Gemma3) as lazy loading is not yet supported for them. - `--hf_model_path` (Optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. - `--save_dtype` (Optional): Specifies the data type of saved model weights. Default to `bfloat16` to save memory. diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 8e73621c69..9c6d6a6f72 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -816,7 +816,7 @@ def _merged_getter(key): def main( args: Sequence[str], - lazy_load_tensors: bool = False, + lazy_load_tensors: bool = True, eager_load_method: str = "safetensors", hf_model_path: str | None = None, revision: str | None = None, @@ -1077,7 +1077,7 @@ def _eager_getter(key): "--lazy_load_tensors", type=str2bool, required=False, - default=False, + default=True, help="Whether to use lazy loading of HF tensors", ) # Eager load uses `transformers_class.from_pretrained` with auto dtype or `safetensors.safe_open` with pt. diff --git a/tests/integration/checkpoint_conversion_test.py b/tests/integration/checkpoint_conversion_test.py index e73b71d076..e40ffaac02 100644 --- a/tests/integration/checkpoint_conversion_test.py +++ b/tests/integration/checkpoint_conversion_test.py @@ -54,7 +54,6 @@ def test_qwen3_30b_a3b_roundtrip_conversion(self): "checkpoint_storage_use_ocdbt=False", "checkpoint_storage_use_zarr3=False", "--save_dtype=bfloat16", - "--lazy_load_tensors=True", ] env = os.environ.copy() env["JAX_PLATFORMS"] = "cpu"