Finding
The DSA indexer path computes and stores the full [T, K] logits tensor, then exposes a separate helper that calls torch.topk:
Why this should improve performance
For V4 sparse attention, the indexer only needs top-512/top-1024 indices. Materializing all T * K fp32 scores and then launching a separate torch.topk adds global memory traffic and an additional selection pass over the same data. A fused streaming top-k path can keep tile candidates local and write only the selected indices/scores.
Suggested implementation
- Add a
dsa_indexer_topk_triton or public fused API that computes weighted-ReLU MQA scores and emits top-k indices directly.
- Use tile-local top-k candidates, then a per-row merge/reduction step across KV tiles.
- Keep
dsa_indexer_logits for diagnostics and reference parity, but let serving call the fused top-k path.
- Consider returning selected scores as an optional output if later stages need them for debugging or validation.
Validation
- Compare fused indices against
torch.topk(dsa_indexer_logits(...), sorted=False) using set equality or sorted-index equality where appropriate.
- Benchmark logits-only + torch.topk vs fused top-k across large
K and V4 top-k values.
- Include causal-window masks (
lengths, row_starts) in correctness and performance tests.
Finding
The DSA indexer path computes and stores the full
[T, K]logits tensor, then exposes a separate helper that callstorch.topk:out = torch.empty((T, K), dtype=torch.float32, ...):src/xkernels/ops/attention/triton/dsa_indexer_kernel.py#L104-L121src/xkernels/ops/attention/triton/dsa_indexer_kernel.py#L93-L101dsa_indexer_topkthen callstorch.topk(..., sorted=False)and casts indices to int32:src/xkernels/ops/attention/dsa_reference.py#L83-L93Why this should improve performance
For V4 sparse attention, the indexer only needs top-512/top-1024 indices. Materializing all
T * Kfp32 scores and then launching a separatetorch.topkadds global memory traffic and an additional selection pass over the same data. A fused streaming top-k path can keep tile candidates local and write only the selected indices/scores.Suggested implementation
dsa_indexer_topk_tritonor public fused API that computes weighted-ReLU MQA scores and emits top-k indices directly.dsa_indexer_logitsfor diagnostics and reference parity, but let serving call the fused top-k path.Validation
torch.topk(dsa_indexer_logits(...), sorted=False)using set equality or sorted-index equality where appropriate.Kand V4 top-k values.lengths,row_starts) in correctness and performance tests.