Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.67s | Raw
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Thu Oct 2 19:58:23 2025 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA A10G On | 00000000:00:1B.0 Off | 0 | | 0% 37C P0 92W / 300W | 0MiB / 23028MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 1 NVIDIA A10G On | 00000000:00:1C.0 Off | 0 | | 0% 29C P8 24W / 300W | 0MiB / 23028MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 2 NVIDIA A10G On | 00000000:00:1D.0 Off | 0 | | 0% 29C P8 24W / 300W | 0MiB / 23028MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 3 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 | | 0% 30C P8 24W / 300W | 0MiB / 23028MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ +-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 35.41s | Raw
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
# ///
import torch
import sys
import os
import kernels_benchmark_tools as kbt


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()

kbt.add(
    "torch_flash_ma",
    torch_flash,
    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
)

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = "float32" if device == "cpu" else "bfloat16"

    # Flux-like workloads scaled down for CPU testing
    base = 1024 if device == "cuda" else 512
    flux_sizes = (
        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
    )
    heads = 24 if device == "cuda" else 8
    head_dim = 128 if device == "cuda" else 64

    wl = []
    for L in flux_sizes:
        wl.append(
            {
                "name": f"flux_L{L}",
                "batch": 1,
                "seq_len": base + L,
                "heads": heads,
                "head_dim": head_dim,
                "dtype": dtype,
                "device": device,
                "seed": 0,
            }
        )

    kbt.run(
        wl,
        jsonl="attn.jsonl",
        reps=5,
        warmup=2,
        gen=kbt.attn.gen_qkv,
        ref=kbt.attn.ref_math,
        cmp=kbt.attn.cmp_allclose,
    )
    kbt.summarize(["attn.jsonl"])
impl wl p50(ms) ok torch_flash_ma flux_L128 0.49 True torch_flash_ma flux_L256 0.52 True torch_flash_ma flux_L320 0.65 True torch_flash_ma flux_L384 0.68 True torch_flash_ma flux_L448 0.71 True torch_flash_ma flux_L512 0.74 True
▶ UV Install Logs

Artifacts:

attn.jsonl