Skip to content

Fuse flash_mla_with_kvcache paged fp8 gather/dequant with sparse MLA attention #53

Description

@xzyaoi

Finding

flash_mla_with_kvcache currently gathers/dequantizes selected fp8_ds_mla cache rows with torch indexing, materializes a dense [T, total_topk, D] KV tensor, concatenates primary and optional extra cache selections, flattens it, builds an index tensor, and only then calls the Triton sparse MLA compute kernel:

The downstream Triton sparse MLA kernel already streams selected KV by indices: src/xkernels/ops/attention/triton/sparse_mla_kernel.py#L71-L93.

Why this should improve performance

The decode wrapper pays multiple extra GPU launches and materializes a top-k-sized dequantized KV buffer before attention. For DeepSeek-V4 sparse decode, topk is large and D=512, so this temporary can be a large fraction of the work. Fusing paged-cache address resolution and fp8 dequant into the attention kernel should reduce memory traffic and avoid the staging allocations.

Suggested implementation

  • Add a Triton decode kernel variant that accepts value_cache, scale_cache, block_table, primary indices, optional extra cache/indices, and validity lengths directly.
  • Inside the attention streaming loop, resolve logical or physical cache positions, dequant fp8_ds_mla rows on the fly, and feed them to the online softmax/value accumulator.
  • Preserve the existing materialized sparse_mla_attention(q, kv, indices) API for tests and non-paged callers.

Validation

  • Compare fused decode against the current staged decode for primary-only and primary+extra-cache cases.
  • Benchmark peak temporary memory, launch count, and latency for B=1, H=128, D=512, topk=512/1024.
  • Include graph-capture validation for fixed decode shapes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions