Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/guides/checkpointing_solutions/convert_checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python3 -m pip install safetensors --no-deps
export MODEL=<HF_MODEL> # e.g. 'llama3.1-8b-Instruct'
export BASE_OUTPUT_DIRECTORY=<CKPT_PATH> # e.g., gs://my-bucket/my-checkpoint-directory
export USE_PATHWAYS=0 # Set to 1 if you intend to use Pathways for training, 0 for McJAX
export LAZY_LOAD_TENSORS=<LAZY_LOAD> # Set to True to save RAM
export LAZY_LOAD_TENSORS=True # Defaults to True. Set to False only for multimodal models.
```

### Run Conversion
Expand Down Expand Up @@ -77,7 +77,7 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local.
- `hardware=cpu`: The conversion script runs on a CPU machine.
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False.
- `--lazy_load_tensors` (Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
- `--lazy_load_tensors` (Optional): Defaults to `True`. Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model uses around 200GB of RAM and completes in ~10 minutes. **Note:** This must be overloaded to `False` for multimodal models (e.g., Gemma3) as lazy loading is not yet supported for them.
- `--hf_model_path` (Optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
- `--save_dtype` (Optional): Specifies the data type of saved model weights. Default to `bfloat16` to save memory.

Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def _merged_getter(key):

def main(
args: Sequence[str],
lazy_load_tensors: bool = False,
lazy_load_tensors: bool = True,
eager_load_method: str = "safetensors",
hf_model_path: str | None = None,
revision: str | None = None,
Expand Down Expand Up @@ -1077,7 +1077,7 @@ def _eager_getter(key):
"--lazy_load_tensors",
type=str2bool,
required=False,
default=False,
default=True,
help="Whether to use lazy loading of HF tensors",
)
# Eager load uses `transformers_class.from_pretrained` with auto dtype or `safetensors.safe_open` with pt.
Expand Down
1 change: 0 additions & 1 deletion tests/integration/checkpoint_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_qwen3_30b_a3b_roundtrip_conversion(self):
"checkpoint_storage_use_ocdbt=False",
"checkpoint_storage_use_zarr3=False",
"--save_dtype=bfloat16",
"--lazy_load_tensors=True",
]
env = os.environ.copy()
env["JAX_PLATFORMS"] = "cpu"
Expand Down
Loading