Skip to content

Commit d8cd9cc

Browse files
metascroyjpiat
authored andcommitted
[MLX] Support Qwen3.5-35B-3A (pytorch#18785)
# Qwen 3.5 MoE — MLX Backend Support Adds `--backend mlx` to the existing Qwen 3.5 MoE export script, enabling export and inference on Apple Silicon via the MLX delegate. ## What changed **Unified export** (`examples/models/qwen3_5_moe/export.py`) - Added `--backend mlx` alongside existing CUDA path. CUDA path is unchanged. - Added `--model-id` for automatic HuggingFace download. - Added `--tiny-test` for CI validation with random weights (~30s, no download). **MLX source transformations** (`examples/models/qwen3_5_moe/mlx_source_transformations.py`) - Replaces Triton-dependent modules with MLX equivalents: `FusedMoEExperts` → `SwitchMLP`, `GatedDeltaNet` → `mlx::gated_delta_rule` custom op, `FullAttention` → `mlx::rope`, `KVCache` → MLX KVCache, `GemmaRMSNorm` → `F.rms_norm`, `SparseMoE` → removes unnecessary dtype casts. **SwitchLinear / SwitchMLP** (`backends/mlx/llm/switch.py`) - Per-expert linear using `mlx::gather_mm` / `mlx::gather_qmm` custom ops. - `SwitchMLP`: reusable gated MoE MLP with configurable activation and optional gate+up fusion. **Gated delta rule** (`backends/mlx/model_ops/gated_delta_rule.py`) - Custom op with `mutates_args=("state",)` for recurrent state carry-forward. - Pattern handler emits `MetalKernelNode` (fused GPU kernel) or `ScanNode` (fallback), selected via `use_custom_kernel` kwarg on the op. **New ops / schema** - `mlx::gather_mm`, `mlx::gather_qmm`: fused gather + matmul for MoE expert selection. - `GatherMmNode`, `GatherQmmNode`, `ScanNode`, `MetalKernelNode`, `ScatterAddNode` added to FlatBuffer schema + C++ runtime. **Python runner** (`examples/models/qwen3_5_moe/run.py`) - ExecuTorch pybinding runner with tokenizer support and vocab size auto-detection from `.pte` metadata. **CI** (`.github/workflows/mlx.yml`) - `test-mlx-qwen35-moe`: tiny model export + inference with deterministic output assertion + AsType node count check (≤23). - `test_gated_delta_rule` tests added to `test-mlx` job. ## Usage Export (downloads model automatically): python export.py --model-id Qwen/Qwen3.5-35B-A3B --backend mlx --qlinear 4w --qlinear-group-size 64 --output-dir ./qwen35_moe_mlx Run: python -m executorch.examples.models.qwen3_5_moe.run --pte ./qwen35_moe_mlx/model.pte --tokenizer Qwen/Qwen3.5-35B-A3B --prompt "What is the capital of France?" CI test (no download): python export.py --tiny-test --backend mlx --qlinear 4w --output-dir /tmp/tiny python -m executorch.examples.models.qwen3_5_moe.run --pte /tmp/tiny/model.pte --prompt-len 4 --max-new-tokens 5 ## Further optimization ideas: * Write a chunked GDN kernel * Turn off expert sorting in decode
1 parent 30a1c83 commit d8cd9cc

19 files changed

Lines changed: 4172 additions & 46 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- extension/audio/**
1414
- examples/models/parakeet/**
1515
- examples/models/voxtral_realtime/**
16+
- examples/models/qwen3_5_moe/**
1617
workflow_dispatch:
1718

1819
permissions: {}
@@ -63,6 +64,61 @@ jobs:
6364
./cmake-out/backends/mlx/test/multi_thread_test_runner
6465
echo "::endgroup::"
6566
67+
echo "::group::Run gated_delta_rule op tests"
68+
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v
69+
echo "::endgroup::"
70+
71+
test-mlx-qwen35-moe:
72+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
73+
with:
74+
job-name: test-mlx-qwen35-moe
75+
runner: macos-14-xlarge
76+
python-version: "3.12"
77+
submodules: recursive
78+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
79+
timeout: 90
80+
script: |
81+
set -eux
82+
83+
echo "::group::Install ExecuTorch"
84+
${CONDA_RUN} python install_executorch.py > /dev/null
85+
echo "::endgroup::"
86+
87+
${CONDA_RUN} pip list
88+
89+
echo "::group::Export Qwen 3.5 MoE (tiny model)"
90+
${CONDA_RUN} python -m executorch.examples.models.qwen3_5_moe.export \
91+
--tiny-test \
92+
--backend mlx \
93+
--qlinear 4w \
94+
--qlinear-group-size 32 \
95+
--output-dir /tmp/qwen35_moe_mlx_tiny
96+
echo "::endgroup::"
97+
98+
echo "::group::Check AsType node count"
99+
ASTYPE_COUNT=$(${CONDA_RUN} python -m executorch.backends.mlx.pte_inspector \
100+
/tmp/qwen35_moe_mlx_tiny/model.pte --mlx-instructions 2>&1 | grep -c "AsTypeNode" || true)
101+
echo "AsType nodes: ${ASTYPE_COUNT}"
102+
if [ "$ASTYPE_COUNT" -gt 23 ]; then
103+
echo "Failed: expected no more than 23 AsType nodes, got ${ASTYPE_COUNT}"
104+
exit 1
105+
fi
106+
echo "::endgroup::"
107+
108+
echo "::group::Run Qwen 3.5 MoE inference"
109+
OUTPUT=$(${CONDA_RUN} python -m executorch.examples.models.qwen3_5_moe.run \
110+
--pte /tmp/qwen35_moe_mlx_tiny/model.pte \
111+
--prompt-len 4 \
112+
--max-new-tokens 5 2>&1)
113+
echo "$OUTPUT"
114+
if echo "$OUTPUT" | grep -q "Generated token ids: \[167, 167, 81, 167, 81\]"; then
115+
echo "Success: Qwen 3.5 MoE MLX export + inference completed with expected output"
116+
else
117+
echo "Failed: unexpected output (expected [167, 167, 81, 167, 81])"
118+
exit 1
119+
fi
120+
echo "::endgroup::"
121+
66122
backend-tester:
67123
strategy:
68124
fail-fast: false

backends/mlx/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,14 @@ add_subdirectory(${MLX_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mlx)
247247
# Op logging option (for debugging) - OFF by default for performance
248248
option(ET_MLX_ENABLE_OP_LOGGING "Enable per-op logging in MLX delegate" OFF)
249249

250+
# Custom kernel execution - OFF by default for security. When enabled,
251+
# MetalKernelNode can execute arbitrary Metal shader code embedded in .pte
252+
# files. Only enable for trusted .pte sources.
253+
option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
254+
"Allow MetalKernelNode to execute custom Metal shaders from .pte files"
255+
ON
256+
)
257+
250258
set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
251259
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
252260
)
@@ -262,6 +270,13 @@ if(ET_MLX_ENABLE_OP_LOGGING)
262270
message(STATUS "MLX delegate op logging ENABLED")
263271
endif()
264272

273+
if(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION)
274+
target_compile_definitions(
275+
mlxdelegate PRIVATE ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
276+
)
277+
message(STATUS "MLX delegate custom kernel execution ENABLED")
278+
endif()
279+
265280
target_include_directories(
266281
mlxdelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime
267282
)

backends/mlx/builder/program_builder.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import traceback
2525
from collections import defaultdict
26+
from contextlib import contextmanager
2627
from dataclasses import dataclass
2728
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union
2829

@@ -172,6 +173,24 @@ def emit_init(self, op: OpNodeUnion) -> None:
172173
self._chains.append([])
173174
self._chains[self.init_chain_idx].append(Instruction(op=op))
174175

176+
@contextmanager
177+
def new_chain(self):
178+
"""Context manager that creates a new instruction chain and redirects emit() to it.
179+
180+
Usage:
181+
with P.new_chain() as chain_idx:
182+
P.emit(MulNode(...)) # goes to the new chain
183+
# P.emit() goes back to the previous chain
184+
"""
185+
chain_idx = len(self._chains)
186+
self._chains.append([])
187+
prev_chain = self._current_chain
188+
self._current_chain = chain_idx
189+
try:
190+
yield chain_idx
191+
finally:
192+
self._current_chain = prev_chain
193+
175194
def args(self, node: Node) -> Tuple[Any, ...]:
176195
return self.slot_map(node.args)
177196

@@ -629,9 +648,12 @@ def _verify_build(self):
629648
info.handler in (noop_handler, PatternHandler.deferred_handler)
630649
or n.users == {}
631650
):
632-
assert (
633-
self.slot_manager.get_slot(n) is None
634-
), f"Did not expect node {n} handled by {info.handler} to have a slot"
651+
# Deferred body nodes may or may not have slots — this is fine.
652+
# Pattern handlers absorb nodes into their body and may set
653+
# slots on them (e.g., GatedDeltaRuleHandler sets getitem[0]'s
654+
# slot to the ScanNode output). Dead nodes (no users) also
655+
# skip the slot check.
656+
pass
635657
else:
636658
assert (
637659
self.slot_manager.get_slot(n) is not None
@@ -962,6 +984,11 @@ def get_named_data_store(self) -> NamedDataStore:
962984
``ep.constants`` / ``extra_constants`` (which all use unprefixed
963985
keys). The prefix is applied at the exit boundary — the
964986
``NamedDataStore`` key — so it matches the FlatBuffer ``named_slots``.
987+
988+
To reduce peak memory, each constant is deleted from the EP
989+
immediately after its bytes are added to the NamedDataStore.
990+
This avoids holding two full copies of all constants simultaneously
991+
(important for large models where constants can be 20+ GB).
965992
"""
966993
named_data_store = NamedDataStore()
967994

@@ -971,6 +998,17 @@ def get_named_data_store(self) -> NamedDataStore:
971998
key=lambda x: self._slot_to_final_tid.get(x[1], 0),
972999
)
9731000

1001+
# Free EP constants not used by the MLX graph to reduce peak memory.
1002+
used = set(self._constant_name_to_slot.keys())
1003+
for ispec in self.ep.graph_signature.input_specs:
1004+
if ispec.arg.name in used and ispec.target is not None:
1005+
used.add(ispec.target)
1006+
1007+
for d in (self.ep._state_dict, self.ep._constants):
1008+
for name in list(d.keys()):
1009+
if name not in used and isinstance(d[name], torch.Tensor):
1010+
del d[name]
1011+
9741012
logger.debug(f"Adding {len(entries)} constants to NamedDataStore...")
9751013
for canonical_name, _slot in entries:
9761014
tensor = self._find_constant_tensor(canonical_name)
@@ -983,6 +1021,15 @@ def get_named_data_store(self) -> NamedDataStore:
9831021
data=t,
9841022
alignment=16,
9851023
)
1024+
1025+
# Free the original tensor from the EP immediately.
1026+
# The contiguous copy is now serialized as bytes in the
1027+
# NamedDataStore — the EP reference is no longer needed.
1028+
# (It would be deleted by lowered_backend_module.py after
1029+
# preprocess() returns anyway.)
1030+
self._delete_constant_tensor(canonical_name)
1031+
del tensor, t
1032+
9861033
logger.debug("Done adding constants to NamedDataStore")
9871034

9881035
return named_data_store
@@ -1011,17 +1058,33 @@ def get_mutable_buffer_names(self) -> List[str]:
10111058

10121059
def _find_constant_tensor(self, name: str) -> Optional[torch.Tensor]:
10131060
"""Find a constant tensor by name from various sources."""
1014-
if name in self.ep.state_dict:
1015-
return self.ep.state_dict[name]
1016-
if name in self.ep.constants:
1017-
return self.ep.constants[name]
1061+
result = self._resolve_constant(name)
1062+
if result is None:
1063+
return None
1064+
1065+
d, k = result
1066+
return d[k]
1067+
1068+
def _delete_constant_tensor(self, name: str) -> None:
1069+
"""Delete a constant from the EP to free memory during serialization."""
1070+
1071+
result = self._resolve_constant(name)
1072+
if result:
1073+
d, k = result
1074+
del d[k]
1075+
1076+
def _resolve_constant(self, name):
1077+
"""Returns (dict, key) or None."""
1078+
if name in self.ep._state_dict:
1079+
return self.ep._state_dict, name
1080+
if name in self.ep._constants:
1081+
return self.ep._constants, name
10181082
if name in self.extra_constants:
1019-
return self.extra_constants[name]
1020-
# Look up by target
1083+
return self.extra_constants, name
10211084
for ispec in self.ep.graph_signature.input_specs:
10221085
if ispec.arg.name == name and ispec.target is not None:
1023-
if ispec.target in self.ep.state_dict:
1024-
return self.ep.state_dict[ispec.target]
1025-
if ispec.target in self.ep.constants:
1026-
return self.ep.constants[ispec.target]
1086+
if ispec.target in self.ep._state_dict:
1087+
return self.ep._state_dict, ispec.target
1088+
if ispec.target in self.ep._constants:
1089+
return self.ep._constants, ispec.target
10271090
return None

backends/mlx/builder/slot_manager.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,26 @@ class IdSpace(Enum):
3030
Temp = auto()
3131

3232

33-
@dataclass(frozen=True)
33+
@dataclass(eq=False, frozen=True)
3434
class Slot:
35+
"""Represents an allocated tensor or symbolic int slot.
36+
37+
Uses identity-based equality and hashing (not field-based) so that
38+
two Slots with the same (id_type, id_space, idx) — which can happen
39+
when the delete-as-you-go allocator recycles an idx — remain distinct
40+
in sets and dicts during build().
41+
"""
42+
3543
id_type: IdType
3644
id_space: IdSpace
3745
idx: Optional[int] = None
3846

47+
def __eq__(self, other):
48+
return self is other
49+
50+
def __hash__(self):
51+
return id(self)
52+
3953

4054
class IdManager:
4155
def __init__(self):

backends/mlx/custom_ops.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,117 @@ def rope_fake(
269269
) -> Tensor:
270270
"""Fake implementation for tracing."""
271271
return x.new_empty(x.shape)
272+
273+
274+
@torch.library.custom_op("mlx::gather_mm", mutates_args=())
275+
def gather_mm(
276+
a: Tensor, # [..., M, K]
277+
b: Tensor, # [E, K, N] or [..., K, N]
278+
rhs_indices: Optional[Tensor] = None, # Expert selection indices
279+
lhs_indices: Optional[Tensor] = None, # Optional LHS gather indices
280+
sorted_indices: bool = False,
281+
) -> Tensor:
282+
"""
283+
Gather matrix multiply — matches mlx::core::gather_mm semantics exactly.
284+
285+
Output shape = broadcast(lhs_indices, rhs_indices).shape + [M, N]
286+
where M = a.shape[-2], N = b.shape[-1].
287+
288+
For MoE: a=[N_tokens, 1, K], b=[E, K, out], rhs_indices=[N_tokens]
289+
→ output=[N_tokens, 1, out]. Caller squeezes dim -2.
290+
"""
291+
if rhs_indices is not None:
292+
b_sel = b[rhs_indices]
293+
else:
294+
b_sel = b
295+
return torch.matmul(a, b_sel)
296+
297+
298+
@torch.library.register_fake("mlx::gather_mm")
299+
def gather_mm_fake(
300+
a: Tensor,
301+
b: Tensor,
302+
rhs_indices: Optional[Tensor] = None,
303+
lhs_indices: Optional[Tensor] = None,
304+
sorted_indices: bool = False,
305+
) -> Tensor:
306+
# Matches MLX: output = indices.shape + [M, N]
307+
# For simplicity, use matmul shape rules after gather
308+
M = a.shape[-2]
309+
N = b.shape[-1]
310+
if rhs_indices is not None:
311+
batch = rhs_indices.shape
312+
else:
313+
batch = b.shape[:-2]
314+
return a.new_empty((*batch, M, N))
315+
316+
317+
@torch.library.custom_op("mlx::gather_qmm", mutates_args=())
318+
def gather_qmm(
319+
x: Tensor, # [..., M, K]
320+
w: Tensor, # [E, out, in_packed]
321+
scales: Tensor, # [E, out, in//gs]
322+
biases: Optional[Tensor] = None, # [E, out, in//gs] (affine mode)
323+
rhs_indices: Optional[Tensor] = None, # Expert selection indices
324+
lhs_indices: Optional[Tensor] = None, # Optional LHS gather indices
325+
transpose: bool = True,
326+
group_size: int = 32,
327+
bits: int = 4,
328+
mode: str = "affine",
329+
sorted_indices: bool = False,
330+
) -> Tensor:
331+
"""
332+
Gather quantized matrix multiply — matches mlx::core::gather_qmm semantics.
333+
334+
Output shape = broadcast(lhs_indices, rhs_indices).shape + [M, N]
335+
336+
For MoE: x=[N_tokens, 1, K], w=[E, out, K_packed], rhs_indices=[N_tokens]
337+
→ output=[N_tokens, 1, out]. Caller squeezes dim -2.
338+
"""
339+
# Eager fallback: gather, dequantize, matmul
340+
if rhs_indices is not None:
341+
w_sel = w[rhs_indices]
342+
s_sel = scales[rhs_indices]
343+
b_sel = biases[rhs_indices] if biases is not None else None
344+
else:
345+
w_sel = w
346+
s_sel = scales
347+
b_sel = biases
348+
349+
# Dequantize
350+
w_float = w_sel.to(x.dtype)
351+
s_expanded = s_sel.repeat_interleave(group_size, dim=-1)
352+
if b_sel is not None:
353+
b_expanded = b_sel.repeat_interleave(group_size, dim=-1)
354+
w_dequant = w_float * s_expanded + b_expanded
355+
else:
356+
w_dequant = w_float * s_expanded
357+
358+
if transpose:
359+
w_dequant = w_dequant.transpose(-1, -2)
360+
361+
return torch.matmul(x, w_dequant)
362+
363+
364+
@torch.library.register_fake("mlx::gather_qmm")
365+
def gather_qmm_fake(
366+
x: Tensor,
367+
w: Tensor,
368+
scales: Tensor,
369+
biases: Optional[Tensor] = None,
370+
rhs_indices: Optional[Tensor] = None,
371+
lhs_indices: Optional[Tensor] = None,
372+
transpose: bool = True,
373+
group_size: int = 32,
374+
bits: int = 4,
375+
mode: str = "affine",
376+
sorted_indices: bool = False,
377+
) -> Tensor:
378+
# Matches MLX: output = indices.shape + [M, N]
379+
M = x.shape[-2]
380+
N = w.shape[-2] if transpose else w.shape[-1]
381+
if rhs_indices is not None:
382+
batch = rhs_indices.shape
383+
else:
384+
batch = w.shape[:-2]
385+
return x.new_empty((*batch, M, N))

0 commit comments

Comments
 (0)