▼ code ▼ output ▶ uv-logs | Cell: benchmark | 80.35s | Raw
# /// script
# dependencies = [
#   "numpy",
#   "torch",
#   "kernels",
#   "pandas",
#   "matplotlib"
# ]
# ///
# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths

import functools
import os
import pathlib

import matplotlib.pyplot as plt
import torch
import torch._dynamo.config
import triton
import triton.language as tl

try:
    from flash_attn import flash_attn_func
except:
    flash_attn_func = None
    print("Flash Attention 2 not found.")

try:
    from flash_attn_interface import flash_attn_func as flash_attn_3_func
except:
    flash_attn_3_func = None
    print("Flash Attention 3 not found.")

try:
    from kernels import get_kernel
    hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn")
    hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3")
except:
    hf_kernels_flash_attn = None
    hf_kernels_flash_attn_3 = None
    print("HF Kernels not found.")

try:
    from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90
except:
    sageattn_qk_int8_pv_fp16_cuda = None
    sageattn_qk_int8_pv_fp16_triton = None
    sageattn_qk_int8_pv_fp8_cuda_sm90 = None
    print("SageAttention not found.")

try:
    from transformer_engine.pytorch.attention import DotProductAttention
except:
    DotProductAttention = None
    print("Transformer Engine not found.")

try:
    import xformers.ops as xops
except:
    xops = None
    print("xFormers not found.")


plt.rcParams.update({
    "figure.figsize": (12, 10),
    "figure.dpi": 120,
    "font.size": 10,
    "axes.titlesize": 12,
    "axes.labelsize": 14,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 8,
    "axes.grid": True,
    "grid.alpha": 0.3,
    "grid.linestyle": "--",
    "lines.linewidth": 2.0,
    "lines.markersize": 6,
    "legend.frameon": True,
    "legend.framealpha": 0.9,
    "legend.loc": "best",
    "axes.spines.top": False,
    "axes.spines.right": False,
})


# We want to compare the best compiled version for each specific shape (dynamic=False)
torch._dynamo.config.cache_size_limit = 10000

# We need to suppress_errors for FA3 to work. It makes it run in eager mode.
# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome!
torch._dynamo.config.suppress_errors = True

output_dir = pathlib.Path("dump_attention_benchmark")
output_dir.mkdir(parents=True, exist_ok=True)

batch_size = 1
num_attention_heads = 24
attention_head_dim = 128
image_sequence_length = 4096  # 1024x1024px
text_sequence_lengths = [128, 256, 320, 384, 448, 512]
sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths]


def _attention_torch(query, key, value, *, backend):
    query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
    with torch.nn.attention.sdpa_kernel(backend):
        out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
    out = out.transpose(1, 2).contiguous()
    return out


_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False)
def _attention_torch_compile_default(query, key, value, *, backend):
    return _compiled_attention_torch_default(query, key, value, backend=backend)


_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False)
def _attention_torch_compile_max_autotune(query, key, value, *, backend):
    return _compiled_attention_torch_max_autotune(query, key, value, backend=backend)


def _attention_flash_attn_2(query, key, value):
    return flash_attn_func(query, key, value)


_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False)
def _attention_flash_attn_2_compile_default(query, key, value):
    return _compiled_flash_attn_2_default(query, key, value)


_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False)
def _attention_flash_attn_2_compile_max_autotune(query, key, value):
    return _compiled_flash_attn_2_max_autotune(query, key, value)


# For fullgraph=True tracing to be compatible
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    out, lse = flash_attn_3_func(query, key, value)
    return out


@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(query)


def _attention_flash_attn_3(query, key, value):
    out = _wrapped_flash_attn_3(query, key, value)
    return out


_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False)
def _attention_flash_attn_3_compile_default(query, key, value):
    return _compiled_flash_attn_3_default(query, key, value)


_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False)
def _attention_flash_attn_3_compile_max_autotune(query, key, value):
    return _compiled_flash_attn_3_max_autotune(query, key, value)


def _attention_hf_kernels_flash_attn(query, key, value):
    return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]


def _attention_hf_kernels_flash_attn3(query, key, value):
    return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0]


def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value):
    return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD")


def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value):
    return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD")


def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value):
    return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD")


if DotProductAttention is not None:
    def set_te_backend(backend):
        # must be applied before first use of
        # transformer_engine.pytorch.attention
        os.environ["NVTE_FLASH_ATTN"] = '0'
        os.environ["NVTE_FUSED_ATTN"] = '0'
        os.environ["NVTE_UNFUSED_ATTN"] = '0'
        if backend == 'flash':
            os.environ["NVTE_FLASH_ATTN"] = '1'
        if backend == 'fused':
            os.environ["NVTE_FUSED_ATTN"] = '1'
        if backend == 'unfused':
            os.environ["NVTE_UNFUSED_ATTN"] = '1'

    set_te_backend("fused")
    te_attn_fn = DotProductAttention(
        num_attention_heads=num_attention_heads,
        kv_channels=attention_head_dim,
        qkv_format="bshd",
        attn_mask_type="no_mask",
    )
