Skip to content

Make AMD fp8 blockscale defaults hit the native fnuz MFMA fast path #56

Description

@xzyaoi

Finding

The native fp8 MFMA path is only selected by path="auto" when operands use an AMD-native fnuz fp8 dtype:

Why this should improve performance

A caller using the public defaults can quantize to float8_e4m3fn and then call mm_fp8_blockscale(..., path="auto"), silently missing the 3-9x native fp8 MFMA path documented for MI300A. The optimized path exists, but the default API makes it easy not to use on AMD.

Suggested implementation

  • Add a device-aware helper such as preferred_fp8_dtype() that returns float8_e4m3fnuz on AMD/gfx942 when available and float8_e4m3fn elsewhere.
  • Consider a new fp8_dtype="auto" default for quant helpers, while preserving explicit dtype control for portability and parity tests.
  • Add a warning or debug note when path="auto" falls back to portable because the operands are float8_e4m3fn on AMD.

Validation

  • Add tests for dtype selection on environments with and without torch.float8_e4m3fnuz.
  • Benchmark the default quantize + mm_fp8_blockscale(path="auto") flow before and after on MI300A.
  • Keep explicit float8_e4m3fn tests to preserve portable fallback correctness.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions