Finding
Several hot-path wrappers allocate and zero temporary tensors on every call:
Why this should improve performance
In decode serving, many of these kernels run once per layer per token. Per-call allocation and memzero overhead can dominate small-batch latency and complicates HIP/CUDA graph capture. The code already has some fixed-shape modes (truncate=False in MoE align), so the next step is to let callers reuse the fixed buffers.
Suggested implementation
- Add optional
workspace= / out= arguments for the hot public APIs, or small workspace dataclasses per op family.
- Size workspaces from existing deterministic bounds (
max_pad, max_blocks, M, top_k, N, hidden size).
- Keep allocating by default for ergonomic use, but let serving stacks preallocate once per max decode shape.
- Avoid zeroing buffers when the kernel overwrites all live elements; keep explicit zeroing only for atomic-add combine outputs or EP partial outputs that need zero for skipped experts.
Validation
- Add tests that pass preallocated buffers and compare results with allocation-owning calls.
- Benchmark small decode buckets with allocation-owning eager, preallocated eager, and graph-captured execution.
- Include a stress test that reuses a larger workspace for smaller
M buckets without stale-data leakage.
Finding
Several hot-path wrappers allocate and zero temporary tensors on every call:
sorted_ids,expert_ids,num_post,tokens_cnts, andcumsumevery call:src/xkernels/ops/moe/triton/align_kernel.py#L189-L193[M, N]fp32 combine output or[M * top_k, N]scratch:src/xkernels/ops/moe/triton/moe_int4_kernel.py#L441-L464src/xkernels/ops/moe/triton/moe_mxfp4_kernel.py#L436-L452out,lse, andmaxlevery call:src/xkernels/ops/attention/triton/sparse_mla_kernel.py#L134-L136Why this should improve performance
In decode serving, many of these kernels run once per layer per token. Per-call allocation and memzero overhead can dominate small-batch latency and complicates HIP/CUDA graph capture. The code already has some fixed-shape modes (
truncate=Falsein MoE align), so the next step is to let callers reuse the fixed buffers.Suggested implementation
workspace=/out=arguments for the hot public APIs, or small workspace dataclasses per op family.max_pad,max_blocks,M,top_k,N, hidden size).Validation
Mbuckets without stale-data leakage.