Skip to content

# Good First Issue: In-Model Sampling Head for the MLX Backend (Gumbel Sampling) #20353

@metascroy

Description

@metascroy

Good First Issue: In-Model Sampling Head for the MLX Backend (Gumbel Sampling)

Summary

Add a new backends/mlx/llm/sampling.py utility that wraps a model
returning logits and performs token sampling inside the model (on
device), so the exported .pte returns a sampled token id instead of a
full [B, S, vocab] logits tensor. Greedy (argmax) sampling already
works end-to-end on the MLX backend today; this issue adds stochastic
sampling via the Gumbel-max trick
with temperature, which requires
exactly one new MLX nodeRandomBitsNode, a random source.

Moving sampling on-device avoids copying the (often large) logits tensor
back to the host every decode step and removes the host-side
softmax + multinomial from the hot loop.

The new sampling op should follow the established MLX custom-op pattern in
backends/mlx/custom_kernel_ops/gated_delta_rule.py — a torch.library
custom op + register_fake, plus an op handler (or PatternHandler)
registered into the MLX program builder. The Gumbel-max trick is
almost entirely expressible with nodes that already exist in the
schema (ArgmaxNode, LogNode, NegNode, DivideNode, AddNode); the
only new schema node needed is RandomBitsNode, which maps directly onto
mlx::core::random (already vendored — see
backends/mlx/third-party/mlx/mlx/random.h).

Feasibility: HIGH

The backend already has every piece except a random source:

Capability Status Where
Greedy / argmax sampling Already works ArgmaxNode (schema.fbs:482), _argmax_handler (ops.py:3410), exec_argmax (MLXInterpreter.h:1693)
Temperature scaling Already works DivideNode / MultiplyNode (schema.fbs:198,204)
Softmax / log-softmax Already works SoftmaxNode (schema.fbs:364), _softmax_handler (ops.py:1195)
top-k building blocks Already works SortNode, ArgsortNode, PartitionNode, ArgPartitionNode (schema.fbs:668-692)
log, neg (for Gumbel noise) Already works LogNode (schema.fbs:359), NegNode (schema.fbs:830)
Random uniform / bits MISSING (this issue) C++ side already exists: mlx::core::random::uniform/bits/gumbel/categorical in backends/mlx/third-party/mlx/mlx/random.h

Because greedy sampling already lowers cleanly, the argmax-only
variant of the wrapper is a no-new-op change and can land first (the
incremental first step in Part 1). The new node is only needed for the
stochastic path.

Why the Gumbel-max trick

Sampling a token from softmax(logits / T) is equivalent to

token = argmax( logits / T + g ),    g_i = -log(-log(u_i)),   u_i ~ Uniform(0, 1)

This is attractive for this backend because:

  • It reuses the existing, already-lowered ArgmaxNode for the actual
    selection — no new reduction kernel is required.
  • g = -log(-log(u)) is just NegNode(LogNode(NegNode(LogNode(u))))
    all existing nodes.
  • The only new primitive needed is u (a uniform sample), which itself
    bottoms out at a single random op — see below.

How MLX layers its RNG (the only true primitive is RandomBits)

MLX builds its entire RNG stack as pure compositions over one primitive
(backends/mlx/third-party/mlx/mlx/random.cpp):

RandomBits            <- the ONLY true primitive (lazy op, primitives.h:1715)
  bits(shape, width, key)     array backed by RandomBits + a uint32 key
    uniform()  = bits / uint32_max -> minimum(upper) -> astype -> range*x + lo   (random.cpp:95)
      gumbel() = negative(log(negative(log(uniform()))))                         (random.cpp:367)

Every op in that chain already exists as a schema node — DivideNode,
MinimumNode, AsTypeNode, MultiplyNode, AddNode, SubtractNode,
LogNode, NegNode, ArgmaxNode — except RandomBits. So this issue
adds exactly one new node, RandomBitsNode, and composes
uniform → gumbel → argmax from existing nodes, exactly the way MLX does.

