Finding
The native fp8 MFMA path is only selected by path="auto" when operands use an AMD-native fnuz fp8 dtype:
Why this should improve performance
A caller using the public defaults can quantize to float8_e4m3fn and then call mm_fp8_blockscale(..., path="auto"), silently missing the 3-9x native fp8 MFMA path documented for MI300A. The optimized path exists, but the default API makes it easy not to use on AMD.
Suggested implementation
- Add a device-aware helper such as
preferred_fp8_dtype() that returns float8_e4m3fnuz on AMD/gfx942 when available and float8_e4m3fn elsewhere.
- Consider a new
fp8_dtype="auto" default for quant helpers, while preserving explicit dtype control for portability and parity tests.
- Add a warning or debug note when
path="auto" falls back to portable because the operands are float8_e4m3fn on AMD.
Validation
- Add tests for dtype selection on environments with and without
torch.float8_e4m3fnuz.
- Benchmark the default quantize +
mm_fp8_blockscale(path="auto") flow before and after on MI300A.
- Keep explicit
float8_e4m3fn tests to preserve portable fallback correctness.
Finding
The native fp8 MFMA path is only selected by
path="auto"when operands use an AMD-native fnuz fp8 dtype:src/xkernels/ops/gemm/triton/entry.py#L50-L57torch.float8_e4m3fn, not fnuz:src/xkernels/ops/gemm/reference.py#L137-L163andsrc/xkernels/ops/gemm/reference.py#L166-L192float8_e4m3fnuzoperands; fn operands stay on the portable fallback.Why this should improve performance
A caller using the public defaults can quantize to
float8_e4m3fnand then callmm_fp8_blockscale(..., path="auto"), silently missing the 3-9x native fp8 MFMA path documented for MI300A. The optimized path exists, but the default API makes it easy not to use on AMD.Suggested implementation
preferred_fp8_dtype()that returnsfloat8_e4m3fnuzon AMD/gfx942 when available andfloat8_e4m3fnelsewhere.fp8_dtype="auto"default for quant helpers, while preserving explicit dtype control for portability and parity tests.path="auto"falls back to portable because the operands arefloat8_e4m3fnon AMD.Validation
torch.float8_e4m3fnuz.mm_fp8_blockscale(path="auto")flow before and after on MI300A.float8_e4m3fntests to preserve portable fallback correctness.