else:
    def te_attn_fn(query, key, value):
        raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.")

def _attention_te(query, key, value):
    out = te_attn_fn(query, key, value)
    out = out.unflatten(2, (num_attention_heads, attention_head_dim))
    return out


# Cannot fullgraph compile TE
_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False)
def _attention_te_compile_default(query, key, value):
    return _compiled_te_attn_fn_default(query, key, value)


# Cannot fullgraph compile TE
_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False)
def _attention_te_compile_max_autotune(query, key, value):
    return _compiled_te_attn_fn_max_autotune(query, key, value)


def _attention_xformers(query, key, value):
    return xops.memory_efficient_attention(query, key, value)


_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False)
def _attention_xformers_compile_default(query, key, value):
    return _compiled_xformers_default(query, key, value)


_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False)
def _attention_xformers_compile_max_autotune(query, key, value):
    return _compiled_xformers_max_autotune(query, key, value)


attention_ops = {}
attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
if hf_kernels_flash_attn is not None:
    attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn
    attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3
if flash_attn_func is not None:
    attention_ops["flash_attn_2"] = _attention_flash_attn_2
    attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default
    attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune
if flash_attn_3_func is not None:
    attention_ops["flash_attn_3"] = _attention_flash_attn_3
    attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default
    attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune
if sageattn_qk_int8_pv_fp16_cuda is not None:
    attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda
    attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton
    if torch.cuda.get_device_capability()[0] >= 9:
        attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90
if DotProductAttention is not None:
    attention_ops["te_fused"] = _attention_te
    attention_ops["te_fused_compile_d"] = _attention_te_compile_default
    attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune
if xops is not None:
    attention_ops["xformers"] = _attention_xformers
    attention_ops["xformers_compile_d"] = _attention_xformers_compile_default
    attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune


def get_color_and_linestyle(n: int) -> tuple[str, str]:
    colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
    line_styles = ["-", ":", "-.", "--"]
    if n > len(colors) * len(line_styles):
        raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
    styles = []
    for i in range(n):
        color = colors[i % len(colors)]
        linestyle = line_styles[i // len(colors)]
        styles.append((color, linestyle))
    return styles


def correctness():
    for seq_len in sequence_lengths:
        shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
        print(f"\n\n===== Testing shape: {shape} =====")

        query = torch.randn(shape, device="cuda", dtype=torch.float32)
        key = torch.randn(shape, device="cuda", dtype=torch.float32)
        value = torch.randn(shape, device="cuda", dtype=torch.float32)

        golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
        query, key, value = (x.bfloat16() for x in (query, key, value))

        for name, fn in attention_ops.items():
            out = fn(query, key, value)
            absdiff = (out - golden_truth).abs()
            absmax = torch.max(absdiff)
            mae = torch.mean(absdiff)
            mse = torch.mean((golden_truth - out) ** 2)
            print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}")


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["seq_len"],
        x_vals=sequence_lengths,
        x_log=False,
        line_arg="provider",
        line_vals=list(attention_ops.keys()),
        line_names=[x.removeprefix("solution_") for x in attention_ops.keys()],
        ylabel="Time (ms)",
        styles=get_color_and_linestyle(len(attention_ops)),
        plot_name="Attention Benchmark",
        args={},
    )
)
def benchmark_fn(seq_len: int, provider: str):
    torch.manual_seed(0)

    shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
    query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
    key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
    value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)

    fn = attention_ops[provider]
    ms, min_ms, max_ms = triton.testing.do_bench(
        lambda: fn(query, key, value),
        warmup=3,
        rep=10,
        quantiles=[0.5, 0.2, 0.8],
    )
    return ms, max_ms, min_ms


with torch.inference_mode():
    correctness()
    benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
Flash Attention 2 not found. Flash Attention 3 not found. SageAttention not found. Transformer Engine not found. xFormers not found. ===== Testing shape: (1, 4224, 24, 128) ===== torch_cudnn : absmax=0.001554, mae=0.000075, mse=0.000000 torch_cudnn_compile_d : absmax=0.001554, mae=0.000075, mse=0.000000 torch_cudnn_compile_ma : absmax=0.001554, mae=0.000075, mse=0.000000 torch_flash : absmax=0.001554, mae=0.000075, mse=0.000000 torch_flash_compile_d : absmax=0.001554, mae=0.000075, mse=0.000000 torch_flash_compile_ma : absmax=0.001554, mae=0.000075, mse=0.000000 hf_flash_attn : absmax=0.001554, mae=0.000075, mse=0.000000 hf_flash_attn3 : absmax=0.001554, mae=0.000075, mse=0.000000 ===== Testing shape: (1, 4352, 24, 128) ===== torch_cudnn : absmax=0.001499, mae=0.000074, mse=0.000000 torch_cudnn_compile_d : absmax=0.001499, mae=0.000074, mse=0.000000 torch_cudnn_compile_ma : absmax=0.001499, mae=0.000074, mse=0.000000 torch_flash : absmax=0.001499, mae=0.000074, mse=0.000000 torch_flash_compile_d : absmax=0.001499, mae=0.000074, mse=0.000000 torch_flash_compile_ma : absmax=0.001499, mae=0.000074, mse=0.000000 hf_flash_attn : absmax=0.001499, mae=0.000074, mse=0.000000 hf_flash_attn3 : absmax=0.001499, mae=0.000074, mse=0.000000 ===== Testing shape: (1, 4416, 24, 128) ===== torch_cudnn : absmax=0.001278, mae=0.000073, mse=0.000000 torch_cudnn_compile_d : absmax=0.001278, mae=0.000073, mse=0.000000 torch_cudnn_compile_ma : absmax=0.001278, mae=0.000073, mse=0.000000 torch_flash : absmax=0.001278, mae=0.000073, mse=0.000000 torch_flash_compile_d : absmax=0.001278, mae=0.000073, mse=0.000000 torch_flash_compile_ma : absmax=0.001278, mae=0.000073, mse=0.000000 hf_flash_attn : absmax=0.001278, mae=0.000073, mse=0.000000 hf_flash_attn3 : absmax=0.001278, mae=0.000073, mse=0.000000 ===== Testing shape: (1, 4480, 24, 128) ===== torch_cudnn : absmax=0.001270, mae=0.000073, mse=0.000000 torch_cudnn_compile_d : absmax=0.001270, mae=0.000073, mse=0.000000 torch_cudnn_compile_ma : absmax=0.001270, mae=0.000073, mse=0.000000 torch_flash : absmax=0.001270, mae=0.000073, mse=0.000000 torch_flash_compile_d : absmax=0.001270, mae=0.000073, mse=0.000000 torch_flash_compile_ma : absmax=0.001270, mae=0.000073, mse=0.000000 hf_flash_attn : absmax=0.001270, mae=0.000073, mse=0.000000 hf_flash_attn3 : absmax=0.001270, mae=0.000073, mse=0.000000 ===== Testing shape: (1, 4544, 24, 128) ===== torch_cudnn : absmax=0.001696, mae=0.000072, mse=0.000000 torch_cudnn_compile_d : absmax=0.001696, mae=0.000072, mse=0.000000 torch_cudnn_compile_ma : absmax=0.001696, mae=0.000072, mse=0.000000 torch_flash : absmax=0.001696, mae=0.000072, mse=0.000000 torch_flash_compile_d : absmax=0.001696, mae=0.000072, mse=0.000000 torch_flash_compile_ma : absmax=0.001696, mae=0.000072, mse=0.000000 hf_flash_attn : absmax=0.001696, mae=0.000072, mse=0.000000 hf_flash_attn3 : absmax=0.001696, mae=0.000072, mse=0.000000 ===== Testing shape: (1, 4608, 24, 128) ===== torch_cudnn : absmax=0.001111, mae=0.000071, mse=0.000000 torch_cudnn_compile_d : absmax=0.001111, mae=0.000071, mse=0.000000 torch_cudnn_compile_ma : absmax=0.001111, mae=0.000071, mse=0.000000 torch_flash : absmax=0.001111, mae=0.000071, mse=0.000000 torch_flash_compile_d : absmax=0.001111, mae=0.000071, mse=0.000000 torch_flash_compile_ma : absmax=0.001111, mae=0.000071, mse=0.000000 hf_flash_attn : absmax=0.001111, mae=0.000071, mse=0.000000 hf_flash_attn3 : absmax=0.001111, mae=0.000071, mse=0.000000 Attention Benchmark: seq_len torch_cudnn torch_cudnn_compile_d torch_cudnn_compile_ma torch_flash torch_flash_compile_d torch_flash_compile_ma hf_flash_attn hf_flash_attn3 0 4224.0 3.801536 3.793184 4.180320 3.967200 3.980720 4.318736 3.397664 3.333776 1 4352.0 4.081888 4.075536 4.421904 4.400096 4.393040 4.733888 3.836736 3.760320 2 4416.0 4.144176 4.139104 4.485120 4.453280 4.447408 4.796512 3.895568 3.863744 3 4480.0 4.208256 4.204688 4.554240 4.529152 4.521568 4.874592 3.952256 3.870208 4 4544.0 4.437280 4.431168 4.790432 4.584704 4.579008 4.938656 4.010512 3.978752 5 4608.0 4.504352 4.497600 4.870880 4.661536 4.652544 5.031328 4.065504 3.985712
▶ UV Install Logs
Fetching 20 files: 0%| | 0/20 [00:00<?, ?it/s] Fetching 20 files: 5%|▌ | 1/20 [00:00<00:04, 4.75it/s] Fetching 20 files: 10%|█ | 2/20 [00:01<00:16, 1.06it/s] Fetching 20 files: 100%|██████████| 20/20 [00:01<00:00, 12.01it/s] Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s] Fetching 4 files: 25%|██▌ | 1/4 [00:00<00:01, 1.98it/s] Fetching 4 files: 50%|█████ | 2/4 [00:01<00:01, 1.38it/s] Fetching 4 files: 100%|██████████| 4/4 [00:01<00:00, 2.89it/s]