feat(compile): CustomKernel and GatherQMM implement output_shapes for shapeless compile#3485
Open
dexwritescode wants to merge 7 commits intoml-explore:mainfrom
Open
feat(compile): CustomKernel and GatherQMM implement output_shapes for shapeless compile#3485dexwritescode wants to merge 7 commits intoml-explore:mainfrom
dexwritescode wants to merge 7 commits intoml-explore:mainfrom
Conversation
`mx::compile(shapeless=true)` calls `Primitive::output_shapes()` on
every node when re-tracing a compiled function with changed input
shapes. `CustomKernel` never implemented this override, so any
compiled function containing a `metal_kernel` / `custom_kernel` call
would throw:
[Primitive::output_shapes] CustomKernel cannot infer output shapes
The output shapes are already provided by the caller at creation time
via `metal_kernel()(inputs, output_shapes, ...)` and passed to
`array::make_arrays`. They just weren't stored on the primitive.
Fix: add an optional `output_shapes` parameter to the `CustomKernel`
constructor (default `{}` for backward compatibility), store it in a
new `output_shapes_` member, and override `output_shapes()` to return
it. If the field is empty (legacy construction path), fall through to
the base-class throw as before.
Update both Metal and CUDA call sites to copy the shapes before
`std::move`-ing them into `array::make_arrays` and pass the copy to
the constructor.
output_shapes() is called on every primitive during shapeless=true retracing. GatherQMM was missing this override, causing compile to throw when any graph containing gather_qmm was retraced. The output shape is fully inferrable from inputs and stored fields: out_shape = lhs_indices.shape() + [x.shape(-2), w_outer_dims] where w_outer_dims = transpose ? w.shape(-2) : w.shape(-1)*32/bits. Input layout differs by mode: Affine has biases at index 3 (pushing indices to 4/5); other modes have indices at 3/4.
zcbenz
reviewed
May 6, 2026
Collaborator
zcbenz
left a comment
There was a problem hiding this comment.
The change overall looks good to me, can you add some simple tests?
Address review feedback from zcbenz: - output_shapes is a const& in the lambda parameter, so std::move(output_shapes) compiles but silently copies rather than moves. Remove the misleading std::move in both metal and cuda backends — make_arrays receives a plain copy. - Fix one extra space in the GatherQMM input layout comment to correctly align lhs_idx under the Affine layout line.
…rQMM Verify that mx.compile(shapeless=True) correctly re-traces functions containing mx.fast.metal_kernel (CustomKernel) and mx.gather_qmm (GatherQMM) when input shapes change between calls. Both tests fail before the fix with the respective 'cannot infer output shapes' error and pass after output_shapes() is implemented.
dexwritescode
commented
May 6, 2026
Author
dexwritescode
left a comment
There was a problem hiding this comment.
Added two tests to python/tests/test_compile.py:
test_shapeless_compile_custom_kernel— compiles ametal_kernelpassthrough withshapeless=True, then calls it with a larger shape. Fails before the fix withCustomKernel cannot infer output shapes.test_shapeless_compile_gather_qmm— compiles agather_qmmcall withshapeless=True, then calls it with a differentMdimension. Fails before the fix withGatherQMM cannot infer output shapes.
Also addressed the two inline comments (removed std::move on const ref in both Metal and CUDA backends, fixed comment alignment in primitives.h).
Remove the intermediate output_shapes_copy and pass output_shapes directly to the CustomKernel constructor, which takes it by value.
cklxx
added a commit
to cklxx/arle
that referenced
this pull request
May 7, 2026
…:async_eval encoder Adds INFER_CPP_PHASE_TIMING=1 stderr probes around the two C++ FFI hot paths so we can split "Rust async_kick = 23ms" into its components: - crates/mlx-sys/src/mlx_qwen35_model.cpp:2541 — `forward_build_us` around `m->forward(inputs)` (lazy graph build). - crates/mlx-sys/src/mlx_bridge.cpp:2072 — `async_eval_call_us` around the actual `mx::async_eval(arrs)` call (encoder + commit work). Cached env probe (one atomic read after first call); zero prod cost when env unset. file-static helper in each TU to avoid header churn. Bench (Qwen3.6 35B-A3B-4bit, c=4 + c=8): forward_build_us c=4 p50 = 1509μs ← lazy graph build is FAST forward_build_us c=8 p50 = 1793μs async_eval_call_us count=82 p50 = 24992μs ← here's the 25ms (count=82 = logits + new_sampled + 80 packed_kv_flat slabs) → Hypothesis confirmed (per MLX async_eval research subagent this date): mx::async_eval does graph traversal + Metal command-buffer encoding SYNCHRONOUSLY on the calling thread. Only GPU completion is async. For a 40-layer MoE forward (~600-1000 primitives at c=4-8), the ~25ms is real CPU encoder work — NOT GPU compute. Confirmed by mlx/transforms.cpp eval_impl(... async=true) which only skips the final wait, never offloads encoding. Erratum: AGENTS.md narrows the previous "MLX_MAX_OPS_PER_BUFFER=200 recommended for any Metal bench at c≥8" recommendation. That was Qwen3.5-dense-specific and benched as wash-or-loss on Qwen3.6 MoE per docs/experience/wins/2026-05-07-bench-qwen36-baseline.md. Removed from default guidance; downgraded to "per-workload matched-A/B tunable". Auto-wired-limit (default since 180e48b) is the canonical Metal serving knob. Wins entry: docs/experience/wins/2026-05-07-bench-qwen36-encode-bottleneck.md captures the localization, the implications (mx::compile blocked on ml-explore/mlx#3485, multi-thread encode blocked by #3078), and four viable next levers ranked S/M. 644 infer tests pass; clippy --features metal -- -D warnings clean. Bench: doc + instrumentation only; no hot-path behavior change when env unset (default).
lhs_indices shape (8,) cannot broadcast with the auto-generated rhs_indices arange(num_experts) shape (4,), causing the second shapeless compile call to fail during graph update. Fix by keeping both indices fixed at shape (num_experts,) and varying only the M dimension via x.shape = (num_experts, M, K).
grid=(x.size, 1, 1) is captured as a fixed tuple at trace time. On the second shapeless-compile call (x.size=8) the primitive still holds grid=(4,1,1), so only 4 of 8 output elements are written and array_equal fails. The test goal is to verify output_shapes prevents a throw and returns the correct shape — not value correctness, which would require the grid to be updated dynamically (out of scope).
Author
|
Hey @zcbenz, I've pushed a change to fix a test that failed. Would appreciate if you can trigger the CI run. Thanks! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
mx::compile(shapeless=true)callsPrimitive::output_shapes()on everynode when re-tracing a compiled function after input shapes change. Two
primitives were missing this override, causing compiled functions that
contain them to throw at runtime:
This makes it impossible to use
mx::compileon models that combine customMetal kernels with gather-quantized-matmul — for example, hybrid SSM+attention
MoE models (like Qwen3 MoE) where the SSM step uses a custom Metal kernel and
the MoE routing uses
gather_qmm.Fix
CustomKernel (
mlx/fast_primitives.h,mlx/backend/metal/custom_kernel.cpp,mlx/backend/cuda/custom_kernel.cpp)The output shapes are already provided by the caller at creation time via
metal_kernel()(inputs, output_shapes, ...)and passed toarray::make_arrays.They just weren't stored on the primitive.
Add an optional
output_shapesconstructor parameter (default{}— backwardcompatible), store in
output_shapes_member, overrideoutput_shapes()toreturn it. Falls through to the base-class throw when empty (legacy path).
GatherQMM (
mlx/primitives.h)The output shape is fully inferrable from the stored fields and input shapes:
where
w_outer_dims = transpose ? w.shape(-2) : w.shape(-1) * 32 / bits.Input layout differs by quantization mode: Affine mode has biases at index 3,
pushing lhs_indices to index 4; other modes have lhs_indices at index 3.
Testing
Verified by enabling
mx::compile(shapeless=true)on a 94-layer hybridSSM+attention MoE model (Qwen3.6-35B-A3B-4bit) where the GatedDeltaNet SSM
step uses a custom Metal kernel and the MoE routing uses
gather_qmm.Previously crashed on re-trace; with this fix the compiled graph is reused
correctly across decode steps.