From 579d9f7b0eab22a594d13a9ba76be01f9c6f3336 Mon Sep 17 00:00:00 2001 From: Nitin Gangahar Date: Fri, 29 May 2026 16:04:30 -0700 Subject: [PATCH] Set lazy_load_tensors to true by default. lazy_load_tensors saves on disk space and the path has now been optimized to be quicker with caching. Added appropriate documentation to indicate that this does not work for models with multimodal capabilities. --- docs/guides/checkpointing_solutions/convert_checkpoint.md | 4 ++-- src/maxtext/checkpoint_conversion/to_maxtext.py | 4 ++-- tests/integration/checkpoint_conversion_test.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 6659857b8c..b9fa6f6cca 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -44,7 +44,7 @@ python3 -m pip install safetensors --no-deps export MODEL= # e.g. 'llama3.1-8b-Instruct' export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-checkpoint-directory export USE_PATHWAYS=0 # Set to 1 if you intend to use Pathways for training, 0 for McJAX -export LAZY_LOAD_TENSORS= # Set to True to save RAM +export LAZY_LOAD_TENSORS=True # Defaults to True. Set to False only for multimodal models. ``` ### Run Conversion @@ -77,7 +77,7 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i - `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local. - `hardware=cpu`: The conversion script runs on a CPU machine. - `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False. -- `--lazy_load_tensors` (Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. +- `--lazy_load_tensors` (Optional): Defaults to `True`. Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model uses around 200GB of RAM and completes in ~10 minutes. **Note:** This must be overloaded to `False` for multimodal models (e.g., Gemma3) as lazy loading is not yet supported for them. - `--hf_model_path` (Optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. - `--save_dtype` (Optional): Specifies the data type of saved model weights. Default to `bfloat16` to save memory. diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 8e73621c69..9c6d6a6f72 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -816,7 +816,7 @@ def _merged_getter(key): def main( args: Sequence[str], - lazy_load_tensors: bool = False, + lazy_load_tensors: bool = True, eager_load_method: str = "safetensors", hf_model_path: str | None = None, revision: str | None = None, @@ -1077,7 +1077,7 @@ def _eager_getter(key): "--lazy_load_tensors", type=str2bool, required=False, - default=False, + default=True, help="Whether to use lazy loading of HF tensors", ) # Eager load uses `transformers_class.from_pretrained` with auto dtype or `safetensors.safe_open` with pt. diff --git a/tests/integration/checkpoint_conversion_test.py b/tests/integration/checkpoint_conversion_test.py index e73b71d076..e40ffaac02 100644 --- a/tests/integration/checkpoint_conversion_test.py +++ b/tests/integration/checkpoint_conversion_test.py @@ -54,7 +54,6 @@ def test_qwen3_30b_a3b_roundtrip_conversion(self): "checkpoint_storage_use_ocdbt=False", "checkpoint_storage_use_zarr3=False", "--save_dtype=bfloat16", - "--lazy_load_tensors=True", ] env = os.environ.copy() env["JAX_PLATFORMS"] = "cpu"