Skip to content

Add Triton kernels for fp8 blockscale quantization helpers #57

Description

@xzyaoi

Finding

The fp8 blockscale quantization helpers are implemented as Python loops over quantization blocks:

These helpers are exported from the package and feed mm_fp8_blockscale.

Why this should improve performance

The weight helper may be offline, but per-token activation quantization can be on the inference path before fp8 blockscale GEMM. The current implementation launches many small torch ops from Python and writes slices group by group. A Triton quantizer can compute scales and fp8 casts in one or a small number of kernels.

Suggested implementation

  • Add a Triton backend for per_token_group_quant_fp8 that maps one program to a (token, K-block) tile and writes both fp8 values and scale.
  • Add a Triton backend for per_block_quant_fp8 for benchmark/setup paths that need fast on-device weight preparation.
  • Support both float8_e4m3fn and float8_e4m3fnuz, sharing the dtype max table already in the reference.
  • Keep the current torch implementation as the CPU/reference path.

Validation

  • Compare dequantized values from Triton quantization against the current reference helpers.
  • Benchmark quantize + mm_fp8_blockscale end-to-end for V4 projection shapes.
  • Include non-multiple-of-128 K/N tails if the public helper continues to support them.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions