Finding
The fp8 blockscale quantization helpers are implemented as Python loops over quantization blocks:
These helpers are exported from the package and feed mm_fp8_blockscale.
Why this should improve performance
The weight helper may be offline, but per-token activation quantization can be on the inference path before fp8 blockscale GEMM. The current implementation launches many small torch ops from Python and writes slices group by group. A Triton quantizer can compute scales and fp8 casts in one or a small number of kernels.
Suggested implementation
- Add a Triton backend for
per_token_group_quant_fp8 that maps one program to a (token, K-block) tile and writes both fp8 values and scale.
- Add a Triton backend for
per_block_quant_fp8 for benchmark/setup paths that need fast on-device weight preparation.
- Support both
float8_e4m3fn and float8_e4m3fnuz, sharing the dtype max table already in the reference.
- Keep the current torch implementation as the CPU/reference path.
Validation
- Compare dequantized values from Triton quantization against the current reference helpers.
- Benchmark quantize +
mm_fp8_blockscale end-to-end for V4 projection shapes.
- Include non-multiple-of-128 K/N tails if the public helper continues to support them.
Finding
The fp8 blockscale quantization helpers are implemented as Python loops over quantization blocks:
amax, scale, cast, and stores per group:src/xkernels/ops/gemm/reference.py#L149-L163src/xkernels/ops/gemm/reference.py#L175-L192These helpers are exported from the package and feed
mm_fp8_blockscale.Why this should improve performance
The weight helper may be offline, but per-token activation quantization can be on the inference path before fp8 blockscale GEMM. The current implementation launches many small torch ops from Python and writes slices group by group. A Triton quantizer can compute scales and fp8 casts in one or a small number of kernels.
Suggested implementation
per_token_group_quant_fp8that maps one program to a(token, K-block)tile and writes both fp8 values and scale.per_block_quant_fp8for benchmark/setup paths that need fast on-device weight preparation.float8_e4m3fnandfloat8_e4m3fnuz, sharing the dtype max table already in the reference.Validation
mm_fp8_blockscaleend-to-end for V4 projection shapes.