Existing Building Blocks

  • Custom-op + builder-handler pattern (canonical):
    backends/mlx/custom_kernel_ops/gated_delta_rule.pytorch.library.custom_op
    • register_fake + PatternHandler + REGISTRY.register_pattern. Its
      test backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py is
      the template for the new op's test.
  • Simple 1:1 op handler pattern: _argmax_handler
    (backends/mlx/ops.py:3410) and _softmax_handler
    (backends/mlx/ops.py:1195) show how to register a handler against a
    torch op target via @REGISTRY.register(target=[...]) and emit schema
    nodes. The new random op handler follows this shape.
  • Schema + codegen workflow: backends/mlx/serialization/schema.fbs
    is the source of truth. After editing it, run
    python backends/mlx/serialization/generate.py, which regenerates:
    • serialization/_generated/ (flatbuffer Python bindings),
    • serialization/mlx_graph_schema.py (Python dataclasses),
    • serialization/_generated_serializers.py,
    • _generated_inspector.py,
    • runtime/schema_generated.h,
    • partial runtime/MLXLoader.h (struct + OpCode enum) and
      runtime/MLXLoader.cpp (deserialize switch).
  • Runtime execution dispatch: backends/mlx/runtime/MLXInterpreter.h
    — each node has an exec_<node>() (e.g. exec_argmax, line 1693) and
    is wired into the eval switch (around line 2058). The new node's
    exec_* calls into mlx::core::random.
  • MLX RNG (already vendored):
    backends/mlx/third-party/mlx/mlx/random.h exposes random::key(seed),
    random::split, random::bits, random::uniform, random::gumbel,
    and random::categorical.
  • Where the wrapper plugs in: model source transforms live in
    backends/mlx/llm/source_transformation.py (e.g.
    replace_et_kv_cache_with_mlx, transform_attention_mha_to_mlx). The
    sampling head is applied as the outermost wrapper around the exported
    module, analogous to those transforms.

Design Decision: deterministic, export-safe RNG

Random ops are stateful, which is a problem for two reasons: torch.export
requires a functional graph, and tests/users need reproducibility. MLX's
RNG is counter/key-based (stateless), which solves both — the runtime
derives a key from a seed rather than mutating global state.

Recommended approach: thread an explicit seed (per decode step) as an
op input.
This mirrors how input_pos / start_pos is already threaded
through backends/mlx/llm/cache.py::KVCache.update. The exported graph
stays pure; the runtime computes random::key(seed)
random::bits(...). Callers pass seed = base_seed + step so each
decode step draws fresh noise while remaining fully reproducible.

This avoids needing a mutable RNG-state buffer (the heavier alternative,
which would require buffer-mutation plumbing like the KV cache).

Optional seed: seeded vs. unseeded is an export-time choice

seed is optional (Optional[Tensor], default None). Whether the
exported program is seeded or not is fixed at export time by whether a
seed tensor is supplied, because torch.export specializes on structure
(None ⇒ no seed input wired in; a tensor ⇒ seed is a graph input). It is
not a per-call runtime toggle within one .pte.

  • Seeded export (seed tensor passed at export): RandomBitsNode
    carries the seed Vid; the runtime uses an explicit
    random::key(seed). Deterministic, reproducible, host-controlled per
    step. Use this for tests and reproducible generation.
  • Unseeded export (seed=None at export): RandomBitsNode is emitted
    with no seed field, so the runtime passes std::nullopt to
    random::bits, which falls back to MLX's global
    KeySequence::default_().next() (random.cpp:41). Convenient "just
    give me randomness," but it reintroduces process-global mutable RNG
    state
    — non-reproducible and no longer a pure graph (with
    thread-safety/determinism caveats). Acceptable for casual use, not for
    tests.

The seeding-semantics note below applies to the seeded export.

Seeding semantics (important for correct generation)

There is no hidden RNG state in the .pteRandomBitsNode is pure:
random::key(seed) is deterministic, so the same seed always yields the
same draw. The host loop therefore fully controls cross-call randomness
via the seed input:

  • Within one call: one bits() draw fills the entire [B, vocab]
    noise array. MLX's threefry is counter-based, so each element uses a
    distinct counter off the same key — the gumbel noise across the vocab is
    internally decorrelated even though it comes from a single seed.
  • Across calls — must advance the seed. Passing the same seed
    every decode step reuses identical gumbel noise each step (tokens still
    differ because logits differ, but the noise is perfectly correlated,
    which biases generation and breaks the i.i.d. sampling assumption).
    Passing a distinct seed per step (seed = base_seed + step) draws
    fresh, independent noise each step. Counter-based threefry decorrelates
    consecutive integer seeds well, so base + step is sufficient — no
    random::split / key-state threading (the stateful approach we
    deliberately avoid) is needed.

Net effect: same seed across two separate generations → reproducible
output (useful for tests); same seed across steps within one generation
→ the failure mode to avoid. The host advances seed every step.

Proposed Design

Land in this order; each part is independently reviewable and testable.

Part 1: SamplingHead wrapper (runtime temperature + seed)

New file backends/mlx/llm/sampling.py. temperature is a runtime input
to forward (a graph input, settable per-call without re-exporting);
seed is optional and selects the seeded vs. unseeded export (see
"Optional seed" above):

class SamplingHead(nn.Module):
    """
    Wraps a model that returns logits and samples a token id on-device.

        forward(*model_args, temperature, seed=None, **model_kwargs) -> token_id

      temperature: scalar float tensor, e.g. torch.tensor(0.8)
      seed:        scalar int tensor (seeded) or None (unseeded export)
    """
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    def forward(self, *args, temperature, seed=None, **kwargs):
        logits = self.model(*args, **kwargs)      # [B, S, vocab]
        last = logits[:, -1, :]                   # [B, vocab]
        return torch.ops.mlx.sample(last, temperature, seed)  # Part 2

