Skip to content

Rework fused_ffn beyond activation-only fusion #60

Description

@xzyaoi

Finding

The fused_ffn optimized backends still run the three projection GEMMs as torch matmuls and only fuse the elementwise SwiGLU activation:

The README performance table already shows fused_ffn at ~1.0x because the GEMMs dominate and both paths use torch matmul.

Why this should improve performance

As implemented, fused_ffn adds backend dispatch and a custom activation launch but does not remove the dominant GEMM costs. To make it a real performance kernel, the project needs a different strategy than activation-only fusion.

Suggested implementation

  • Evaluate whether this op should delegate to vendor GEMM libraries with fused/epilogue support rather than a custom activation-only kernel.
  • Consider a persistent or grouped GEMM strategy only for shapes where torch/hipBLASLt misses the fast path.
  • If no robust faster path exists, document fused_ffn as a correctness/scaffold API rather than listing it as an optimized kernel.
  • Add benchmarks that separate GEMM time, activation time, and dispatch overhead so future changes show where the win comes from.

Validation

  • Benchmark fp16 and bf16 shapes currently listed in the README on MI300A and MI250X.
  • Compare against torch.compile/eager PyTorch baselines and vendor-library fast paths.
  • Require a material speedup threshold before making a new backend the default.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions