Good First Issue: Skip MoE expert sorting during decode (MLX backend, Qwen 3.5 MoE)
Area: backends/mlx · MoE lowering
Skill level: Intermediate (PyTorch + MLX builder; no Metal kernels touched)
Impact: Faster decode for Qwen 3.5 MoE on Apple Silicon. Prefill and numerics unchanged.
Problem
The MLX MoE sorts tokens by expert so each expert's tokens are contiguous for a grouped matmul. This helps prefill (many tokens to batch) but is pure overhead during decode (1 token → nothing to batch): the argsort, activation gather, and output scatter all run for nothing.
The sort is decided by a static export-time flag (sort_experts), so a single dynamic-shape .pte runs it in both phases.
Fix: decide sort vs. no-sort at runtime, based on token count M, in the backend lowering — using the existing emit_if_else helper.
Where the code is
| What |
Where |
| MoE forward + the sort |
backends/mlx/llm/switch.py → SwitchMLP.forward (~L220) |
| Per-expert grouped linear |
backends/mlx/llm/switch.py → SwitchLinear (~L130) |
| Grouped-matmul ops |
backends/mlx/custom_ops.py → gather_mm (~L282), gather_qmm (~L325) |
Grouped-matmul handlers (carry sorted_indices → GatherMmNode/GatherQmmNode) |
backends/mlx/ops.py → _gather_mm_handler (~L1540), _gather_qmm_handler (~L1568) |
| Branch helpers |
backends/mlx/builder/op_helpers.py → emit_if_else (~L355), emit_sub_int (~L307), emit_product (~L232), emit_shape (~L177) |
| Pattern to copy |
backends/mlx/custom_kernel_ops/gguf/q4k/linear.py → emit_linear/_emit_linear_fused/_emit_q4k_matmul/_emit_q4k_matvec |
| Temp-slot scoping |
backends/mlx/builder/program_builder.py → tmp_scope |
Export wiring of sort_experts |
examples/models/qwen3_5_moe/export.py (~L33), .../mlx_source_transformations.py → _sparse_moe_forward (~L55), _swap_sparse_moe (~L328) |
| CI E2E gate |
.github/workflows/mlx.yml → test-mlx-qwen35-moe (~L108) |
Line numbers are approximate — search by symbol if they have drifted.
The sort today (SwitchMLP.forward, ~L243):
if sort_experts:
flat_indices = expert_indices.flatten()
order = flat_indices.argsort().to(torch.int32)
inv_order = order.argsort().to(torch.int32)
sorted_idx = flat_indices[order].to(torch.int32)
x_sorted = x[(order // top_k).to(torch.int64)]
...
# 2-3 gather_mm/gather_qmm(sorted_indices=sort_experts), activation, combine
if sort_experts:
down = down[inv_order].reshape(N, top_k, -1) # scatter back
The pattern to copy (q4k/linear.py, ~L457):
m_iov = emit_product(P, emit_shape(P, x_node, x_slot, end_dim=-1)) # M = token count
emit_if_else(
P,
emit_sub_int(P, m_iov, IntOrVid.from_literal(1)), # cond = M - 1 (M==1 → else)
emit_then=lambda: _emit_q4k_matmul(...), # M > 1 → prefill
emit_else=lambda: _emit_q4k_matvec(...), # M == 1 → decode
)
emit_if_else emits an IfNode when M is dynamic, and folds to a single path (zero overhead) when M is a compile-time literal. Both branches write the same output slot.
The change
Introduce a single fused MoE op so the lowering has one node to branch on, then branch it on M.
1. New custom op — backends/mlx/custom_ops.py:
mlx::moe_experts(x, expert_weights, expert_indices, top_k, <stacked gate/up/down weights (+ scale/zero_point if quantized)>) -> Tensor. Its reference (eager) implementation is the unsorted experts MLP — grouped projections (gather_mm/gather_qmm with sorted_indices=False) → activation → weighted combine. Add a register_fake for shape/dtype. Sorting is a layout-only optimization with identical numerics, so the reference never sorts.
2. SwitchMLP.forward — call the op; delete the sort. Replace the body with: collect the stacked weights from the SwitchLinear children and call torch.ops.mlx.moe_experts(...). Delete the sort_experts parameter and the argsort/gather/inv_order/scatter block entirely, plus the sort_experts plumbing in _sparse_moe_forward, _swap_sparse_moe, and export.py.
3. SwitchLinear — unchanged. Its pack() still produces the stacked per-expert weight tensors the op consumes. (SwitchLinear.forward is no longer on the execution path.)
4. Handler _moe_experts_handler — backends/mlx/ops.py:
m_iov = emit_product(P, emit_shape(P, x_node, x_slot, end_dim=-1))
out = P.make_or_get_slot(n)
emit_if_else(
P,
emit_sub_int(P, m_iov, IntOrVid.from_literal(1)),
emit_then=lambda: _emit_moe_sorted(...), # M > 1 → prefill
emit_else=lambda: _emit_moe_unsorted(...), # M == 1 → decode
)
5. Two emit paths (both write the same out slot; wrap temporaries in with P.tmp_scope():):
_emit_moe_unsorted (decode): GatherMmNode/GatherQmmNode(sorted_indices=False) on the original activations + expert_indices, then the weighted combine. No argsort/gather/scatter.
_emit_moe_sorted (prefill): argsort → activation gather → GatherMmNode/GatherQmmNode(sorted_indices=True) → inverse-permutation scatter → combine.
sorted_indices stays on gather_mm/gather_qmm and GatherMmNode/GatherQmmNode — it is the per-branch knob (True in _emit_moe_sorted, False in _emit_moe_unsorted).
Acceptance criteria
Testing
- Unit tests —
backends/mlx/test/test_ops.py (see GatherMmTest ~L6465, GatherQmmTest ~L6532 for the existing pattern; each already has a batch_size=1 decode config). Add a moe_experts op test asserting the lowering matches the unsorted reference for both batch_size=1 (decode) and batch_size>1 (prefill) — this exercises both emit_if_else branches.
- E2E CI gate —
.github/workflows/mlx.yml → test-mlx-qwen35-moe:
python -m executorch.examples.models.qwen3_5_moe.export \
--tiny-test --backend mlx --qlinear 4w --qlinear-group-size 32 \
--output-dir /tmp/qwen35_moe_mlx_tiny
python -m executorch.examples.models.qwen3_5_moe.run \
--pte /tmp/qwen35_moe_mlx_tiny/model.pte --prompt-len 4 --max-new-tokens 5
Must keep the exact output Generated token ids: [167, 94, 253, 88, 227] (prefill + decode, i.e. both branches).
- Perf check (optional): report
decode_token_per_sec before/after.
Good First Issue: Skip MoE expert sorting during decode (MLX backend, Qwen 3.5 MoE)
Area:
backends/mlx· MoE loweringSkill level: Intermediate (PyTorch + MLX builder; no Metal kernels touched)
Impact: Faster decode for Qwen 3.5 MoE on Apple Silicon. Prefill and numerics unchanged.
Problem
The MLX MoE sorts tokens by expert so each expert's tokens are contiguous for a grouped matmul. This helps prefill (many tokens to batch) but is pure overhead during decode (1 token → nothing to batch): the
argsort, activation gather, and output scatter all run for nothing.The sort is decided by a static export-time flag (
sort_experts), so a single dynamic-shape.pteruns it in both phases.Fix: decide sort vs. no-sort at runtime, based on token count
M, in the backend lowering — using the existingemit_if_elsehelper.Where the code is
backends/mlx/llm/switch.py→SwitchMLP.forward(~L220)backends/mlx/llm/switch.py→SwitchLinear(~L130)backends/mlx/custom_ops.py→gather_mm(~L282),gather_qmm(~L325)sorted_indices→GatherMmNode/GatherQmmNode)backends/mlx/ops.py→_gather_mm_handler(~L1540),_gather_qmm_handler(~L1568)backends/mlx/builder/op_helpers.py→emit_if_else(~L355),emit_sub_int(~L307),emit_product(~L232),emit_shape(~L177)backends/mlx/custom_kernel_ops/gguf/q4k/linear.py→emit_linear/_emit_linear_fused/_emit_q4k_matmul/_emit_q4k_matvecbackends/mlx/builder/program_builder.py→tmp_scopesort_expertsexamples/models/qwen3_5_moe/export.py(~L33),.../mlx_source_transformations.py→_sparse_moe_forward(~L55),_swap_sparse_moe(~L328).github/workflows/mlx.yml→test-mlx-qwen35-moe(~L108)The sort today (
SwitchMLP.forward, ~L243):The pattern to copy (
q4k/linear.py, ~L457):emit_if_elseemits anIfNodewhenMis dynamic, and folds to a single path (zero overhead) whenMis a compile-time literal. Both branches write the same output slot.The change
Introduce a single fused MoE op so the lowering has one node to branch on, then branch it on
M.1. New custom op —
backends/mlx/custom_ops.py:mlx::moe_experts(x, expert_weights, expert_indices, top_k, <stacked gate/up/down weights (+ scale/zero_point if quantized)>) -> Tensor. Its reference (eager) implementation is the unsorted experts MLP — grouped projections (gather_mm/gather_qmmwithsorted_indices=False) → activation → weighted combine. Add aregister_fakefor shape/dtype. Sorting is a layout-only optimization with identical numerics, so the reference never sorts.2.
SwitchMLP.forward— call the op; delete the sort. Replace the body with: collect the stacked weights from theSwitchLinearchildren and calltorch.ops.mlx.moe_experts(...). Delete thesort_expertsparameter and theargsort/gather/inv_order/scatter block entirely, plus thesort_expertsplumbing in_sparse_moe_forward,_swap_sparse_moe, andexport.py.3.
SwitchLinear— unchanged. Itspack()still produces the stacked per-expert weight tensors the op consumes. (SwitchLinear.forwardis no longer on the execution path.)4. Handler
_moe_experts_handler—backends/mlx/ops.py:5. Two emit paths (both write the same
outslot; wrap temporaries inwith P.tmp_scope():):_emit_moe_unsorted(decode):GatherMmNode/GatherQmmNode(sorted_indices=False)on the original activations +expert_indices, then the weighted combine. No argsort/gather/scatter._emit_moe_sorted(prefill): argsort → activation gather →GatherMmNode/GatherQmmNode(sorted_indices=True)→ inverse-permutation scatter → combine.sorted_indicesstays ongather_mm/gather_qmmandGatherMmNode/GatherQmmNode— it is the per-branch knob (Truein_emit_moe_sorted,Falsein_emit_moe_unsorted).Acceptance criteria
SwitchMLP.forwardcallsmlx::moe_experts(whose reference impl is the unsorted experts MLP); thesort_expertsflag and source-level sort are removed (including_sparse_moe_forward,_swap_sparse_moe,export.py).M == 1) emits noargsort/gather/scatter and runs the gather withsorted_indices=False.M > 1) is unchanged (sorted path,sorted_indices=True).Mis statically known, noIfNodeis emitted (the helper folds the branch).Testing
backends/mlx/test/test_ops.py(seeGatherMmTest~L6465,GatherQmmTest~L6532 for the existing pattern; each already has abatch_size=1decode config). Add amoe_expertsop test asserting the lowering matches the unsorted reference for bothbatch_size=1(decode) andbatch_size>1(prefill) — this exercises bothemit_if_elsebranches..github/workflows/mlx.yml→test-mlx-qwen35-moe:Generated token ids: [167, 94, 253, 88, 227](prefill + decode, i.e. both branches).decode_token_per_secbefore/after.