Skip to content

Fix low_memory_mode meta-device crash on fused-MoE models#1781

Open
abatilo wants to merge 1 commit into
NVIDIA:mainfrom
abatilo:fix-low-memory-fused-moe
Open

Fix low_memory_mode meta-device crash on fused-MoE models#1781
abatilo wants to merge 1 commit into
NVIDIA:mainfrom
abatilo:fix-low-memory-fused-moe

Conversation

@abatilo

@abatilo abatilo commented Jun 21, 2026

Copy link
Copy Markdown

In low_memory_mode, init_quantized_weights builds the model on the meta
device, quantizes and compresses it, then loads the weights with
accelerate.load_checkpoint_and_dispatch. accelerate matches checkpoint
keys verbatim and does no weight conversion.

transformers 5.0+ keeps MoE experts as fused 3-D parameters
(experts.gate_up_proj / experts.down_proj), but HF checkpoints store the
unfused per-expert gate_proj / up_proj / down_proj and let from_pretrained
fuse them at load time. Since the low-memory path bypasses from_pretrained,
that fusion never runs: the fused parameters have no matching on-disk key,
stay on the meta device, and dispatch then raises ("Cannot copy out of
meta tensor" or "... is on the meta device"). This hits GLM-5.2
(glm_moe_dsa) and any transformers-5 fused-MoE whose checkpoint stores
unfused experts.

The fix rebuilds those parameters before dispatch: for each fused-expert
module still on meta, read its per-expert weights lazily from the local
safetensors and fuse them (gate_up_proj[e] = cat([gate_proj[e],
up_proj[e]]), down_proj[e] per expert), the inverse of
moe_utils._export_fused_experts. Modules whose gate_up_proj is already on
disk are left alone and load normally.

This holds every routed expert in CPU RAM before dispatch, so it does not
lower peak host memory the way accelerate's streaming load does. For
models that do not fit in CPU RAM under low_memory_mode, the normal
from_pretrained device_map path still works (it streams and fuses
incrementally); fusing per layer into a temporary checkpoint to restore
streaming is a possible follow-up.

Tests: a CPU unit test with synthetic transformers-5 fused experts
reproduces the meta crash, then checks the fix on the real path where
experts are wrapped by _QuantFusedExperts, including that storage-offset
expert-index recovery survives rebuilding the parameter.

transformers references (pinned to v5.6.2):

Signed-off-by: Aaron Batilo AaronBatilo@gmail.com

Summary by CodeRabbit

Release Notes

  • Bug Fixes

    • Fixed loading and materialization of quantized models with fused mixture-of-experts modules from checkpoint files in memory-constrained environments
  • Tests

    • Added regression test for low-memory checkpoint loading with fused expert modules, verifying correct parameter reconstruction and expert index recovery

In low_memory_mode, init_quantized_weights builds the model on the meta
device, quantizes and compresses it, then loads the weights with
accelerate.load_checkpoint_and_dispatch. accelerate matches checkpoint
keys verbatim and does no weight conversion.

transformers 5.0+ keeps MoE experts as fused 3-D parameters
(experts.gate_up_proj / experts.down_proj), but HF checkpoints store the
unfused per-expert gate_proj / up_proj / down_proj and let from_pretrained
fuse them at load time. Since the low-memory path bypasses from_pretrained,
that fusion never runs: the fused parameters have no matching on-disk key,
stay on the meta device, and dispatch then raises ("Cannot copy out of
meta tensor" or "... is on the meta device"). This hits GLM-5.2
(glm_moe_dsa) and any transformers-5 fused-MoE whose checkpoint stores
unfused experts.

The fix rebuilds those parameters before dispatch: for each fused-expert
module still on meta, read its per-expert weights lazily from the local
safetensors and fuse them (gate_up_proj[e] = cat([gate_proj[e],
up_proj[e]]), down_proj[e] per expert), the inverse of
moe_utils._export_fused_experts. Modules whose gate_up_proj is already on
disk are left alone and load normally.

This holds every routed expert in CPU RAM before dispatch, so it does not
lower peak host memory the way accelerate's streaming load does. For
models that do not fit in CPU RAM under low_memory_mode, the normal
from_pretrained device_map path still works (it streams and fuses
incrementally); fusing per layer into a temporary checkpoint to restore
streaming is a possible follow-up.

Tests: a CPU unit test with synthetic transformers-5 fused experts
reproduces the meta crash, then checks the fix on the real path where
experts are wrapped by _QuantFusedExperts, including that storage-offset
expert-index recovery survives rebuilding the parameter.

transformers references (pinned to v5.6.2):
- fused 3-D experts.gate_up_proj/down_proj params and the chunk(2) forward:
  https://github.com/huggingface/transformers/blob/aa935fb53dc38b56f10cc86b3a74354c2e99412f/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py#L508-L530
- the load-time fusion this path bypasses (glm_moe_dsa aliases the
  qwen2_moe WeightConverter that from_pretrained runs to merge per-expert
  gate_proj/up_proj into gate_up_proj):
  https://github.com/huggingface/transformers/blob/aa935fb53dc38b56f10cc86b3a74354c2e99412f/src/transformers/conversion_mapping.py#L235-L249

Signed-off-by: Aaron Batilo <AaronBatilo@gmail.com>
@abatilo abatilo requested review from a team as code owners June 21, 2026 22:43
@abatilo abatilo requested a review from kinjalpatel27 June 21, 2026 22:43
@copy-pr-bot

copy-pr-bot Bot commented Jun 21, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 21, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds _materialize_fused_experts_from_checkpoint to accelerate.py, which scans fused MoE modules still on meta, reads safetensors shard headers to locate per-expert gate/up/down projection tensors, and reconstructs fused gate_up_proj/down_proj parameters. This function is called in patched_from_pretrained before load_checkpoint_and_dispatch. A regression test verifies both the failure path and the corrected materialization flow.

Changes

Fused Expert Materialization from Checkpoint

Layer / File(s) Summary
Checkpoint materialization helper and call site
modelopt/torch/quantization/plugins/accelerate.py
Adds glob, os, safe_open, and _is_fused_experts_module imports. Implements _materialize_fused_experts_from_checkpoint that finds fused modules with meta-device parameters, maps tensor keys to .safetensors shards by reading headers, loads per-expert gate/up/down tensors, and reassigns reconstructed fused parameters. Invokes this helper in patched_from_pretrained after device-map computation and before dispatch.
Low-memory loading regression test
tests/unit/torch/quantization/plugins/test_fused_experts.py
Adds _write_separate_expert_checkpoint to write a per-expert safetensors checkpoint, and TestLowMemoryFusedExpertsLoading.test_low_memory_load_materializes_fused_experts that asserts raw dispatch fails on meta tensors, then validates that fused-expert conversion plus materialization produces correctly shaped and valued gate_up_proj/down_proj and recovers the correct expert index.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • jingyu-ml
  • kaix-nv
  • kevalmorabia97
  • cjluo-nv
  • realAsma
🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: fixing a low_memory_mode meta-device crash on fused-MoE models by implementing expert materialization from checkpoints.
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All security coding practices in SECURITY.md are followed. No torch.load/numpy.load with unsafe settings, no hardcoded trust_remote_code, no eval/exec, no nosec comments, no new non-permissive depe...

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/accelerate.py (1)

136-140: 🧹 Nitpick | 🔵 Trivial | 💤 Low value

Missing key raises an unguarded KeyError.

If the checkpoint has an unexpected structure (missing a shard for a key), line 137 raises KeyError with no context. Since the outer loop already guards with f"{name}.0.gate_proj.weight" not in key_to_file, the risk is limited to partial/corrupt checkpoints where some but not all expert keys exist.

💡 Optional: wrap with try/except for a clearer error
     def _tensor(key):
-        shard = key_to_file[key]
+        shard = key_to_file.get(key)
+        if shard is None:
+            raise KeyError(
+                f"Checkpoint key '{key}' not found in any safetensors shard. "
+                "The checkpoint may be corrupt or use an unexpected expert layout."
+            )
         if shard not in handles:
             handles[shard] = safe_open(shard, framework="pt")
         return handles[shard].get_tensor(key)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/plugins/accelerate.py` around lines 136 - 140,
The _tensor function raises an unguarded KeyError when accessing
key_to_file[key] if the key does not exist in the dictionary. Wrap the line that
accesses key_to_file[key] with a try/except block to catch the KeyError and
re-raise it with additional context information such as the missing key name and
details about the checkpoint structure, making it clearer to users when they
encounter partial or corrupt checkpoints.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/plugins/accelerate.py`:
- Around line 112-157: The file handles stored in the handles dictionary (which
are created by safe_open calls within the _tensor function) are never explicitly
closed after use. After the for loop that processes the fused modules completes,
add cleanup code to explicitly close all the file handles in the handles
dictionary by iterating through handles.values() and calling the appropriate
close method on each handle to prevent file descriptor leaks.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 959-975: The safetensors import statement is currently inside the
_write_separate_expert_checkpoint function body instead of at the module level.
Move the `from safetensors.torch import save_file` import to the top of the file
with the other module-level imports, then remove the import statement from
inside the function body. Since safetensors is a required dependency for this
test, there is no need to guard this import within the function.

---

Nitpick comments:
In `@modelopt/torch/quantization/plugins/accelerate.py`:
- Around line 136-140: The _tensor function raises an unguarded KeyError when
accessing key_to_file[key] if the key does not exist in the dictionary. Wrap the
line that accesses key_to_file[key] with a try/except block to catch the
KeyError and re-raise it with additional context information such as the missing
key name and details about the checkpoint structure, making it clearer to users
when they encounter partial or corrupt checkpoints.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9d2e5c4c-495b-4ba4-91d7-f36986b22b06

📥 Commits

Reviewing files that changed from the base of the PR and between 9048d13 and 32bb4e3.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/plugins/accelerate.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py

Comment on lines +112 to +157
def _materialize_fused_experts_from_checkpoint(model, checkpoint):
"""Fuse separate on-disk per-expert weights into the meta fused-MoE params.

load_checkpoint_and_dispatch matches keys verbatim and does not run transformers'
expert fusion, so the fused gate_up_proj/down_proj have no on-disk key and stay on
meta (dispatch then raises). Rebuild them from the local safetensors before dispatch.
Holds every routed expert in CPU RAM, so it does not stream or offload.
"""
fused = [
(name, m)
for name, m in model.named_modules()
if _is_fused_experts_module(m) and m.gate_up_proj.device.type == "meta"
]
if not fused:
return

# Map each checkpoint key to its shard (lazy header read, no tensor data).
key_to_file = {}
for shard in glob.glob(os.path.join(checkpoint, "*.safetensors")):
with safe_open(shard, framework="pt") as f:
key_to_file.update(dict.fromkeys(f.keys(), shard))

handles = {}

def _tensor(key):
shard = key_to_file[key]
if shard not in handles:
handles[shard] = safe_open(shard, framework="pt")
return handles[shard].get_tensor(key)

for name, module in fused:
if f"{name}.0.gate_proj.weight" not in key_to_file:
continue # already fused on disk, or an unhandled layout
dtype = module.gate_up_proj.dtype
gate_up, down = [], []
for e in range(module.num_experts):
gate = _tensor(f"{name}.{e}.gate_proj.weight")
up = _tensor(f"{name}.{e}.up_proj.weight")
gate_up.append(torch.cat([gate, up], dim=0))
down.append(_tensor(f"{name}.{e}.down_proj.weight"))
module.gate_up_proj = torch.nn.Parameter(
torch.stack(gate_up).to(dtype), requires_grad=False
)
module.down_proj = torch.nn.Parameter(torch.stack(down).to(dtype), requires_grad=False)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Safetensors file handles are never explicitly closed.

The handles dictionary accumulates open safe_open file handles but they're never closed. While Python's GC will eventually close them, this can leak file descriptors when processing checkpoints with many shards, potentially hitting OS limits during large model loads.

🔧 Suggested fix: close handles after materializing
     for name, module in fused:
         if f"{name}.0.gate_proj.weight" not in key_to_file:
             continue  # already fused on disk, or an unhandled layout
         dtype = module.gate_up_proj.dtype
         gate_up, down = [], []
         for e in range(module.num_experts):
             gate = _tensor(f"{name}.{e}.gate_proj.weight")
             up = _tensor(f"{name}.{e}.up_proj.weight")
             gate_up.append(torch.cat([gate, up], dim=0))
             down.append(_tensor(f"{name}.{e}.down_proj.weight"))
         module.gate_up_proj = torch.nn.Parameter(
             torch.stack(gate_up).to(dtype), requires_grad=False
         )
         module.down_proj = torch.nn.Parameter(torch.stack(down).to(dtype), requires_grad=False)
+
+    # Close all safetensors file handles
+    for handle in handles.values():
+        handle.__exit__(None, None, None)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/plugins/accelerate.py` around lines 112 - 157,
The file handles stored in the handles dictionary (which are created by
safe_open calls within the _tensor function) are never explicitly closed after
use. After the for loop that processes the fused modules completes, add cleanup
code to explicitly close all the file handles in the handles dictionary by
iterating through handles.values() and calling the appropriate close method on
each handle to prevent file descriptor leaks.

Comment on lines +959 to +975
def _write_separate_expert_checkpoint(tmp_path):
"""Write a checkpoint with SEPARATE per-expert weights; return the fused refs."""
from safetensors.torch import save_file

gate_up_ref = torch.randn(NUM_EXPERTS, 2 * INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02
down_ref = torch.randn(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM) * 0.02
state_dict = {"moe.gate.weight": torch.randn(NUM_EXPERTS, HIDDEN_DIM) * 0.02}
for e in range(NUM_EXPERTS):
state_dict[f"moe.experts.{e}.gate_proj.weight"] = gate_up_ref[
e, :INTERMEDIATE_DIM, :
].contiguous()
state_dict[f"moe.experts.{e}.up_proj.weight"] = gate_up_ref[
e, INTERMEDIATE_DIM:, :
].contiguous()
state_dict[f"moe.experts.{e}.down_proj.weight"] = down_ref[e].contiguous()
save_file(state_dict, str(tmp_path / "model.safetensors"))
return gate_up_ref, down_ref

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Move safetensors import to module level.

Per CONTRIBUTING.md, imports belong at the top of the file so import errors surface at collection time. The safetensors.torch.save_file import on line 961 is inside the helper function without a documented reason (not an optional dependency guard—safetensors is required for this test to run).

🔧 Suggested fix

Add at module level (near other imports):

from safetensors.torch import save_file

Then remove line 961 from the function body.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 959
- 975, The safetensors import statement is currently inside the
_write_separate_expert_checkpoint function body instead of at the module level.
Move the `from safetensors.torch import save_file` import to the top of the file
with the other module-level imports, then remove the import statement from
inside the function body. Since safetensors is a required dependency for this
test, there is no need to guard this import within the function.

Source: Coding guidelines

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant