Finding
The fused_ffn optimized backends still run the three projection GEMMs as torch matmuls and only fuse the elementwise SwiGLU activation:
The README performance table already shows fused_ffn at ~1.0x because the GEMMs dominate and both paths use torch matmul.
Why this should improve performance
As implemented, fused_ffn adds backend dispatch and a custom activation launch but does not remove the dominant GEMM costs. To make it a real performance kernel, the project needs a different strategy than activation-only fusion.
Suggested implementation
- Evaluate whether this op should delegate to vendor GEMM libraries with fused/epilogue support rather than a custom activation-only kernel.
- Consider a persistent or grouped GEMM strategy only for shapes where torch/hipBLASLt misses the fast path.
- If no robust faster path exists, document
fused_ffn as a correctness/scaffold API rather than listing it as an optimized kernel.
- Add benchmarks that separate GEMM time, activation time, and dispatch overhead so future changes show where the win comes from.
Validation
- Benchmark fp16 and bf16 shapes currently listed in the README on MI300A and MI250X.
- Compare against
torch.compile/eager PyTorch baselines and vendor-library fast paths.
- Require a material speedup threshold before making a new backend the default.
Finding
The
fused_ffnoptimized backends still run the three projection GEMMs as torch matmuls and only fuse the elementwise SwiGLU activation:src/xkernels/ops/ffn/interface.py#L11-L27(silu(x @ w_gate) * (x @ w_up)) @ w_down:src/xkernels/ops/ffn/reference.py#L11-L18g = x @ w_gate,u = x @ w_up, fused activation, thenh @ w_down:src/xkernels/ops/ffn/triton/ffn_kernel.py#L35-L39src/xkernels/ops/ffn/cuda/__init__.py#L22-L26The README performance table already shows
fused_ffnat ~1.0x because the GEMMs dominate and both paths use torch matmul.Why this should improve performance
As implemented,
fused_ffnadds backend dispatch and a custom activation launch but does not remove the dominant GEMM costs. To make it a real performance kernel, the project needs a different strategy than activation-only fusion.Suggested implementation
fused_ffnas a correctness/scaffold API rather than listing it as an optimized kernel.Validation
torch.compile/eager PyTorch baselines and vendor-library fast paths.