Greedy decoding is temperature → 0: as T shrinks, logits / T
dominates the Gumbel noise and argmax(logits/T + g) → argmax(logits).
Pass a small temperature for near-greedy sampling. (Exact
runtime-selectable greedy — a true argmax branch when temperature == 0
could use the schema's IfNode, but that is optional and out of scope
for the first landing.)

Incremental first step: a greedy-only variant (fixed
torch.argmax(last, dim=-1), no temperature/seed inputs) needs no
new op
and already exports + lowers through MLXPartitioner today. Land
it first as a smoke test of the wrapper + partitioner path, then add the
runtime-temperature op below.

Part 2: mlx::sample custom op + reference implementation

Add the op next to the other custom ops (e.g.
backends/mlx/custom_kernel_ops/sample.py), following
gated_delta_rule.py. temperature is a tensor input (runtime-settable);
seed is an optional tensor input — None selects the unseeded
(global-RNG) export (see "Optional seed" above):

@torch.library.custom_op("mlx::sample", mutates_args=())
def sample(
    logits: Tensor, temperature: Tensor, seed: Optional[Tensor] = None
) -> Tensor:
    """
    Gumbel-max sampling from softmax(logits / temperature).
    logits:      [B, vocab]
    temperature: scalar float tensor    (runtime input)
    seed:        scalar int tensor or None
                 - tensor -> deterministic, keyed RNG (random::key(seed))
                 - None   -> MLX global KeySequence (non-deterministic)
    -> token_id: [B] int64
    Reference (CPU) implementation for export + numerical parity.
    """
    if seed is None:
        u = torch.rand(logits.shape)                       # global RNG
    else:
        gen = torch.Generator().manual_seed(int(seed.item()))
        u = torch.rand(logits.shape, generator=gen)
    gumbel = -torch.log(-torch.log(u))
    return torch.argmax(logits / temperature + gumbel, dim=-1)

@torch.library.register_fake("mlx::sample")
def sample_fake(logits, temperature, seed=None):
    return logits.new_empty(logits.shape[:-1], dtype=torch.long)

The reference impl makes the numerical contract testable on CPU before
any runtime work lands.

Part 3: new random schema node + handler + runtime

  1. Schema (backends/mlx/serialization/schema.fbs): add a table and
    append it to the OpNode union (append-only — see the BC rules at
    the top of the file). The new node mirrors MLX's only true random op,
    RandomBits:

    table RandomBitsNode {
        out: Tid (required);
        shape: [IntOrVid] (required);
        seed: IntOrVid;              // OPTIONAL: present -> random::key(seed);
                                     //           absent  -> MLX global KeySequence
        width: int32 = 4;            // bytes per element (4 -> uint32)
    }
    

    seed is not required: when the op is exported with seed=None
    the handler emits the node with no seed field, and the runtime falls
    back to MLX's global default key.

  2. Codegen: run python backends/mlx/serialization/generate.py to
    regenerate the Python dataclass, serializers, inspector, and the
    partial C++ loader bits.

  3. Op handler (backends/mlx/ops.py): register a handler for
    torch.ops.mlx.sample that emits the full Gumbel-max graph from the
    new RandomBitsNode plus existing nodes — i.e. reproduce MLX's
    uniform → gumbel → argmax layering in the IR:

    • uniform (random.cpp:95): RandomBitsNode (uint32) →
      AsTypeNode(float32) → DivideNode by uint32_max
      (a FullNode constant 4294967295.0) → MinimumNode with
      nextafter(1, 0) (a FullNode) → AsTypeNode(target dtype).
    • gumbel (random.cpp:367): LogNode → NegNode → LogNode → NegNode.
    • sample: divide logits by the runtime temperature tensor
      (DivideNode — temperature is a graph input, so it stays
      runtime-settable) → AddNode (gumbel noise) → ArgmaxNode(axis=-1).
      If a seed input is present, read it via .item() to a SymInt and
      thread it into RandomBitsNode.seed (same pattern as input_pos in
      cache.py); if seed is None, emit RandomBitsNode with the seed
      field unset.

    Model the handler shape on _argmax_handler (ops.py:3410).
    Alternatively register it as a PatternHandler like
    GatedDeltaRuleHandler if matching the auto_functionalized wrapper
    is cleaner.

  4. Runtime (backends/mlx/runtime/MLXInterpreter.h): add
    exec_random_bits(...) that calls
    mlx::core::random::bits(shape, n.width, key, s), where key is
    random::key(n.seed) when the seed field is present and std::nullopt
    when absent (the latter making MLX use its global KeySequence — see
    random.cpp:41). Wire it into the eval switch (next to exec_argmax,
    ~line 2058). This is the only new exec_* — every other op in the
    chain already has one.

