▼ code
▼ output
▶ uv-logs
|
Cell: benchmark | 40.58s
|
Raw
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy",
# "torch",
# "kernels",
# "kernels-benchmark-tools",
# "sageattention",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
# ///
import torch
import sys
import os
import kernels_benchmark_tools as kbt
# from sageattention import sageattn_qk_int8_pv_fp16_cuda
# def sage_attention(q, k, v):
# """SageAttention with INT8 Q/K quantization and FP16 P/V"""
# return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout="NHD")
from kernels import get_kernel
hf_kernels_sage_attn = get_kernel("kernels-community/sage_attention")
def sage_attention(query, key, value):
"""HuggingFace Kernels Flash Attention"""
return hf_kernels_sage_attn.fwd(query, key, value, is_causal=False)[0]
kbt.add(
"sage_int8_fp16",
sage_attention,
tags={"family": "sageattention", "backend": "int8_fp16_cuda", "compile": "none"},
)
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("SageAttention requires CUDA - skipping benchmark")
sys.exit(0)
dtype = "bfloat16"
# Flux-like workloads
base = 1024
flux_sizes = [128, 256, 320, 384, 448, 512]
heads = 24
head_dim = 128
wl = []
for L in flux_sizes:
wl.append(
{
"name": f"flux_L{L}",
"batch": 1,
"seq_len": base + L,
"heads": heads,
"head_dim": head_dim,
"dtype": dtype,
"device": device,
"seed": 0,
}
)
kbt.run(
wl,
jsonl="attn.jsonl",
reps=5,
warmup=2,
gen=kbt.attn.gen_qkv,
ref=kbt.attn.ref_math,
cmp=kbt.attn.cmp_allclose,
)
kbt.summarize(["attn.jsonl"])
impl wl p50(ms) ok
sage_int8_fp16 flux_L128 FAIL False
Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd'
sage_int8_fp16 flux_L256 FAIL False
Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd'
sage_int8_fp16 flux_L320 FAIL False
Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd'
sage_int8_fp16 flux_L384 FAIL False
Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd'
sage_int8_fp16 flux_L448 FAIL False
Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd'
sage_int8_fp16 flux_L512 FAIL False
Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd'
▶ UV Install Logs
Fetching 11 files: 0%| | 0/11 [00:00<?, ?it/s]
Fetching 11 files: 9%|▉ | 1/11 [00:00<00:01, 5.59it/s]
Fetching 11 files: 73%|███████▎ | 8/11 [00:00<00:00, 12.79it/s]
Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 16.77it/s]