# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy",
# "torch==2.8.0",
# "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
import os
import kernels_benchmark_tools as kbt
def torch_flash_base(q, k, v):
qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
return o.transpose(1, 2).contiguous()
# Compile with default mode
compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)
kbt.add(
"torch_flash_compiled_default",
compiled_flash_default,
tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
)
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = "float32" if device == "cpu" else "bfloat16"
# Flux-like workloads
base = 1024 if device == "cuda" else 512
flux_sizes = (
[128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
)
heads = 24 if device == "cuda" else 8
head_dim = 128 if device == "cuda" else 64
wl = []
for L in flux_sizes:
wl.append(
{
"name": f"flux_L{L}",
"batch": 1,
"seq_len": base + L,
"heads": heads,
"head_dim": head_dim,
"dtype": dtype,
"device": device,
"seed": 0,
}
)
kbt.run(
wl,
jsonl="attn_default.jsonl",
reps=5,
warmup=2,
gen=kbt.attn.gen_qkv,
ref=kbt.attn.ref_math,
cmp=kbt.attn.cmp_allclose,
profile_trace=True
)
kbt.summarize(["attn_default.jsonl"])
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L128
======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
torch_flash_compiled_default 0.00% 0.000us 0.00% 0.000us 0.000us 3.491ms 306.07% 3.491ms 3.491ms 1
torch_flash_compiled_default 3.51% 207.593us 60.49% 3.579ms 3.579ms 0.000us 0.00% 1.140ms 1.140ms 1
Torch-Compiled Region: 0/1 15.99% 946.480us 55.71% 3.297ms 1.099ms 0.000us 0.00% 1.140ms 380.147us 3
aten::_scaled_dot_product_flash_attention 1.10% 65.372us 6.61% 390.969us 130.323us 0.000us 0.00% 659.932us 219.977us 3
aten::_flash_attention_forward 1.41% 83.209us 4.38% 259.016us 86.339us 659.932us 57.87% 659.932us 219.977us 3
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 659.932us 57.87% 659.932us 219.977us 3
triton_poi_fused__scaled_dot_product_flash_attention... 2.76% 163.384us 4.58% 270.907us 30.101us 420.798us 36.90% 420.798us 46.755us 9
triton_poi_fused__scaled_dot_product_flash_attention... 0.00% 0.000us 0.00% 0.000us 0.000us 420.798us 36.90% 420.798us 46.755us 9
triton_poi_fused_clone_1 0.89% 52.622us 1.49% 87.972us 29.324us 59.711us 5.24% 59.711us 19.904us 3
triton_poi_fused_clone_1 0.00% 0.000us 0.00% 0.000us 0.000us 59.711us 5.24% 59.711us 19.904us 3
TorchDynamo Cache Lookup 1.27% 75.092us 1.27% 75.092us 25.031us 0.000us 0.00% 0.000us 0.000us 3
Pregraph bytecode 0.21% 12.520us 0.21% 12.520us 4.173us 0.000us 0.00% 0.000us 0.000us 3
AOTDispatcher Runtime Wrapper Prologue 0.44% 25.851us 0.44% 25.851us 8.617us 0.000us 0.00% 0.000us 0.000us 3
Activity Buffer Request 26.40% 1.562ms 26.40% 1.562ms 1.562ms 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernel 2.41% 142.873us 2.41% 142.873us 11.906us 0.000us 0.00% 0.000us 0.000us 12
aten::transpose 0.83% 48.920us 1.13% 66.581us 5.548us 0.000us 0.00% 0.000us 0.000us 12
aten::as_strided 0.30% 17.661us 0.30% 17.661us 1.472us 0.000us 0.00% 0.000us 0.000us 12
aten::empty_like 0.31% 18.491us 0.86% 50.671us 16.890us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_strided 0.54% 32.180us 0.54% 32.180us 10.727us 0.000us 0.00% 0.000us 0.000us 3
aten::empty 1.09% 64.664us 1.09% 64.664us 5.389us 0.000us 0.00% 0.000us 0.000us 12
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 5.918ms
Self CUDA time total: 1.140ms
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L256
======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
torch_flash_compiled_default 0.00% 0.000us 0.00% 0.000us 0.000us 1.396ms 100.73% 1.396ms 1.396ms 1
torch_flash_compiled_default 2.25% 108.483us 51.59% 2.486ms 2.486ms 0.000us 0.00% 1.386ms 1.386ms 1
Torch-Compiled Region: 0/3 10.69% 515.082us 48.73% 2.348ms 782.580us 0.000us 0.00% 1.386ms 461.970us 3
aten::_scaled_dot_product_flash_attention 0.56% 27.139us 3.70% 178.184us 59.395us 0.000us 0.00% 923.002us 307.667us 3
aten::_flash_attention_forward 0.74% 35.625us 2.48% 119.623us 39.874us 923.002us 66.60% 923.002us 307.667us 3
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 923.002us 66.60% 923.002us 307.667us 3
triton_poi_fused__scaled_dot_product_flash_attention... 1.63% 78.772us 3.01% 144.853us 16.095us 436.509us 31.50% 436.509us 48.501us 9
triton_poi_fused__scaled_dot_product_flash_attention... 0.00% 0.000us 0.00% 0.000us 0.000us 436.509us 31.50% 436.509us 48.501us 9
triton_poi_fused_clone_1 0.63% 30.250us 1.08% 51.990us 17.330us 26.400us 1.90% 26.400us 8.800us 3
triton_poi_fused_clone_1 0.00% 0.000us 0.00% 0.000us 0.000us 26.400us 1.90% 26.400us 8.800us 3
TorchDynamo Cache Lookup 0.61% 29.610us 0.61% 29.610us 9.870us 0.000us 0.00% 0.000us 0.000us 3
Pregraph bytecode 0.15% 7.290us 0.15% 7.290us 2.430us 0.000us 0.00% 0.000us 0.000us 3
AOTDispatcher Runtime Wrapper Prologue 0.25% 11.921us 0.25% 11.921us 3.974us 0.000us 0.00% 0.000us 0.000us 3
Activity Buffer Request 29.85% 1.438ms 29.85% 1.438ms 1.438ms 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernel 1.82% 87.821us 1.82% 87.821us 7.318us 0.000us 0.00% 0.000us 0.000us 12
aten::transpose 0.46% 22.012us 0.65% 31.422us 2.619us 0.000us 0.00% 0.000us 0.000us 12
aten::as_strided 0.20% 9.410us 0.20% 9.410us 0.784us 0.000us 0.00% 0.000us 0.000us 12
aten::empty_like 0.13% 6.300us 0.51% 24.540us 8.180us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_strided 0.38% 18.240us 0.38% 18.240us 6.080us 0.000us 0.00% 0.000us 0.000us 3
aten::empty 0.66% 31.688us 0.66% 31.688us 2.641us 0.000us 0.00% 0.000us 0.000us 12
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 4.818ms
Self CUDA time total: 1.386ms
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L320
======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
torch_flash_compiled_default 0.00% 0.000us 0.00% 0.000us 0.000us 1.466ms 100.67% 1.466ms 1.466ms 1
torch_flash_compiled_default 3.49% 107.241us 80.45% 2.474ms 2.474ms 0.000us 0.00% 1.456ms 1.456ms 1
Torch-Compiled Region: 0/5 17.70% 544.140us 75.94% 2.335ms 778.374us 0.000us 0.00% 1.456ms 485.471us 3
aten::_scaled_dot_product_flash_attention 0.90% 27.542us 5.95% 182.995us 60.998us 0.000us 0.00% 949.469us 316.490us 3
aten::_flash_attention_forward 1.21% 37.133us 4.02% 123.584us 41.195us 949.469us 65.19% 949.469us 316.490us 3
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 949.469us 65.19% 949.469us 316.490us 3
triton_poi_fused__scaled_dot_product_flash_attention... 2.64% 81.332us 4.93% 151.573us 16.841us 467.936us 32.13% 467.936us 51.993us 9
triton_poi_fused__scaled_dot_product_flash_attention... 0.00% 0.000us 0.00% 0.000us 0.000us 467.936us 32.13% 467.936us 51.993us 9
triton_poi_fused_clone_1 1.03% 31.672us 1.81% 55.723us 18.574us 39.008us 2.68% 39.008us 13.003us 3
triton_poi_fused_clone_1 0.00% 0.000us 0.00% 0.000us 0.000us 39.008us 2.68% 39.008us 13.003us 3
TorchDynamo Cache Lookup 1.02% 31.392us 1.02% 31.392us 10.464us 0.000us 0.00% 0.000us 0.000us 3
Pregraph bytecode 0.22% 6.840us 0.22% 6.840us 2.280us 0.000us 0.00% 0.000us 0.000us 3
AOTDispatcher Runtime Wrapper Prologue 0.39% 11.950us 0.39% 11.950us 3.983us 0.000us 0.00% 0.000us 0.000us 3
Activity Buffer Request 44.94% 1.382ms 44.94% 1.382ms 1.382ms 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernel 3.07% 94.292us 3.07% 94.292us 7.858us 0.000us 0.00% 0.000us 0.000us 12
aten::transpose 0.72% 22.170us 1.04% 31.869us 2.656us 0.000us 0.00% 0.000us 0.000us 12
aten::as_strided 0.32% 9.699us 0.32% 9.699us 0.808us 0.000us 0.00% 0.000us 0.000us 12
aten::empty_like 0.23% 6.960us 0.81% 24.931us 8.310us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_strided 0.58% 17.971us 0.58% 17.971us 5.990us 0.000us 0.00% 0.000us 0.000us 3
aten::empty 1.00% 30.650us 1.00% 30.650us 2.554us 0.000us 0.00% 0.000us 0.000us 12
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 3.075ms
Self CUDA time total: 1.456ms
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L384
======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
torch_flash_compiled_default 0.00% 0.000us 0.00% 0.000us 0.000us 1.549ms 100.86% 1.549ms 1.549ms 1
torch_flash_compiled_default 3.20% 106.202us 81.86% 2.714ms 2.714ms 0.000us 0.00% 1.535ms 1.535ms 1
Torch-Compiled Region: 0/7 16.60% 550.513us 77.71% 2.577ms 858.965us 0.000us 0.00% 1.535ms 511.773us 3
aten::_scaled_dot_product_flash_attention 0.83% 27.650us 5.45% 180.683us 60.228us 0.000us 0.00% 1.012ms 337.470us 3
aten::_flash_attention_forward 1.09% 36.231us 3.66% 121.432us 40.477us 1.012ms 65.94% 1.012ms 337.470us 3
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 1.012ms 65.94% 1.012ms 337.470us 3
triton_poi_fused__scaled_dot_product_flash_attention... 2.54% 84.320us 12.09% 401.069us 44.563us 481.950us 31.39% 481.950us 53.550us 9
triton_poi_fused__scaled_dot_product_flash_attention... 0.00% 0.000us 0.00% 0.000us 0.000us 481.950us 31.39% 481.950us 53.550us 9
triton_poi_fused_clone_1 1.04% 34.411us 1.74% 57.771us 19.257us 40.959us 2.67% 40.959us 13.653us 3
triton_poi_fused_clone_1 0.00% 0.000us 0.00% 0.000us 0.000us 40.959us 2.67% 40.959us 13.653us 3
TorchDynamo Cache Lookup 0.94% 31.282us 0.94% 31.282us 10.427us 0.000us 0.00% 0.000us 0.000us 3
Pregraph bytecode 0.23% 7.580us 0.23% 7.580us 2.527us 0.000us 0.00% 0.000us 0.000us 3
AOTDispatcher Runtime Wrapper Prologue 0.40% 13.160us 0.40% 13.160us 4.387us 0.000us 0.00% 0.000us 0.000us 3
Activity Buffer Request 41.20% 1.366ms 41.20% 1.366ms 1.366ms 0.000us 0.00% 0.000us 0.000us 1
cuLaunchKernel 10.26% 340.109us 10.26% 340.109us 28.342us 0.000us 0.00% 0.000us 0.000us 12
aten::transpose 0.66% 21.760us 0.95% 31.601us 2.633us 0.000us 0.00% 0.000us 0.000us 12
aten::as_strided 0.30% 9.841us 0.30% 9.841us 0.820us 0.000us 0.00% 0.000us 0.000us 12
aten::empty_like 0.21% 7.070us 0.74% 24.381us 8.127us 0.000us 0.00% 0.000us 0.000us 3
aten::empty_strided 0.52% 17.311us 0.52% 17.311us 5.770us 0.000us 0.00% 0.000us 0.000us 3
aten::empty 0.92% 30.600us 0.92% 30.600us 2.550us 0.000us 0.00% 0.000us 0.000us 12
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 3.316ms
Self CUDA time total: 1.535ms
impl wl p50(ms) ok
torch_flash_compiled_default flux_L128 2.73 True
torch_flash_compiled_default flux_L256 1.02 True
torch_flash_compiled_default flux_L320 2.81 True
torch_flash_compiled_default flux_L384 0.59 True
torch_flash_compiled_default flux_L448 FAIL False
Error: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.
torch_flash_compiled_default flux_L512 FAIL False
Error: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.
▶ UV Install Logs
W1023 16:39:20.458000 2827609 torch/_dynamo/convert_frame.py:1016] [0/8] torch._dynamo hit config.recompile_limit (8)
W1023 16:39:20.458000 2827609 torch/_dynamo/convert_frame.py:1016] [0/8] function: 'torch_flash_base' (/home/ubuntu/Projects/kernels-benchmarks-consolidated/benches/flash_attn/impls/.uvnote/cells/benchmark_default.py:18)
W1023 16:39:20.458000 2827609 torch/_dynamo/convert_frame.py:1016] [0/8] last reason: 0/7: GLOBAL_STATE changed: num_threads
W1023 16:39:20.458000 2827609 torch/_dynamo/convert_frame.py:1016] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1023 16:39:20.458000 2827609 torch/_dynamo/convert_frame.py:1016] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W1023 16:39:20.480000 2827609 torch/_dynamo/convert_frame.py:1016] [0/9] torch._dynamo hit config.recompile_limit (8)
W1023 16:39:20.480000 2827609 torch/_dynamo/convert_frame.py:1016] [0/9] function: 'torch_flash_base' (/home/ubuntu/Projects/kernels-benchmarks-consolidated/benches/flash_attn/impls/.uvnote/cells/benchmark_default.py:18)
W1023 16:39:20.480000 2827609 torch/_dynamo/convert_frame.py:1016] [0/9] last reason: 0/7: GLOBAL_STATE changed: num_threads
W1023 16:39:20.480000 2827609 torch/_dynamo/convert_frame.py:1016] [0/9] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1023 16:39:20.480000 2827609 torch/_dynamo/convert_frame.py:1016] [0/9] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.