Finding
flash_mla_with_kvcache currently gathers/dequantizes selected fp8_ds_mla cache rows with torch indexing, materializes a dense [T, total_topk, D] KV tensor, concatenates primary and optional extra cache selections, flattens it, builds an index tensor, and only then calls the Triton sparse MLA compute kernel:
The downstream Triton sparse MLA kernel already streams selected KV by indices: src/xkernels/ops/attention/triton/sparse_mla_kernel.py#L71-L93.
Why this should improve performance
The decode wrapper pays multiple extra GPU launches and materializes a top-k-sized dequantized KV buffer before attention. For DeepSeek-V4 sparse decode, topk is large and D=512, so this temporary can be a large fraction of the work. Fusing paged-cache address resolution and fp8 dequant into the attention kernel should reduce memory traffic and avoid the staging allocations.
Suggested implementation
- Add a Triton decode kernel variant that accepts
value_cache, scale_cache, block_table, primary indices, optional extra cache/indices, and validity lengths directly.
- Inside the attention streaming loop, resolve logical or physical cache positions, dequant fp8_ds_mla rows on the fly, and feed them to the online softmax/value accumulator.
- Preserve the existing materialized
sparse_mla_attention(q, kv, indices) API for tests and non-paged callers.
Validation
- Compare fused decode against the current staged decode for primary-only and primary+extra-cache cases.
- Benchmark peak temporary memory, launch count, and latency for
B=1, H=128, D=512, topk=512/1024.
- Include graph-capture validation for fixed decode shapes.
Finding
flash_mla_with_kvcachecurrently gathers/dequantizes selected fp8_ds_mla cache rows with torch indexing, materializes a dense[T, total_topk, D]KV tensor, concatenates primary and optional extra cache selections, flattens it, builds an index tensor, and only then calls the Triton sparse MLA compute kernel:src/xkernels/ops/attention/sparse_mla_decode.py#L26-L61torch.cat,reshape,to(q2.dtype),torch.arange, andtorch.wherestaging before compute:src/xkernels/ops/attention/sparse_mla_decode.py#L110-L150The downstream Triton sparse MLA kernel already streams selected KV by
indices:src/xkernels/ops/attention/triton/sparse_mla_kernel.py#L71-L93.Why this should improve performance
The decode wrapper pays multiple extra GPU launches and materializes a top-k-sized dequantized KV buffer before attention. For DeepSeek-V4 sparse decode,
topkis large andD=512, so this temporary can be a large fraction of the work. Fusing paged-cache address resolution and fp8 dequant into the attention kernel should reduce memory traffic and avoid the staging allocations.Suggested implementation
value_cache,scale_cache,block_table, primary indices, optional extra cache/indices, and validity lengths directly.sparse_mla_attention(q, kv, indices)API for tests and non-paged callers.Validation
B=1,H=128,D=512,topk=512/1024.