Skip to content

Flash Attention 3 varlen wrapper assumes output is a tuple #13999

@bghira

Description

@bghira

Describe the bug

_flash_attention_3_varlen_hub assumes flash_attn_varlen_func returns a tuple and does:

out, lse, *_ = func(...)

The hub FA3 varlen function however, returns a single tensor with shape [total_q, heads, dim], so Python destructuring takes the first sequence row as out.

Reproduction

import argparse
import sys

import torch

from diffusers.models.attention_dispatch import (
    AttentionBackendName,
    _HUB_KERNELS_REGISTRY,
    _flash_attention_3_varlen_hub,
)


class SkipRepro(RuntimeError):
    pass


def _fake_flash_attn_3_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q,
    max_seqlen_k,
    softmax_scale,
    causal,
):
    # The hub flash-attn3 varlen function in the failing environment returns
    # one tensor with shape [total_q, heads, dim], not a tuple.
    return q + 1000


def _make_inputs(device):
    batch_size = 2
    seq_len = 4
    heads = 2
    dim = 64

    query = torch.randn(batch_size, seq_len, heads, dim, device=device, dtype=torch.bfloat16)
    key = query.clone()
    value = query.clone()
    return query, key, value


def _install_kernel(kernel_fn):
    registry_entry = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB]
    original_kernel_fn = registry_entry.kernel_fn
    registry_entry.kernel_fn = kernel_fn
    return registry_entry, original_kernel_fn


def _run_diffusers_wrapper(query, key, value, kernel_fn):
    registry_entry, original_kernel_fn = _install_kernel(kernel_fn)

    try:
        out = _flash_attention_3_varlen_hub(query, key, value)
    finally:
        registry_entry.kernel_fn = original_kernel_fn

    return out


def _run_mock_repro():
    query, key, value = _make_inputs("cpu")

    out = _run_diffusers_wrapper(query, key, value, _fake_flash_attn_3_varlen_func)
    expected = query + 1000

    print("mode: mock tensor-returning FA3 varlen function")
    print(f"expected shape: {tuple(expected.shape)}")
    print(f"actual shape:   {tuple(out.shape)}")
    print(f"actual tensor equals expected: {torch.equal(out, expected)}")

    if tuple(out.shape) != tuple(expected.shape):
        print("BUG REPRODUCED: tensor return was destructured as if it were a tuple.")
        return

    raise RuntimeError("Bug did not reproduce; _flash_attention_3_varlen_hub preserved the full tensor output.")


def _require_sm90():
    if not torch.cuda.is_available():
        raise SkipRepro("This real FA3 repro requires CUDA, but torch.cuda.is_available() is False.")

    major, minor = torch.cuda.get_device_capability()
    device_name = torch.cuda.get_device_name()
    if major != 9:
        raise SkipRepro(
            f"This real FA3 repro requires an NVIDIA SM90 GPU, but got {device_name} with capability {major}.{minor}."
        )

    return f"{device_name} sm{major}{minor}"


def _run_real_repro():
    device_description = _require_sm90()

    from kernels import get_kernel

    kernel_module = get_kernel("kernels-community/flash-attn3", version=1)
    kernel_fn = kernel_module.flash_attn_varlen_func

    query, key, value = _make_inputs("cuda")
    batch_size, seq_len, _, _ = query.shape

    cu_seqlens = torch.arange(
        0,
        (batch_size + 1) * seq_len,
        step=seq_len,
        device=query.device,
        dtype=torch.int32,
    )
    query_packed = query.flatten(0, 1)
    key_packed = key.flatten(0, 1)
    value_packed = value.flatten(0, 1)

    direct = kernel_fn(
        q=query_packed,
        k=key_packed,
        v=value_packed,
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_k=cu_seqlens,
        max_seqlen_q=seq_len,
        max_seqlen_k=seq_len,
        softmax_scale=None,
        causal=False,
    )

    if isinstance(direct, tuple):
        raise SkipRepro(
            "The loaded FA3 varlen hub kernel returned a tuple in this environment, so this environment does not "
            "reproduce the reported tensor-return contract mismatch."
        )

    expected = direct.unflatten(0, (batch_size, seq_len))
    out = _run_diffusers_wrapper(query, key, value, kernel_fn)

    print("mode: real kernels-community/flash-attn3 flash_attn_varlen_func")
    print(f"device: {device_description}")
    print(f"direct kernel return type: {type(direct).__name__}")
    print(f"direct kernel shape:      {tuple(direct.shape)}")
    print(f"expected wrapper shape:   {tuple(expected.shape)}")
    print(f"actual wrapper shape:     {tuple(out.shape)}")
    print(f"actual tensor equals expected: {torch.equal(out, expected)}")

    if tuple(out.shape) != tuple(expected.shape):
        print("BUG REPRODUCED: Diffusers destructured the tensor return as if it were a tuple.")
        return

    raise RuntimeError("Bug did not reproduce; _flash_attention_3_varlen_hub preserved the full tensor output.")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mock",
        action="store_true",
        help="Run a CPU-only mock of the tensor-return contract. By default, this runs the real SM90 FA3 repro.",
    )
    args = parser.parse_args()

    if args.mock:
        _run_mock_repro()
        return

    try:
        _run_real_repro()
    except SkipRepro as error:
        print(f"SKIPPED: {error}", file=sys.stderr)
        sys.exit(77)


if __name__ == "__main__":
    main()

Logs

System Info

  • 🤗 Diffusers version: 0.39.0.dev0
  • Platform: Linux-6.11.0-1016-nvidia-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.12.3
  • PyTorch version (GPU?): 2.12.1+cpu (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.20.1
  • Transformers version: not installed
  • Accelerate version: not installed
  • PEFT version: not installed
  • Safetensors version: 0.8.0
  • xFormers version: not installed
  • Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB (x8)
  • Using GPU in script?: Yes, but not required
  • Using distributed or parallel set-up in script?: Yes, but not required

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions