Fix low_memory_mode meta-device crash on fused-MoE models#1781
Conversation
In low_memory_mode, init_quantized_weights builds the model on the meta
device, quantizes and compresses it, then loads the weights with
accelerate.load_checkpoint_and_dispatch. accelerate matches checkpoint
keys verbatim and does no weight conversion.
transformers 5.0+ keeps MoE experts as fused 3-D parameters
(experts.gate_up_proj / experts.down_proj), but HF checkpoints store the
unfused per-expert gate_proj / up_proj / down_proj and let from_pretrained
fuse them at load time. Since the low-memory path bypasses from_pretrained,
that fusion never runs: the fused parameters have no matching on-disk key,
stay on the meta device, and dispatch then raises ("Cannot copy out of
meta tensor" or "... is on the meta device"). This hits GLM-5.2
(glm_moe_dsa) and any transformers-5 fused-MoE whose checkpoint stores
unfused experts.
The fix rebuilds those parameters before dispatch: for each fused-expert
module still on meta, read its per-expert weights lazily from the local
safetensors and fuse them (gate_up_proj[e] = cat([gate_proj[e],
up_proj[e]]), down_proj[e] per expert), the inverse of
moe_utils._export_fused_experts. Modules whose gate_up_proj is already on
disk are left alone and load normally.
This holds every routed expert in CPU RAM before dispatch, so it does not
lower peak host memory the way accelerate's streaming load does. For
models that do not fit in CPU RAM under low_memory_mode, the normal
from_pretrained device_map path still works (it streams and fuses
incrementally); fusing per layer into a temporary checkpoint to restore
streaming is a possible follow-up.
Tests: a CPU unit test with synthetic transformers-5 fused experts
reproduces the meta crash, then checks the fix on the real path where
experts are wrapped by _QuantFusedExperts, including that storage-offset
expert-index recovery survives rebuilding the parameter.
transformers references (pinned to v5.6.2):
- fused 3-D experts.gate_up_proj/down_proj params and the chunk(2) forward:
https://github.com/huggingface/transformers/blob/aa935fb53dc38b56f10cc86b3a74354c2e99412f/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py#L508-L530
- the load-time fusion this path bypasses (glm_moe_dsa aliases the
qwen2_moe WeightConverter that from_pretrained runs to merge per-expert
gate_proj/up_proj into gate_up_proj):
https://github.com/huggingface/transformers/blob/aa935fb53dc38b56f10cc86b3a74354c2e99412f/src/transformers/conversion_mapping.py#L235-L249
Signed-off-by: Aaron Batilo <AaronBatilo@gmail.com>
📝 WalkthroughWalkthroughAdds ChangesFused Expert Materialization from Checkpoint
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 2
🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/accelerate.py (1)
136-140: 🧹 Nitpick | 🔵 Trivial | 💤 Low valueMissing key raises an unguarded
KeyError.If the checkpoint has an unexpected structure (missing a shard for a key), line 137 raises
KeyErrorwith no context. Since the outer loop already guards withf"{name}.0.gate_proj.weight" not in key_to_file, the risk is limited to partial/corrupt checkpoints where some but not all expert keys exist.💡 Optional: wrap with try/except for a clearer error
def _tensor(key): - shard = key_to_file[key] + shard = key_to_file.get(key) + if shard is None: + raise KeyError( + f"Checkpoint key '{key}' not found in any safetensors shard. " + "The checkpoint may be corrupt or use an unexpected expert layout." + ) if shard not in handles: handles[shard] = safe_open(shard, framework="pt") return handles[shard].get_tensor(key)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/plugins/accelerate.py` around lines 136 - 140, The _tensor function raises an unguarded KeyError when accessing key_to_file[key] if the key does not exist in the dictionary. Wrap the line that accesses key_to_file[key] with a try/except block to catch the KeyError and re-raise it with additional context information such as the missing key name and details about the checkpoint structure, making it clearer to users when they encounter partial or corrupt checkpoints.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/plugins/accelerate.py`:
- Around line 112-157: The file handles stored in the handles dictionary (which
are created by safe_open calls within the _tensor function) are never explicitly
closed after use. After the for loop that processes the fused modules completes,
add cleanup code to explicitly close all the file handles in the handles
dictionary by iterating through handles.values() and calling the appropriate
close method on each handle to prevent file descriptor leaks.
In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 959-975: The safetensors import statement is currently inside the
_write_separate_expert_checkpoint function body instead of at the module level.
Move the `from safetensors.torch import save_file` import to the top of the file
with the other module-level imports, then remove the import statement from
inside the function body. Since safetensors is a required dependency for this
test, there is no need to guard this import within the function.
---
Nitpick comments:
In `@modelopt/torch/quantization/plugins/accelerate.py`:
- Around line 136-140: The _tensor function raises an unguarded KeyError when
accessing key_to_file[key] if the key does not exist in the dictionary. Wrap the
line that accesses key_to_file[key] with a try/except block to catch the
KeyError and re-raise it with additional context information such as the missing
key name and details about the checkpoint structure, making it clearer to users
when they encounter partial or corrupt checkpoints.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 9d2e5c4c-495b-4ba4-91d7-f36986b22b06
📒 Files selected for processing (2)
modelopt/torch/quantization/plugins/accelerate.pytests/unit/torch/quantization/plugins/test_fused_experts.py
| def _materialize_fused_experts_from_checkpoint(model, checkpoint): | ||
| """Fuse separate on-disk per-expert weights into the meta fused-MoE params. | ||
|
|
||
| load_checkpoint_and_dispatch matches keys verbatim and does not run transformers' | ||
| expert fusion, so the fused gate_up_proj/down_proj have no on-disk key and stay on | ||
| meta (dispatch then raises). Rebuild them from the local safetensors before dispatch. | ||
| Holds every routed expert in CPU RAM, so it does not stream or offload. | ||
| """ | ||
| fused = [ | ||
| (name, m) | ||
| for name, m in model.named_modules() | ||
| if _is_fused_experts_module(m) and m.gate_up_proj.device.type == "meta" | ||
| ] | ||
| if not fused: | ||
| return | ||
|
|
||
| # Map each checkpoint key to its shard (lazy header read, no tensor data). | ||
| key_to_file = {} | ||
| for shard in glob.glob(os.path.join(checkpoint, "*.safetensors")): | ||
| with safe_open(shard, framework="pt") as f: | ||
| key_to_file.update(dict.fromkeys(f.keys(), shard)) | ||
|
|
||
| handles = {} | ||
|
|
||
| def _tensor(key): | ||
| shard = key_to_file[key] | ||
| if shard not in handles: | ||
| handles[shard] = safe_open(shard, framework="pt") | ||
| return handles[shard].get_tensor(key) | ||
|
|
||
| for name, module in fused: | ||
| if f"{name}.0.gate_proj.weight" not in key_to_file: | ||
| continue # already fused on disk, or an unhandled layout | ||
| dtype = module.gate_up_proj.dtype | ||
| gate_up, down = [], [] | ||
| for e in range(module.num_experts): | ||
| gate = _tensor(f"{name}.{e}.gate_proj.weight") | ||
| up = _tensor(f"{name}.{e}.up_proj.weight") | ||
| gate_up.append(torch.cat([gate, up], dim=0)) | ||
| down.append(_tensor(f"{name}.{e}.down_proj.weight")) | ||
| module.gate_up_proj = torch.nn.Parameter( | ||
| torch.stack(gate_up).to(dtype), requires_grad=False | ||
| ) | ||
| module.down_proj = torch.nn.Parameter(torch.stack(down).to(dtype), requires_grad=False) | ||
|
|
||
|
|
There was a problem hiding this comment.
Safetensors file handles are never explicitly closed.
The handles dictionary accumulates open safe_open file handles but they're never closed. While Python's GC will eventually close them, this can leak file descriptors when processing checkpoints with many shards, potentially hitting OS limits during large model loads.
🔧 Suggested fix: close handles after materializing
for name, module in fused:
if f"{name}.0.gate_proj.weight" not in key_to_file:
continue # already fused on disk, or an unhandled layout
dtype = module.gate_up_proj.dtype
gate_up, down = [], []
for e in range(module.num_experts):
gate = _tensor(f"{name}.{e}.gate_proj.weight")
up = _tensor(f"{name}.{e}.up_proj.weight")
gate_up.append(torch.cat([gate, up], dim=0))
down.append(_tensor(f"{name}.{e}.down_proj.weight"))
module.gate_up_proj = torch.nn.Parameter(
torch.stack(gate_up).to(dtype), requires_grad=False
)
module.down_proj = torch.nn.Parameter(torch.stack(down).to(dtype), requires_grad=False)
+
+ # Close all safetensors file handles
+ for handle in handles.values():
+ handle.__exit__(None, None, None)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/quantization/plugins/accelerate.py` around lines 112 - 157,
The file handles stored in the handles dictionary (which are created by
safe_open calls within the _tensor function) are never explicitly closed after
use. After the for loop that processes the fused modules completes, add cleanup
code to explicitly close all the file handles in the handles dictionary by
iterating through handles.values() and calling the appropriate close method on
each handle to prevent file descriptor leaks.
| def _write_separate_expert_checkpoint(tmp_path): | ||
| """Write a checkpoint with SEPARATE per-expert weights; return the fused refs.""" | ||
| from safetensors.torch import save_file | ||
|
|
||
| gate_up_ref = torch.randn(NUM_EXPERTS, 2 * INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02 | ||
| down_ref = torch.randn(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM) * 0.02 | ||
| state_dict = {"moe.gate.weight": torch.randn(NUM_EXPERTS, HIDDEN_DIM) * 0.02} | ||
| for e in range(NUM_EXPERTS): | ||
| state_dict[f"moe.experts.{e}.gate_proj.weight"] = gate_up_ref[ | ||
| e, :INTERMEDIATE_DIM, : | ||
| ].contiguous() | ||
| state_dict[f"moe.experts.{e}.up_proj.weight"] = gate_up_ref[ | ||
| e, INTERMEDIATE_DIM:, : | ||
| ].contiguous() | ||
| state_dict[f"moe.experts.{e}.down_proj.weight"] = down_ref[e].contiguous() | ||
| save_file(state_dict, str(tmp_path / "model.safetensors")) | ||
| return gate_up_ref, down_ref |
There was a problem hiding this comment.
Move safetensors import to module level.
Per CONTRIBUTING.md, imports belong at the top of the file so import errors surface at collection time. The safetensors.torch.save_file import on line 961 is inside the helper function without a documented reason (not an optional dependency guard—safetensors is required for this test to run).
🔧 Suggested fix
Add at module level (near other imports):
from safetensors.torch import save_fileThen remove line 961 from the function body.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 959
- 975, The safetensors import statement is currently inside the
_write_separate_expert_checkpoint function body instead of at the module level.
Move the `from safetensors.torch import save_file` import to the top of the file
with the other module-level imports, then remove the import statement from
inside the function body. Since safetensors is a required dependency for this
test, there is no need to guard this import within the function.
Source: Coding guidelines
In low_memory_mode, init_quantized_weights builds the model on the meta
device, quantizes and compresses it, then loads the weights with
accelerate.load_checkpoint_and_dispatch. accelerate matches checkpoint
keys verbatim and does no weight conversion.
transformers 5.0+ keeps MoE experts as fused 3-D parameters
(experts.gate_up_proj / experts.down_proj), but HF checkpoints store the
unfused per-expert gate_proj / up_proj / down_proj and let from_pretrained
fuse them at load time. Since the low-memory path bypasses from_pretrained,
that fusion never runs: the fused parameters have no matching on-disk key,
stay on the meta device, and dispatch then raises ("Cannot copy out of
meta tensor" or "... is on the meta device"). This hits GLM-5.2
(glm_moe_dsa) and any transformers-5 fused-MoE whose checkpoint stores
unfused experts.
The fix rebuilds those parameters before dispatch: for each fused-expert
module still on meta, read its per-expert weights lazily from the local
safetensors and fuse them (gate_up_proj[e] = cat([gate_proj[e],
up_proj[e]]), down_proj[e] per expert), the inverse of
moe_utils._export_fused_experts. Modules whose gate_up_proj is already on
disk are left alone and load normally.
This holds every routed expert in CPU RAM before dispatch, so it does not
lower peak host memory the way accelerate's streaming load does. For
models that do not fit in CPU RAM under low_memory_mode, the normal
from_pretrained device_map path still works (it streams and fuses
incrementally); fusing per layer into a temporary checkpoint to restore
streaming is a possible follow-up.
Tests: a CPU unit test with synthetic transformers-5 fused experts
reproduces the meta crash, then checks the fix on the real path where
experts are wrapped by _QuantFusedExperts, including that storage-offset
expert-index recovery survives rebuilding the parameter.
transformers references (pinned to v5.6.2):
https://github.com/huggingface/transformers/blob/aa935fb53dc38b56f10cc86b3a74354c2e99412f/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py#L508-L530
qwen2_moe WeightConverter that from_pretrained runs to merge per-expert
gate_proj/up_proj into gate_up_proj):
https://github.com/huggingface/transformers/blob/aa935fb53dc38b56f10cc86b3a74354c2e99412f/src/transformers/conversion_mapping.py#L235-L249
Signed-off-by: Aaron Batilo AaronBatilo@gmail.com
Summary by CodeRabbit
Release Notes
Bug Fixes
Tests