# /// script # dependencies = [ # "numpy", # "torch", # "kernels", # "pandas", # "matplotlib" # ] # /// # Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths import functools import os import pathlib import matplotlib.pyplot as plt import torch import torch._dynamo.config import triton import triton.language as tl try: from flash_attn import flash_attn_func except: flash_attn_func = None print("Flash Attention 2 not found.") try: from flash_attn_interface import flash_attn_func as flash_attn_3_func except: flash_attn_3_func = None print("Flash Attention 3 not found.") try: from kernels import get_kernel hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn") hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3") except: hf_kernels_flash_attn = None hf_kernels_flash_attn_3 = None print("HF Kernels not found.") try: from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90 except: sageattn_qk_int8_pv_fp16_cuda = None sageattn_qk_int8_pv_fp16_triton = None sageattn_qk_int8_pv_fp8_cuda_sm90 = None print("SageAttention not found.") try: from transformer_engine.pytorch.attention import DotProductAttention except: DotProductAttention = None print("Transformer Engine not found.") try: import xformers.ops as xops except: xops = None print("xFormers not found.") plt.rcParams.update({ "figure.figsize": (12, 10), "figure.dpi": 120, "font.size": 10, "axes.titlesize": 12, "axes.labelsize": 14, "xtick.labelsize": 10, "ytick.labelsize": 10, "legend.fontsize": 8, "axes.grid": True, "grid.alpha": 0.3, "grid.linestyle": "--", "lines.linewidth": 2.0, "lines.markersize": 6, "legend.frameon": True, "legend.framealpha": 0.9, "legend.loc": "best", "axes.spines.top": False, "axes.spines.right": False, }) # We want to compare the best compiled version for each specific shape (dynamic=False) torch._dynamo.config.cache_size_limit = 10000 # We need to suppress_errors for FA3 to work. It makes it run in eager mode. # I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome! torch._dynamo.config.suppress_errors = True output_dir = pathlib.Path("dump_attention_benchmark") output_dir.mkdir(parents=True, exist_ok=True) batch_size = 1 num_attention_heads = 24 attention_head_dim = 128 image_sequence_length = 4096 # 1024x1024px text_sequence_lengths = [128, 256, 320, 384, 448, 512] sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths] def _attention_torch(query, key, value, *, backend): query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(backend): out = torch.nn.functional.scaled_dot_product_attention(query, key, value) out = out.transpose(1, 2).contiguous() return out _compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False) def _attention_torch_compile_default(query, key, value, *, backend): return _compiled_attention_torch_default(query, key, value, backend=backend) _compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False) def _attention_torch_compile_max_autotune(query, key, value, *, backend): return _compiled_attention_torch_max_autotune(query, key, value, backend=backend) def _attention_flash_attn_2(query, key, value): return flash_attn_func(query, key, value) _compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False) def _attention_flash_attn_2_compile_default(query, key, value): return _compiled_flash_attn_2_default(query, key, value) _compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False) def _attention_flash_attn_2_compile_max_autotune(query, key, value): return _compiled_flash_attn_2_max_autotune(query, key, value) # For fullgraph=True tracing to be compatible @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: out, lse = flash_attn_3_func(query, key, value) return out @torch.library.register_fake("flash_attn_3::_flash_attn_forward") def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: return torch.empty_like(query) def _attention_flash_attn_3(query, key, value): out = _wrapped_flash_attn_3(query, key, value) return out _compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False) def _attention_flash_attn_3_compile_default(query, key, value): return _compiled_flash_attn_3_default(query, key, value) _compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False) def _attention_flash_attn_3_compile_max_autotune(query, key, value): return _compiled_flash_attn_3_max_autotune(query, key, value) def _attention_hf_kernels_flash_attn(query, key, value): return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0] def _attention_hf_kernels_flash_attn3(query, key, value): return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0] def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value): return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD") def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value): return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD") def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value): return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD") if DotProductAttention is not None: def set_te_backend(backend): # must be applied before first use of # transformer_engine.pytorch.attention os.environ["NVTE_FLASH_ATTN"] = '0' os.environ["NVTE_FUSED_ATTN"] = '0' os.environ["NVTE_UNFUSED_ATTN"] = '0' if backend == 'flash': os.environ["NVTE_FLASH_ATTN"] = '1' if backend == 'fused': os.environ["NVTE_FUSED_ATTN"] = '1' if backend == 'unfused': os.environ["NVTE_UNFUSED_ATTN"] = '1' set_te_backend("fused") te_attn_fn = DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=attention_head_dim, qkv_format="bshd", attn_mask_type="no_mask", ) else: def te_attn_fn(query, key, value): raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.") def _attention_te(query, key, value): out = te_attn_fn(query, key, value) out = out.unflatten(2, (num_attention_heads, attention_head_dim)) return out # Cannot fullgraph compile TE _compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False) def _attention_te_compile_default(query, key, value): return _compiled_te_attn_fn_default(query, key, value) # Cannot fullgraph compile TE _compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False) def _attention_te_compile_max_autotune(query, key, value): return _compiled_te_attn_fn_max_autotune(query, key, value) def _attention_xformers(query, key, value): return xops.memory_efficient_attention(query, key, value) _compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False) def _attention_xformers_compile_default(query, key, value): return _compiled_xformers_default(query, key, value) _compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False) def _attention_xformers_compile_max_autotune(query, key, value): return _compiled_xformers_max_autotune(query, key, value) attention_ops = {} attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) if hf_kernels_flash_attn is not None: attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3 if flash_attn_func is not None: attention_ops["flash_attn_2"] = _attention_flash_attn_2 attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune if flash_attn_3_func is not None: attention_ops["flash_attn_3"] = _attention_flash_attn_3 attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune if sageattn_qk_int8_pv_fp16_cuda is not None: attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton if torch.cuda.get_device_capability()[0] >= 9: attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90 if DotProductAttention is not None: attention_ops["te_fused"] = _attention_te attention_ops["te_fused_compile_d"] = _attention_te_compile_default attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune if xops is not None: attention_ops["xformers"] = _attention_xformers attention_ops["xformers_compile_d"] = _attention_xformers_compile_default attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune def get_color_and_linestyle(n: int) -> tuple[str, str]: colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"] line_styles = ["-", ":", "-.", "--"] if n > len(colors) * len(line_styles): raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}") styles = [] for i in range(n): color = colors[i % len(colors)] linestyle = line_styles[i // len(colors)] styles.append((color, linestyle)) return styles def correctness(): for seq_len in sequence_lengths: shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) print(f"\n\n===== Testing shape: {shape} =====") query = torch.randn(shape, device="cuda", dtype=torch.float32) key = torch.randn(shape, device="cuda", dtype=torch.float32) value = torch.randn(shape, device="cuda", dtype=torch.float32) golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH) query, key, value = (x.bfloat16() for x in (query, key, value)) for name, fn in attention_ops.items(): out = fn(query, key, value) absdiff = (out - golden_truth).abs() absmax = torch.max(absdiff) mae = torch.mean(absdiff) mse = torch.mean((golden_truth - out) ** 2) print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}") @triton.testing.perf_report( triton.testing.Benchmark( x_names=["seq_len"], x_vals=sequence_lengths, x_log=False, line_arg="provider", line_vals=list(attention_ops.keys()), line_names=[x.removeprefix("solution_") for x in attention_ops.keys()], ylabel="Time (ms)", styles=get_color_and_linestyle(len(attention_ops)), plot_name="Attention Benchmark", args={}, ) ) def benchmark_fn(seq_len: int, provider: str): torch.manual_seed(0) shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) fn = attention_ops[provider] ms, min_ms, max_ms = triton.testing.do_bench( lambda: fn(query, key, value), warmup=3, rep=10, quantiles=[0.5, 0.2, 0.8], ) return ms, max_ms, min_ms with torch.inference_mode(): correctness() benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())