Part 4: Tests

Add backends/mlx/custom_kernel_ops/test/test_sample.py, modeled on
test_gated_delta_rule.py:

  • Greedy parity: with a small temperature (e.g. 1e-4),
    mlx::sample matches torch.argmax(logits, dim=-1) on random logits.
  • Runtime temperature: a single exported program produces near-greedy
    outputs at small temperature and higher-entropy outputs at large
    temperature, without re-exporting (confirms temperature is a live
    graph input, not a baked constant).
  • Op parity: mlx::sample reference vs. an independent Gumbel-max /
    multinomial reference — same seed → identical token; over many seeds,
    empirical token frequencies match softmax(logits/T) within tolerance
    (chi-square / TV distance).
  • Determinism (seeded export): same seed → identical token across
    runs; different seed → different draws.
  • Unseeded export: exporting with seed=None lowers cleanly
    (RandomBitsNode emitted with no seed field) and produces valid,
    varying tokens across runs (no reproducibility expected).
  • torch.export: SamplingHead exports with strict=True and lowers
    through MLXPartitioner (the new op is assigned to the MLX delegate,
    not left in the host program), with temperature and seed as graph
    inputs.
  • (On Apple Silicon) end-to-end: a tiny model + SamplingHead
    exports, runs via the MLX runtime, and produces a valid token id in
    [0, vocab).

Out of Scope (Follow-up Issues)

  • top-k / top-p (nucleus) filtering. The building blocks
    (SortNode, ArgsortNode, PartitionNode, CumsumNode) already
    exist; adding masked top-k/top-p on top of the Gumbel path is a natural
    follow-up once the random node lands.
  • Repetition / frequency / presence penalties. These need token
    history state and are better handled in a separate issue.
  • Batched / multi-token (speculative) sampling. Start with
    B=1, one token per step (matches the existing MLX KV-cache
    constraint max_batch_size == 1 in cache.py).
  • Mutable RNG-state buffer. The seed-as-input design avoids it; a
    buffer-backed RNG (like the KV-cache mutation path) can be a follow-up
    if a stateful API is later desired.

Acceptance Criteria

  • New backends/mlx/llm/sampling.py::SamplingHead wraps a
    logits-returning model and returns a token id; temperature is a
    runtime input to forward, and seed is an optional input (see
    the seed criterion below).
  • mlx::sample is registered (custom op + register_fake) with a
    CPU reference implementation, takes temperature/seed as tensor
    inputs, and has an op handler in backends/mlx/ops.py that emits
    the Gumbel-max graph.
  • Exactly one new schema node (RandomBitsNode, with an optional
    seed field) is added to schema.fbs (append-only in the OpNode
    union), generate.py has been re-run, and the generated artifacts
    are committed.
  • seed is optional end-to-end: exporting with a seed tensor yields a
    deterministic .pte (explicit random::key(seed)); exporting with
    seed=None yields an unseeded .pte (runtime std::nullopt
    MLX global KeySequence). The seeded/unseeded distinction is fixed
    at export time.
  • backends/mlx/runtime/MLXInterpreter.h executes the new node via
    mlx::core::random and is wired into the eval dispatch switch.
  • backends/mlx/custom_kernel_ops/test/test_sample.py passes
    (greedy parity, runtime-temperature, op parity / distribution
    check, determinism, torch.export through MLXPartitioner).
  • Changing temperature at runtime alters sampling without
    re-export. For a seeded export, the same seed produces
    identical tokens across runs with no host-side stateful RNG; the
    unseeded export intentionally relies on MLX's global key sequence.

Pointers

  • Canonical custom-op pattern to mirror:
    backends/mlx/custom_kernel_ops/gated_delta_rule.py and its test
    backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py.
  • Simple op-handler examples: _argmax_handler
    (backends/mlx/ops.py:3410), _softmax_handler
    (backends/mlx/ops.py:1195).
  • Schema (source of truth) + BC rules + codegen:
    backends/mlx/serialization/schema.fbs,
    backends/mlx/serialization/generate.py.
  • Runtime execution dispatch:
    backends/mlx/runtime/MLXInterpreter.h (exec_argmax at line 1693,
    eval switch ~line 2058).
  • Vendored MLX RNG API:
    backends/mlx/third-party/mlx/mlx/random.h
    (key, split, bits, uniform, gumbel, categorical).
  • Where transforms/wrappers are applied:
    backends/mlx/llm/source_transformation.py.
  • Seed-threading precedent (input_pos/start_pos):
    backends/mlx/llm/cache.py::KVCache.update.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

Status
No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions