|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import functools |
|
|
import os |
|
|
import pathlib |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import torch |
|
|
import torch._dynamo.config |
|
|
import triton |
|
|
import triton.language as tl |
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_func |
|
|
except: |
|
|
flash_attn_func = None |
|
|
print("Flash Attention 2 not found.") |
|
|
|
|
|
try: |
|
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func |
|
|
except: |
|
|
flash_attn_3_func = None |
|
|
print("Flash Attention 3 not found.") |
|
|
|
|
|
try: |
|
|
from kernels import get_kernel |
|
|
hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn") |
|
|
hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3") |
|
|
except: |
|
|
hf_kernels_flash_attn = None |
|
|
hf_kernels_flash_attn_3 = None |
|
|
print("HF Kernels not found.") |
|
|
|
|
|
try: |
|
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90 |
|
|
except: |
|
|
sageattn_qk_int8_pv_fp16_cuda = None |
|
|
sageattn_qk_int8_pv_fp16_triton = None |
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90 = None |
|
|
print("SageAttention not found.") |
|
|
|
|
|
try: |
|
|
from transformer_engine.pytorch.attention import DotProductAttention |
|
|
except: |
|
|
DotProductAttention = None |
|
|
print("Transformer Engine not found.") |
|
|
|
|
|
try: |
|
|
import xformers.ops as xops |
|
|
except: |
|
|
xops = None |
|
|
print("xFormers not found.") |
|
|
|
|
|
|
|
|
plt.rcParams.update({ |
|
|
"figure.figsize": (12, 10), |
|
|
"figure.dpi": 120, |
|
|
"font.size": 10, |
|
|
"axes.titlesize": 12, |
|
|
"axes.labelsize": 14, |
|
|
"xtick.labelsize": 10, |
|
|
"ytick.labelsize": 10, |
|
|
"legend.fontsize": 8, |
|
|
"axes.grid": True, |
|
|
"grid.alpha": 0.3, |
|
|
"grid.linestyle": "--", |
|
|
"lines.linewidth": 2.0, |
|
|
"lines.markersize": 6, |
|
|
"legend.frameon": True, |
|
|
"legend.framealpha": 0.9, |
|
|
"legend.loc": "best", |
|
|
"axes.spines.top": False, |
|
|
"axes.spines.right": False, |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
torch._dynamo.config.cache_size_limit = 10000 |
|
|
|
|
|
|
|
|
|
|
|
torch._dynamo.config.suppress_errors = True |
|
|
|
|
|
output_dir = pathlib.Path("dump_attention_benchmark") |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
batch_size = 1 |
|
|
num_attention_heads = 24 |
|
|
attention_head_dim = 128 |
|
|
image_sequence_length = 4096 |
|
|
text_sequence_lengths = [128, 256, 320, 384, 448, 512] |
|
|
sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths] |
|
|
|
|
|
|
|
|
def _attention_torch(query, key, value, *, backend): |
|
|
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) |
|
|
with torch.nn.attention.sdpa_kernel(backend): |
|
|
out = torch.nn.functional.scaled_dot_product_attention(query, key, value) |
|
|
out = out.transpose(1, 2).contiguous() |
|
|
return out |
|
|
|
|
|
|
|
|
_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False) |
|
|
def _attention_torch_compile_default(query, key, value, *, backend): |
|
|
return _compiled_attention_torch_default(query, key, value, backend=backend) |
|
|
|
|
|
|
|
|
_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False) |
|
|
def _attention_torch_compile_max_autotune(query, key, value, *, backend): |
|
|
return _compiled_attention_torch_max_autotune(query, key, value, backend=backend) |
|
|
|
|
|
|
|
|
def _attention_flash_attn_2(query, key, value): |
|
|
return flash_attn_func(query, key, value) |
|
|
|
|
|
|
|
|
_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False) |
|
|
def _attention_flash_attn_2_compile_default(query, key, value): |
|
|
return _compiled_flash_attn_2_default(query, key, value) |
|
|
|
|
|
|
|
|
_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False) |
|
|
def _attention_flash_attn_2_compile_max_autotune(query, key, value): |
|
|
return _compiled_flash_attn_2_max_autotune(query, key, value) |
|
|
|
|
|
|
|
|
|
|
|
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") |
|
|
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: |
|
|
out, lse = flash_attn_3_func(query, key, value) |
|
|
return out |
|
|
|
|
|
|
|
|
@torch.library.register_fake("flash_attn_3::_flash_attn_forward") |
|
|
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: |
|
|
return torch.empty_like(query) |
|
|
|
|
|
|
|
|
def _attention_flash_attn_3(query, key, value): |
|
|
out = _wrapped_flash_attn_3(query, key, value) |
|
|
return out |
|
|
|
|
|
|
|
|
_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False) |
|
|
def _attention_flash_attn_3_compile_default(query, key, value): |
|
|
return _compiled_flash_attn_3_default(query, key, value) |
|
|
|
|
|
|
|
|
_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False) |
|
|
def _attention_flash_attn_3_compile_max_autotune(query, key, value): |
|
|
return _compiled_flash_attn_3_max_autotune(query, key, value) |
|
|
|
|
|
|
|
|
def _attention_hf_kernels_flash_attn(query, key, value): |
|
|
return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0] |
|
|
|
|
|
|
|
|
def _attention_hf_kernels_flash_attn3(query, key, value): |
|
|
return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0] |
|
|
|
|
|
|
|
|
def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value): |
|
|
return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD") |
|
|
|
|
|
|
|
|
def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value): |
|
|
return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD") |
|
|
|
|
|
|
|
|
def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value): |
|
|
return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD") |
|
|
|
|
|
|
|
|
if DotProductAttention is not None: |
|
|
def set_te_backend(backend): |
|
|
|
|
|
|
|
|
os.environ["NVTE_FLASH_ATTN"] = '0' |
|
|
os.environ["NVTE_FUSED_ATTN"] = '0' |
|
|
os.environ["NVTE_UNFUSED_ATTN"] = '0' |
|
|
if backend == 'flash': |
|
|
os.environ["NVTE_FLASH_ATTN"] = '1' |
|
|
if backend == 'fused': |
|
|
os.environ["NVTE_FUSED_ATTN"] = '1' |
|
|
if backend == 'unfused': |
|
|
os.environ["NVTE_UNFUSED_ATTN"] = '1' |
|
|
|
|
|
set_te_backend("fused") |
|
|
te_attn_fn = DotProductAttention( |
|
|
num_attention_heads=num_attention_heads, |
|
|
kv_channels=attention_head_dim, |
|
|
qkv_format="bshd", |
|
|
attn_mask_type="no_mask", |
|
|
) |
|
|
else: |
|
|
def te_attn_fn(query, key, value): |
|
|
raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.") |
|
|
|
|
|
def _attention_te(query, key, value): |
|
|
out = te_attn_fn(query, key, value) |
|
|
out = out.unflatten(2, (num_attention_heads, attention_head_dim)) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False) |
|
|
def _attention_te_compile_default(query, key, value): |
|
|
return _compiled_te_attn_fn_default(query, key, value) |
|
|
|
|
|
|
|
|
|
|
|
_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False) |
|
|
def _attention_te_compile_max_autotune(query, key, value): |
|
|
return _compiled_te_attn_fn_max_autotune(query, key, value) |
|
|
|
|
|
|
|
|
def _attention_xformers(query, key, value): |
|
|
return xops.memory_efficient_attention(query, key, value) |
|
|
|
|
|
|
|
|
_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False) |
|
|
def _attention_xformers_compile_default(query, key, value): |
|
|
return _compiled_xformers_default(query, key, value) |
|
|
|
|
|
|
|
|
_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False) |
|
|
def _attention_xformers_compile_max_autotune(query, key, value): |
|
|
return _compiled_xformers_max_autotune(query, key, value) |
|
|
|
|
|
|
|
|
attention_ops = {} |
|
|
attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) |
|
|
attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) |
|
|
attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) |
|
|
attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) |
|
|
attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) |
|
|
attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) |
|
|
if hf_kernels_flash_attn is not None: |
|
|
attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn |
|
|
attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3 |
|
|
if flash_attn_func is not None: |
|
|
attention_ops["flash_attn_2"] = _attention_flash_attn_2 |
|
|
attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default |
|
|
attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune |
|
|
if flash_attn_3_func is not None: |
|
|
attention_ops["flash_attn_3"] = _attention_flash_attn_3 |
|
|
attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default |
|
|
attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune |
|
|
if sageattn_qk_int8_pv_fp16_cuda is not None: |
|
|
attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda |
|
|
attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton |
|
|
if torch.cuda.get_device_capability()[0] >= 9: |
|
|
attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90 |
|
|
if DotProductAttention is not None: |
|
|
attention_ops["te_fused"] = _attention_te |
|
|
attention_ops["te_fused_compile_d"] = _attention_te_compile_default |
|
|
attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune |
|
|
if xops is not None: |
|
|
attention_ops["xformers"] = _attention_xformers |
|
|
attention_ops["xformers_compile_d"] = _attention_xformers_compile_default |
|
|
attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune |
|
|
|
|
|
|
|
|
def get_color_and_linestyle(n: int) -> tuple[str, str]: |
|
|
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"] |
|
|
line_styles = ["-", ":", "-.", "--"] |
|
|
if n > len(colors) * len(line_styles): |
|
|
raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}") |
|
|
styles = [] |
|
|
for i in range(n): |
|
|
color = colors[i % len(colors)] |
|
|
linestyle = line_styles[i // len(colors)] |
|
|
styles.append((color, linestyle)) |
|
|
return styles |
|
|
|
|
|
|
|
|
def correctness(): |
|
|
for seq_len in sequence_lengths: |
|
|
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) |
|
|
print(f"\n\n===== Testing shape: {shape} =====") |
|
|
|
|
|
query = torch.randn(shape, device="cuda", dtype=torch.float32) |
|
|
key = torch.randn(shape, device="cuda", dtype=torch.float32) |
|
|
value = torch.randn(shape, device="cuda", dtype=torch.float32) |
|
|
|
|
|
golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH) |
|
|
query, key, value = (x.bfloat16() for x in (query, key, value)) |
|
|
|
|
|
for name, fn in attention_ops.items(): |
|
|
out = fn(query, key, value) |
|
|
absdiff = (out - golden_truth).abs() |
|
|
absmax = torch.max(absdiff) |
|
|
mae = torch.mean(absdiff) |
|
|
mse = torch.mean((golden_truth - out) ** 2) |
|
|
print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}") |
|
|
|
|
|
|
|
|
@triton.testing.perf_report( |
|
|
triton.testing.Benchmark( |
|
|
x_names=["seq_len"], |
|
|
x_vals=sequence_lengths, |
|
|
x_log=False, |
|
|
line_arg="provider", |
|
|
line_vals=list(attention_ops.keys()), |
|
|
line_names=[x.removeprefix("solution_") for x in attention_ops.keys()], |
|
|
ylabel="Time (ms)", |
|
|
styles=get_color_and_linestyle(len(attention_ops)), |
|
|
plot_name="Attention Benchmark", |
|
|
args={}, |
|
|
) |
|
|
) |
|
|
def benchmark_fn(seq_len: int, provider: str): |
|
|
torch.manual_seed(0) |
|
|
|
|
|
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) |
|
|
query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) |
|
|
key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) |
|
|
value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) |
|
|
|
|
|
fn = attention_ops[provider] |
|
|
ms, min_ms, max_ms = triton.testing.do_bench( |
|
|
lambda: fn(query, key, value), |
|
|
warmup=3, |
|
|
rep=10, |
|
|
quantiles=[0.5, 0.2, 0.8], |
|
|
) |
|
|
return ms, max_ms, min_ms |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
correctness() |
|
|
benchmark_fn.run(print_data=True, save_path=output_dir.as_posix()) |
|
|
|