Skip to content

Fuse DSA indexer logits with top-k selection #54

Description

@xzyaoi

Finding

The DSA indexer path computes and stores the full [T, K] logits tensor, then exposes a separate helper that calls torch.topk:

Why this should improve performance

For V4 sparse attention, the indexer only needs top-512/top-1024 indices. Materializing all T * K fp32 scores and then launching a separate torch.topk adds global memory traffic and an additional selection pass over the same data. A fused streaming top-k path can keep tile candidates local and write only the selected indices/scores.

Suggested implementation

  • Add a dsa_indexer_topk_triton or public fused API that computes weighted-ReLU MQA scores and emits top-k indices directly.
  • Use tile-local top-k candidates, then a per-row merge/reduction step across KV tiles.
  • Keep dsa_indexer_logits for diagnostics and reference parity, but let serving call the fused top-k path.
  • Consider returning selected scores as an optional output if later stages need them for debugging or validation.

Validation

  • Compare fused indices against torch.topk(dsa_indexer_logits(...), sorted=False) using set equality or sorted-index equality where appropriate.
  • Benchmark logits-only + torch.topk vs fused top-k across large K and V4 top-k values.
  • Include causal-window masks (lengths, row_starts) in correctness and performance tests.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions