diff --git a/flash_attn/impls/artifacts/benchmark/attn.jsonl b/flash_attn/impls/artifacts/benchmark/attn.jsonl
new file mode 100644
index 0000000000000000000000000000000000000000..f1695213932feff6321b80cc6dac9317e83037c9
--- /dev/null
+++ b/flash_attn/impls/artifacts/benchmark/attn.jsonl
@@ -0,0 +1,6 @@
+{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.3389439880847931, "p50": 0.3461120128631592, "p90": 0.3461120128631592, "mean": 0.3452928066253662, "reps": 5, "warmup": 2}, "compile_ms": 0.9463679790496826, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.000362396240234375, "mse": 2.9206275939941406e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.40959998965263367, "p50": 0.41280001401901245, "p90": 0.41286399960517883, "mean": 0.41234560012817384, "reps": 5, "warmup": 2}, "compile_ms": 0.34329599142074585, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.4310399889945984, "p50": 0.4331519901752472, "p90": 0.4362240135669708, "mean": 0.4366208016872406, "reps": 5, "warmup": 2}, "compile_ms": 0.35942399501800537, "peak_bytes": 99680256, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003566741943359375, "mse": 2.8759241104125977e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.4359680116176605, "p50": 0.44361600279808044, "p90": 0.447488009929657, "mean": 0.4450624048709869, "reps": 5, "warmup": 2}, "compile_ms": 0.3678080141544342, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003604888916015625, "mse": 2.8759241104125977e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.4711039960384369, "p50": 0.47513601183891296, "p90": 0.4763199985027313, "mean": 0.4750400006771088, "reps": 5, "warmup": 2}, "compile_ms": 0.40857601165771484, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.86102294921875e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.49663999676704407, "p50": 0.4997119903564453, "p90": 0.5038080215454102, "mean": 0.5009407997131348, "reps": 5, "warmup": 2}, "compile_ms": 0.43724799156188965, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003566741943359375, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null}
diff --git a/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl b/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl
new file mode 100644
index 0000000000000000000000000000000000000000..5ba2e631889b69237014219e513ecfcbae800a20
--- /dev/null
+++ b/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl
@@ -0,0 +1,6 @@
+{"ts": "2025-10-02T16:11:55Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.3563520014286041, "p50": 0.35942399501800537, "p90": 0.3624959886074066, "mean": 0.3856383919715881, "reps": 5, "warmup": 2}, "compile_ms": 2383.33544921875, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:55Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.4926080107688904, "p50": 0.49663999676704407, "p90": 0.5017600059509277, "mean": 0.4982912003993988, "reps": 5, "warmup": 2}, "compile_ms": 76.60860443115234, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:55Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.5335040092468262, "p50": 0.5366079807281494, "p90": 0.5386239886283875, "mean": 0.5369919896125793, "reps": 5, "warmup": 2}, "compile_ms": 74.49088287353516, "peak_bytes": 99876864, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:55Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.5775359869003296, "p50": 0.5868800282478333, "p90": 0.5877760052680969, "mean": 0.5841408014297486, "reps": 5, "warmup": 2}, "compile_ms": 72.97433471679688, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:56Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.6072319746017456, "p50": 0.6113280057907104, "p90": 0.6144000291824341, "mean": 0.6184704065322876, "reps": 5, "warmup": 2}, "compile_ms": 215.12498474121094, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:56Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.6399999856948853, "p50": 0.6430720090866089, "p90": 0.6430720090866089, "mean": 0.6428672075271606, "reps": 5, "warmup": 2}, "compile_ms": 71.8028793334961, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null}
diff --git a/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl b/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl
new file mode 100644
index 0000000000000000000000000000000000000000..8a1e8bb5317f89c8a6cab3ccb82bd9e167c6db90
--- /dev/null
+++ b/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl
@@ -0,0 +1,6 @@
+{"ts": "2025-10-02T16:11:08Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.3665919899940491, "p50": 0.3768320083618164, "p90": 0.41171199083328247, "mean": 0.40020479559898375, "reps": 5, "warmup": 2}, "compile_ms": 2910.97705078125, "peak_bytes": 85722112, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:08Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.5160959959030151, "p50": 0.5489599704742432, "p90": 0.5631359815597534, "mean": 0.5535807967185974, "reps": 5, "warmup": 2}, "compile_ms": 85.84806060791016, "peak_bytes": 97387520, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:09Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.562175989151001, "p50": 0.6144000291824341, "p90": 0.6318079829216003, "mean": 0.6143999934196472, "reps": 5, "warmup": 2}, "compile_ms": 82.77401733398438, "peak_bytes": 99746816, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:09Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.6512640118598938, "p50": 0.6584320068359375, "p90": 0.6799359917640686, "mean": 0.6754495978355408, "reps": 5, "warmup": 2}, "compile_ms": 81.94969940185547, "peak_bytes": 101843968, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:09Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.6973119974136353, "p50": 0.7014080286026001, "p90": 0.7229440212249756, "mean": 0.7210752129554748, "reps": 5, "warmup": 2}, "compile_ms": 81.1141128540039, "peak_bytes": 103810048, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
+{"ts": "2025-10-02T16:11:10Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.7485439777374268, "p50": 0.7557439804077148, "p90": 0.7710719704627991, "mean": 0.7735359907150269, "reps": 5, "warmup": 2}, "compile_ms": 767.1397094726562, "peak_bytes": 106562560, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null}
diff --git a/flash_attn/impls/cells/benchmark.py b/flash_attn/impls/cells/benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..c558501ad2f4abcbaec8236272134bb8e8b0cfc4
--- /dev/null
+++ b/flash_attn/impls/cells/benchmark.py
@@ -0,0 +1,72 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# "kernels",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+from kernels import get_kernel
+
+hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn", revision="v0.0.2")
+
+
+def hf_flash_attention(query, key, value):
+ """HuggingFace Kernels Flash Attention"""
+ return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
+
+
+kbt.add(
+ "hf_kernels_flash_attn",
+ hf_flash_attention,
+ tags={"family": "hf-kernels", "backend": "flash-attn", "compile": "none"},
+)
+
+if __name__ == "__main__":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ if device == "cpu":
+ print("HF Kernels Flash Attention requires CUDA - skipping benchmark")
+ sys.exit(0)
+
+ dtype = "bfloat16"
+
+ # Flux-like workloads
+ base = 1024
+ flux_sizes = [128, 256, 320, 384, 448, 512]
+ heads = 24
+ head_dim = 128
+
+ wl = []
+ for L in flux_sizes:
+ wl.append(
+ {
+ "name": f"flux_L{L}",
+ "batch": 1,
+ "seq_len": base + L,
+ "heads": heads,
+ "head_dim": head_dim,
+ "dtype": dtype,
+ "device": device,
+ "seed": 0,
+ }
+ )
+
+ kbt.run(
+ wl,
+ jsonl="attn.jsonl",
+ reps=5,
+ warmup=2,
+ gen=kbt.attn.gen_qkv,
+ ref=kbt.attn.ref_math,
+ cmp=kbt.attn.cmp_allclose,
+ )
+ kbt.summarize(["attn.jsonl"])
\ No newline at end of file
diff --git a/flash_attn/impls/cells/benchmark_default.py b/flash_attn/impls/cells/benchmark_default.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc2fd06ac69ffe1f5bc88d1821b17447dc90c846
--- /dev/null
+++ b/flash_attn/impls/cells/benchmark_default.py
@@ -0,0 +1,70 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash_base(q, k, v):
+ qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+ o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+ return o.transpose(1, 2).contiguous()
+
+
+# Compile with default mode
+compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)
+
+kbt.add(
+ "torch_flash_compiled_default",
+ compiled_flash_default,
+ tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
+)
+
+if __name__ == "__main__":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ dtype = "float32" if device == "cpu" else "bfloat16"
+
+ # Flux-like workloads
+ base = 1024 if device == "cuda" else 512
+ flux_sizes = (
+ [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+ )
+ heads = 24 if device == "cuda" else 8
+ head_dim = 128 if device == "cuda" else 64
+
+ wl = []
+ for L in flux_sizes:
+ wl.append(
+ {
+ "name": f"flux_L{L}",
+ "batch": 1,
+ "seq_len": base + L,
+ "heads": heads,
+ "head_dim": head_dim,
+ "dtype": dtype,
+ "device": device,
+ "seed": 0,
+ }
+ )
+
+ kbt.run(
+ wl,
+ jsonl="attn_default.jsonl",
+ reps=5,
+ warmup=2,
+ gen=kbt.attn.gen_qkv,
+ ref=kbt.attn.ref_math,
+ cmp=kbt.attn.cmp_allclose,
+ )
+ kbt.summarize(["attn_default.jsonl"])
\ No newline at end of file
diff --git a/flash_attn/impls/cells/benchmark_max_autotune.py b/flash_attn/impls/cells/benchmark_max_autotune.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd96e676c4d9ebdf709b701a7b9a71b9d51774fd
--- /dev/null
+++ b/flash_attn/impls/cells/benchmark_max_autotune.py
@@ -0,0 +1,70 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash_base(q, k, v):
+ qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+ o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+ return o.transpose(1, 2).contiguous()
+
+
+# Compile with max-autotune mode
+compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False)
+
+kbt.add(
+ "torch_flash_compiled_max_autotune",
+ compiled_flash_max_autotune,
+ tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
+)
+
+if __name__ == "__main__":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ dtype = "float32" if device == "cpu" else "bfloat16"
+
+ # Flux-like workloads
+ base = 1024 if device == "cuda" else 512
+ flux_sizes = (
+ [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+ )
+ heads = 24 if device == "cuda" else 8
+ head_dim = 128 if device == "cuda" else 64
+
+ wl = []
+ for L in flux_sizes:
+ wl.append(
+ {
+ "name": f"flux_L{L}",
+ "batch": 1,
+ "seq_len": base + L,
+ "heads": heads,
+ "head_dim": head_dim,
+ "dtype": dtype,
+ "device": device,
+ "seed": 0,
+ }
+ )
+
+ kbt.run(
+ wl,
+ jsonl="attn_max_autotune.jsonl",
+ reps=5,
+ warmup=2,
+ gen=kbt.attn.gen_qkv,
+ ref=kbt.attn.ref_math,
+ cmp=kbt.attn.cmp_allclose,
+ )
+ kbt.summarize(["attn_max_autotune.jsonl"])
\ No newline at end of file
diff --git a/flash_attn/impls/cells/nv.py b/flash_attn/impls/cells/nv.py
new file mode 100644
index 0000000000000000000000000000000000000000..80eef60a7536ed875fb21731ab2d059458bd20b4
--- /dev/null
+++ b/flash_attn/impls/cells/nv.py
@@ -0,0 +1,3 @@
+import subprocess
+
+print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
\ No newline at end of file
diff --git a/flash_attn/impls/compiled_variants.html b/flash_attn/impls/compiled_variants.html
new file mode 100644
index 0000000000000000000000000000000000000000..6a8ad1c87fe3de178f48216a9a342d6d09eda581
--- /dev/null
+++ b/flash_attn/impls/compiled_variants.html
@@ -0,0 +1,4163 @@
+
+
+
+
+
+ compiled_variants
+
+
+
+
+
+
+
+
+
+
+
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31
+
+
+
+
+
Torch Compile Variants!
+
This file benchmarks Flash Attention with different torch.compile modes.
+
Flash Attention with torch.compile(mode="default")
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash_base ( q , k , v ):
+ qt , kt , vt = ( x . transpose ( 1 , 2 ) . contiguous () for x in ( q , k , v ))
+ with torch . nn . attention . sdpa_kernel ( torch . nn . attention . SDPBackend . FLASH_ATTENTION ):
+ o = torch . nn . functional . scaled_dot_product_attention ( qt , kt , vt )
+ return o . transpose ( 1 , 2 ) . contiguous ()
+
+
+# Compile with default mode
+compiled_flash_default = torch . compile ( torch_flash_base , mode = "default" , fullgraph = True , dynamic = False )
+
+kbt . add (
+ "torch_flash_compiled_default" ,
+ compiled_flash_default ,
+ tags = { "family" : "torch-sdpa" , "backend" : "FLASH" , "compile" : "default" },
+)
+
+if __name__ == "__main__" :
+ device = "cuda" if torch . cuda . is_available () else "cpu"
+ dtype = "float32" if device == "cpu" else "bfloat16"
+
+ # Flux-like workloads
+ base = 1024 if device == "cuda" else 512
+ flux_sizes = (
+ [ 128 , 256 , 320 , 384 , 448 , 512 ] if device == "cuda" else [ 64 , 128 , 192 , 256 ]
+ )
+ heads = 24 if device == "cuda" else 8
+ head_dim = 128 if device == "cuda" else 64
+
+ wl = []
+ for L in flux_sizes :
+ wl . append (
+ {
+ "name" : f "flux_L { L } " ,
+ "batch" : 1 ,
+ "seq_len" : base + L ,
+ "heads" : heads ,
+ "head_dim" : head_dim ,
+ "dtype" : dtype ,
+ "device" : device ,
+ "seed" : 0 ,
+ }
+ )
+
+ kbt . run (
+ wl ,
+ jsonl = "attn_default.jsonl" ,
+ reps = 5 ,
+ warmup = 2 ,
+ gen = kbt . attn . gen_qkv ,
+ ref = kbt . attn . ref_math ,
+ cmp = kbt . attn . cmp_allclose ,
+ )
+ kbt . summarize ([ "attn_default.jsonl" ])
+
+
+
+
+
+
+
impl wl p50(ms) ok
+torch_flash_compiled_default flux_L128 0.36 True
+torch_flash_compiled_default flux_L256 0.50 True
+torch_flash_compiled_default flux_L320 0.54 True
+torch_flash_compiled_default flux_L384 0.59 True
+torch_flash_compiled_default flux_L448 0.61 True
+torch_flash_compiled_default flux_L512 0.64 True
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading matplotlib (8.3MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading kiwisolver (1.4MiB)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+Downloading fonttools (4.7MiB)
+Downloading triton (148.4MiB)
+Downloading numpy (15.9MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading torch (846.8MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading setuptools (1.1MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading networkx (1.9MiB)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading sympy (6.0MiB)
+Downloading pillow (6.3MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading setuptools
+ Downloading fonttools
+ Downloading networkx
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading matplotlib
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading torch
+Installed 37 packages in 203ms
+
+
+
+
+
+
+
Flash Attention with torch.compile(mode="max-autotune")
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash_base ( q , k , v ):
+ qt , kt , vt = ( x . transpose ( 1 , 2 ) . contiguous () for x in ( q , k , v ))
+ with torch . nn . attention . sdpa_kernel ( torch . nn . attention . SDPBackend . FLASH_ATTENTION ):
+ o = torch . nn . functional . scaled_dot_product_attention ( qt , kt , vt )
+ return o . transpose ( 1 , 2 ) . contiguous ()
+
+
+# Compile with max-autotune mode
+compiled_flash_max_autotune = torch . compile ( torch_flash_base , mode = "max-autotune" , fullgraph = True , dynamic = False )
+
+kbt . add (
+ "torch_flash_compiled_max_autotune" ,
+ compiled_flash_max_autotune ,
+ tags = { "family" : "torch-sdpa" , "backend" : "FLASH" , "compile" : "max-autotune" },
+)
+
+if __name__ == "__main__" :
+ device = "cuda" if torch . cuda . is_available () else "cpu"
+ dtype = "float32" if device == "cpu" else "bfloat16"
+
+ # Flux-like workloads
+ base = 1024 if device == "cuda" else 512
+ flux_sizes = (
+ [ 128 , 256 , 320 , 384 , 448 , 512 ] if device == "cuda" else [ 64 , 128 , 192 , 256 ]
+ )
+ heads = 24 if device == "cuda" else 8
+ head_dim = 128 if device == "cuda" else 64
+
+ wl = []
+ for L in flux_sizes :
+ wl . append (
+ {
+ "name" : f "flux_L { L } " ,
+ "batch" : 1 ,
+ "seq_len" : base + L ,
+ "heads" : heads ,
+ "head_dim" : head_dim ,
+ "dtype" : dtype ,
+ "device" : device ,
+ "seed" : 0 ,
+ }
+ )
+
+ kbt . run (
+ wl ,
+ jsonl = "attn_max_autotune.jsonl" ,
+ reps = 5 ,
+ warmup = 2 ,
+ gen = kbt . attn . gen_qkv ,
+ ref = kbt . attn . ref_math ,
+ cmp = kbt . attn . cmp_allclose ,
+ )
+ kbt . summarize ([ "attn_max_autotune.jsonl" ])
+
+
+
+
+
+
+
impl wl p50(ms) ok
+torch_flash_compiled_max_autotune flux_L128 0.38 True
+torch_flash_compiled_max_autotune flux_L256 0.55 True
+torch_flash_compiled_max_autotune flux_L320 0.61 True
+torch_flash_compiled_max_autotune flux_L384 0.66 True
+torch_flash_compiled_max_autotune flux_L448 0.70 True
+torch_flash_compiled_max_autotune flux_L512 0.76 True
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading setuptools (1.1MiB)
+Downloading pillow (6.3MiB)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading networkx (1.9MiB)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading kiwisolver (1.4MiB)
+Downloading fonttools (4.7MiB)
+Downloading numpy (15.9MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading matplotlib (8.3MiB)
+Downloading torch (846.8MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading triton (148.4MiB)
+Downloading sympy (6.0MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading setuptools
+ Downloading fonttools
+ Downloading networkx
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading matplotlib
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading torch
+Installed 37 packages in 208ms
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/flash_attn/impls/flash_attention.html b/flash_attn/impls/flash_attention.html
new file mode 100644
index 0000000000000000000000000000000000000000..69fae8a00674f2b96a097313a070476b23f6617a
--- /dev/null
+++ b/flash_attn/impls/flash_attention.html
@@ -0,0 +1,4057 @@
+
+
+
+
+
+ flash_attention
+
+
+
+
+
+
+
+
+
+
+
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31
+
+
+
+
+
Flash Attention Implementation
+
GPU Info
+
+
+
+
+
import subprocess
+
+print ( subprocess . run ([ "nvidia-smi" ], capture_output = True , text = True ) . stdout )
+
+
+
+
+
+
+
Thu Oct 2 16:12:42 2025
++-----------------------------------------------------------------------------------------+
+| NVIDIA-SMI 560.35.05 Driver Version: 560.35.05 CUDA Version: 12.6 |
+|-----------------------------------------+------------------------+----------------------+
+| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
+| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
+| | | MIG M. |
+|=========================================+========================+======================|
+| 0 NVIDIA L4 Off | 00000000:38:00.0 Off | 0 |
+| N/A 41C P0 27W / 72W | 1MiB / 23034MiB | 0% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+| 1 NVIDIA L4 Off | 00000000:3A:00.0 Off | 0 |
+| N/A 41C P0 27W / 72W | 1MiB / 23034MiB | 2% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+| 2 NVIDIA L4 Off | 00000000:3C:00.0 Off | 0 |
+| N/A 44C P0 29W / 72W | 1MiB / 23034MiB | 2% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+| 3 NVIDIA L4 Off | 00000000:3E:00.0 Off | 0 |
+| N/A 42C P0 29W / 72W | 1MiB / 23034MiB | 2% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+
++-----------------------------------------------------------------------------------------+
+| Processes: |
+| GPU GI CI PID Type Process name GPU Memory |
+| ID ID Usage |
+|=========================================================================================|
+| No running processes found |
++-----------------------------------------------------------------------------------------+
+
+
+
+
+
+
Flash Attention Benchmark
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash ( q , k , v ):
+ qt , kt , vt = ( x . transpose ( 1 , 2 ) . contiguous () for x in ( q , k , v ))
+ with torch . nn . attention . sdpa_kernel ( torch . nn . attention . SDPBackend . FLASH_ATTENTION ):
+ o = torch . nn . functional . scaled_dot_product_attention ( qt , kt , vt )
+ return o . transpose ( 1 , 2 ) . contiguous ()
+
+kbt . add (
+ "torch_flash_ma" ,
+ torch_flash ,
+ tags = { "family" : "torch-sdpa" , "backend" : "FLASH" , "compile" : "max-autotune" },
+)
+
+if __name__ == "__main__" :
+ device = "cuda" if torch . cuda . is_available () else "cpu"
+ dtype = "float32" if device == "cpu" else "bfloat16"
+
+ # Flux-like workloads scaled down for CPU testing
+ base = 1024 if device == "cuda" else 512
+ flux_sizes = (
+ [ 128 , 256 , 320 , 384 , 448 , 512 ] if device == "cuda" else [ 64 , 128 , 192 , 256 ]
+ )
+ heads = 24 if device == "cuda" else 8
+ head_dim = 128 if device == "cuda" else 64
+
+ wl = []
+ for L in flux_sizes :
+ wl . append (
+ {
+ "name" : f "flux_L { L } " ,
+ "batch" : 1 ,
+ "seq_len" : base + L ,
+ "heads" : heads ,
+ "head_dim" : head_dim ,
+ "dtype" : dtype ,
+ "device" : device ,
+ "seed" : 0 ,
+ }
+ )
+
+ kbt . run (
+ wl ,
+ jsonl = "attn.jsonl" ,
+ reps = 5 ,
+ warmup = 2 ,
+ gen = kbt . attn . gen_qkv ,
+ ref = kbt . attn . ref_math ,
+ cmp = kbt . attn . cmp_allclose ,
+ )
+ kbt . summarize ([ "attn.jsonl" ])
+
+
+
+
+
+
+
impl wl p50(ms) ok
+torch_flash_ma flux_L128 0.41 True
+torch_flash_ma flux_L256 0.52 True
+torch_flash_ma flux_L320 0.55 True
+torch_flash_ma flux_L384 0.59 True
+torch_flash_ma flux_L448 0.64 True
+torch_flash_ma flux_L512 0.68 True
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading fonttools (4.7MiB)
+Downloading matplotlib (8.3MiB)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading torch (846.8MiB)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+Downloading pillow (6.3MiB)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading sympy (6.0MiB)
+Downloading setuptools (1.1MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading networkx (1.9MiB)
+Downloading triton (148.4MiB)
+Downloading kiwisolver (1.4MiB)
+Downloading numpy (15.9MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading setuptools
+ Downloading fonttools
+ Downloading networkx
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading matplotlib
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading torch
+Installed 37 packages in 224ms
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/flash_attn/impls/hf_kernels_flash_attn.html b/flash_attn/impls/hf_kernels_flash_attn.html
new file mode 100644
index 0000000000000000000000000000000000000000..a73274cdbcce6cc4d5e87b9056f24bf359798266
--- /dev/null
+++ b/flash_attn/impls/hf_kernels_flash_attn.html
@@ -0,0 +1,4010 @@
+
+
+
+
+
+ hf_kernels_flash_attn
+
+
+
+
+
+
+
+
+
+
+
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31
+
+
+
+
+
HF Kernels - Flash Attention
+
HuggingFace Kernels Flash Attention Benchmark
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# "kernels",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+from kernels import get_kernel
+
+hf_kernels_flash_attn = get_kernel ( "kernels-community/flash-attn" , revision = "v0.0.2" )
+
+
+def hf_flash_attention ( query , key , value ):
+ """HuggingFace Kernels Flash Attention"""
+ return hf_kernels_flash_attn . fwd ( query , key , value , is_causal = False )[ 0 ]
+
+
+kbt . add (
+ "hf_kernels_flash_attn" ,
+ hf_flash_attention ,
+ tags = { "family" : "hf-kernels" , "backend" : "flash-attn" , "compile" : "none" },
+)
+
+if __name__ == "__main__" :
+ device = "cuda" if torch . cuda . is_available () else "cpu"
+
+ if device == "cpu" :
+ print ( "HF Kernels Flash Attention requires CUDA - skipping benchmark" )
+ sys . exit ( 0 )
+
+ dtype = "bfloat16"
+
+ # Flux-like workloads
+ base = 1024
+ flux_sizes = [ 128 , 256 , 320 , 384 , 448 , 512 ]
+ heads = 24
+ head_dim = 128
+
+ wl = []
+ for L in flux_sizes :
+ wl . append (
+ {
+ "name" : f "flux_L { L } " ,
+ "batch" : 1 ,
+ "seq_len" : base + L ,
+ "heads" : heads ,
+ "head_dim" : head_dim ,
+ "dtype" : dtype ,
+ "device" : device ,
+ "seed" : 0 ,
+ }
+ )
+
+ kbt . run (
+ wl ,
+ jsonl = "attn.jsonl" ,
+ reps = 5 ,
+ warmup = 2 ,
+ gen = kbt . attn . gen_qkv ,
+ ref = kbt . attn . ref_math ,
+ cmp = kbt . attn . cmp_allclose ,
+ )
+ kbt . summarize ([ "attn.jsonl" ])
+
+
+
+
+
+
+
impl wl p50(ms) ok
+hf_kernels_flash_attn flux_L128 0.25 True
+hf_kernels_flash_attn flux_L256 0.32 True
+hf_kernels_flash_attn flux_L320 0.34 True
+hf_kernels_flash_attn flux_L384 0.35 True
+hf_kernels_flash_attn flux_L448 0.38 True
+hf_kernels_flash_attn flux_L512 0.42 True
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading kiwisolver (1.4MiB)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading matplotlib (8.3MiB)
+Downloading fonttools (4.7MiB)
+Downloading setuptools (1.1MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading sympy (6.0MiB)
+Downloading hf-xet (3.0MiB)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+Downloading pillow (6.3MiB)
+Downloading networkx (1.9MiB)
+Downloading numpy (15.9MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading torch (846.8MiB)
+Downloading triton (148.4MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading hf-xet
+ Downloading setuptools
+ Downloading networkx
+ Downloading fonttools
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading matplotlib
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading torch
+Installed 47 packages in 255ms
+
+
+
Fetching 20 files: 0%| | 0/20 [00:00<?, ?it/s]
+Fetching 20 files: 5%|▌ | 1/20 [00:00<00:05, 3.64it/s]
+Fetching 20 files: 10%|█ | 2/20 [00:02<00:22, 1.24s/it]
+Fetching 20 files: 100%|██████████| 20/20 [00:02<00:00, 9.14it/s]
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/flash_attn/impls/hf_kernels_flash_attn3.html b/flash_attn/impls/hf_kernels_flash_attn3.html
new file mode 100644
index 0000000000000000000000000000000000000000..f745e3998546d5effda35036ae155237f9caf603
--- /dev/null
+++ b/flash_attn/impls/hf_kernels_flash_attn3.html
@@ -0,0 +1,4007 @@
+
+
+
+
+
+ hf_kernels_flash_attn3
+
+
+
+
+
+
+
+
+
+
+
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31
+
+
+
+
+
HF Kernels - Flash Attention 3
+
HuggingFace Kernels Flash Attention 3 Benchmark
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# "kernels",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+from kernels import get_kernel
+
+hf_kernels_flash_attn3 = get_kernel ( "kernels-community/flash-attn3" )
+
+
+def hf_flash_attention3 ( query , key , value ):
+ return hf_kernels_flash_attn3 . flash_attn_func ( query , key , value , causal = False )[ 0 ]
+
+
+kbt . add (
+ "hf_kernels_flash_attn3" ,
+ hf_flash_attention3 ,
+ tags = { "family" : "hf-kernels" , "backend" : "flash-attn3" , "compile" : "none" },
+)
+
+if __name__ == "__main__" :
+ device = "cuda" if torch . cuda . is_available () else "cpu"
+
+ if device == "cpu" :
+ print ( "HF Kernels Flash Attention 3 requires CUDA - skipping benchmark" )
+ sys . exit ( 0 )
+
+ dtype = "bfloat16"
+
+ # Flux-like workloads
+ base = 1024
+ flux_sizes = [ 128 , 256 , 320 , 384 , 448 , 512 ]
+ heads = 24
+ head_dim = 128
+
+ wl = []
+ for L in flux_sizes :
+ wl . append (
+ {
+ "name" : f "flux_L { L } " ,
+ "batch" : 1 ,
+ "seq_len" : base + L ,
+ "heads" : heads ,
+ "head_dim" : head_dim ,
+ "dtype" : dtype ,
+ "device" : device ,
+ "seed" : 0 ,
+ }
+ )
+
+ kbt . run (
+ wl ,
+ jsonl = "attn.jsonl" ,
+ reps = 5 ,
+ warmup = 2 ,
+ gen = kbt . attn . gen_qkv ,
+ ref = kbt . attn . ref_math ,
+ cmp = kbt . attn . cmp_allclose ,
+ )
+ kbt . summarize ([ "attn.jsonl" ])
+
+
+
+
+
+
+
impl wl p50(ms) ok
+hf_kernels_flash_attn3 flux_L128 0.28 True
+hf_kernels_flash_attn3 flux_L256 0.34 True
+hf_kernels_flash_attn3 flux_L320 0.36 True
+hf_kernels_flash_attn3 flux_L384 0.37 True
+hf_kernels_flash_attn3 flux_L448 0.40 True
+hf_kernels_flash_attn3 flux_L512 0.43 True
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading pillow (6.3MiB)
+Downloading hf-xet (3.0MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+Downloading kiwisolver (1.4MiB)
+Downloading fonttools (4.7MiB)
+Downloading matplotlib (8.3MiB)
+Downloading networkx (1.9MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading numpy (15.9MiB)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading torch (846.8MiB)
+Downloading triton (148.4MiB)
+Downloading setuptools (1.1MiB)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading sympy (6.0MiB)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading hf-xet
+ Downloading setuptools
+ Downloading networkx
+ Downloading fonttools
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading matplotlib
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading torch
+Installed 47 packages in 229ms
+
+
+
Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s]
+Fetching 4 files: 25%|██▌ | 1/4 [00:00<00:00, 3.56it/s]
+Fetching 4 files: 50%|█████ | 2/4 [00:02<00:02, 1.32s/it]
+Fetching 4 files: 100%|██████████| 4/4 [00:02<00:00, 1.72it/s]
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/flash_attn/impls/index.html b/flash_attn/impls/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..fbf5a24033c4d68d64ce4f1be0ba13266dd4e89d
--- /dev/null
+++ b/flash_attn/impls/index.html
@@ -0,0 +1,94 @@
+
+
+
+
+
+ Index of /flash_attn/impls
+
+
+
+
+ Index of /flash_attn/impls
+
+
+
\ No newline at end of file
diff --git a/flash_attn/impls/mem_efficient_attention.html b/flash_attn/impls/mem_efficient_attention.html
new file mode 100644
index 0000000000000000000000000000000000000000..3400970cc242db031fdf36b6c82329d15913de75
--- /dev/null
+++ b/flash_attn/impls/mem_efficient_attention.html
@@ -0,0 +1,3998 @@
+
+
+
+
+
+ mem_efficient_attention
+
+
+
+
+
+
+
+
+
+
+
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31
+
+
+
+
+
Memory Efficient Attention Implementation
+
Memory Efficient SDPA Benchmark
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_mem_eff ( q , k , v ):
+ qt , kt , vt = ( x . transpose ( 1 , 2 ) . contiguous () for x in ( q , k , v ))
+ with torch . nn . attention . sdpa_kernel (
+ torch . nn . attention . SDPBackend . EFFICIENT_ATTENTION
+ ):
+ o = torch . nn . functional . scaled_dot_product_attention ( qt , kt , vt )
+ return o . transpose ( 1 , 2 ) . contiguous ()
+
+kbt . add (
+ "torch_mem_eff" ,
+ torch_mem_eff ,
+ tags = { "family" : "torch-sdpa" , "backend" : "EFFICIENT" , "compile" : "none" },
+)
+
+if __name__ == "__main__" :
+ device = "cuda" if torch . cuda . is_available () else "cpu"
+ dtype = "float32" if device == "cpu" else "bfloat16"
+
+ # Flux-like workloads scaled down for CPU testing
+ base = 1024 if device == "cuda" else 512
+ flux_sizes = (
+ [ 128 , 256 , 320 , 384 , 448 , 512 ] if device == "cuda" else [ 64 , 128 , 192 , 256 ]
+ )
+ heads = 24 if device == "cuda" else 8
+ head_dim = 128 if device == "cuda" else 64
+
+ wl = []
+ for L in flux_sizes :
+ wl . append (
+ {
+ "name" : f "flux_L { L } " ,
+ "batch" : 1 ,
+ "seq_len" : base + L ,
+ "heads" : heads ,
+ "head_dim" : head_dim ,
+ "dtype" : dtype ,
+ "device" : device ,
+ "seed" : 0 ,
+ }
+ )
+
+ kbt . run (
+ wl ,
+ jsonl = "attn.jsonl" ,
+ reps = 5 ,
+ warmup = 2 ,
+ gen = kbt . attn . gen_qkv ,
+ ref = kbt . attn . ref_math ,
+ cmp = kbt . attn . cmp_allclose ,
+ )
+ kbt . summarize ([ "attn.jsonl" ])
+
+
+
+
+
+
+
impl wl p50(ms) ok
+torch_mem_eff flux_L128 0.48 True
+torch_mem_eff flux_L256 0.63 True
+torch_mem_eff flux_L320 0.70 True
+torch_mem_eff flux_L384 0.83 True
+torch_mem_eff flux_L448 0.95 True
+torch_mem_eff flux_L512 1.00 True
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+Downloading networkx (1.9MiB)
+Downloading sympy (6.0MiB)
+Downloading fonttools (4.7MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading triton (148.4MiB)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading matplotlib (8.3MiB)
+Downloading pillow (6.3MiB)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading kiwisolver (1.4MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading torch (846.8MiB)
+Downloading setuptools (1.1MiB)
+Downloading numpy (15.9MiB)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading setuptools
+ Downloading fonttools
+ Downloading networkx
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading matplotlib
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading torch
+Installed 37 packages in 248ms
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/flash_attn/impls/sage_attention.html b/flash_attn/impls/sage_attention.html
new file mode 100644
index 0000000000000000000000000000000000000000..20d60b76924d107347f498604c5ef8072e658fda
--- /dev/null
+++ b/flash_attn/impls/sage_attention.html
@@ -0,0 +1,4022 @@
+
+
+
+
+
+ sage_attention
+
+
+
+
+
+
+
+
+
+
+
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31
+
+
+
+
+
SageAttention Implementation
+
SageAttention Benchmark (INT8 Quantized)
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels",
+# "kernels-benchmark-tools",
+# "sageattention",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+# from sageattention import sageattn_qk_int8_pv_fp16_cuda
+
+
+# def sage_attention(q, k, v):
+# """SageAttention with INT8 Q/K quantization and FP16 P/V"""
+# return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout="NHD")
+
+from kernels import get_kernel
+
+hf_kernels_sage_attn = get_kernel ( "kernels-community/sage_attention" )
+
+
+def sage_attention ( query , key , value ):
+ """HuggingFace Kernels Flash Attention"""
+ return hf_kernels_sage_attn . fwd ( query , key , value , is_causal = False )[ 0 ]
+
+kbt . add (
+ "sage_int8_fp16" ,
+ sage_attention ,
+ tags = { "family" : "sageattention" , "backend" : "int8_fp16_cuda" , "compile" : "none" },
+)
+
+if __name__ == "__main__" :
+ device = "cuda" if torch . cuda . is_available () else "cpu"
+
+ if device == "cpu" :
+ print ( "SageAttention requires CUDA - skipping benchmark" )
+ sys . exit ( 0 )
+
+ dtype = "bfloat16"
+
+ # Flux-like workloads
+ base = 1024
+ flux_sizes = [ 128 , 256 , 320 , 384 , 448 , 512 ]
+ heads = 24
+ head_dim = 128
+
+ wl = []
+ for L in flux_sizes :
+ wl . append (
+ {
+ "name" : f "flux_L { L } " ,
+ "batch" : 1 ,
+ "seq_len" : base + L ,
+ "heads" : heads ,
+ "head_dim" : head_dim ,
+ "dtype" : dtype ,
+ "device" : device ,
+ "seed" : 0 ,
+ }
+ )
+
+ kbt . run (
+ wl ,
+ jsonl = "attn.jsonl" ,
+ reps = 5 ,
+ warmup = 2 ,
+ gen = kbt . attn . gen_qkv ,
+ ref = kbt . attn . ref_math ,
+ cmp = kbt . attn . cmp_allclose ,
+ )
+ kbt . summarize ([ "attn.jsonl" ])
+
+
+
+
+
+
+
impl wl p50(ms) ok
+sage_int8_fp16 flux_L128 FAIL False
+ Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd'
+sage_int8_fp16 flux_L256 FAIL False
+ Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd'
+sage_int8_fp16 flux_L320 FAIL False
+ Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd'
+sage_int8_fp16 flux_L384 FAIL False
+ Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd'
+sage_int8_fp16 flux_L448 FAIL False
+ Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd'
+sage_int8_fp16 flux_L512 FAIL False
+ Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd'
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading hf-xet (3.0MiB)
+Downloading sympy (6.0MiB)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+Downloading kiwisolver (1.4MiB)
+Downloading pillow (6.3MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading numpy (15.9MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading matplotlib (8.3MiB)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading fonttools (4.7MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading setuptools (1.1MiB)
+Downloading networkx (1.9MiB)
+Downloading triton (148.4MiB)
+Downloading torch (846.8MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading hf-xet
+ Downloading setuptools
+ Downloading networkx
+ Downloading fonttools
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading matplotlib
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading torch
+Installed 48 packages in 239ms
+
+
+
Fetching 11 files: 0%| | 0/11 [00:00<?, ?it/s]
+Fetching 11 files: 9%|▉ | 1/11 [00:00<00:05, 1.85it/s]
+Fetching 11 files: 45%|████▌ | 5/11 [00:00<00:00, 6.46it/s]
+Fetching 11 files: 73%|███████▎ | 8/11 [00:01<00:00, 10.07it/s]
+Fetching 11 files: 100%|██████████| 11/11 [00:01<00:00, 10.94it/s]
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/flash_attn/impls/xformers.html b/flash_attn/impls/xformers.html
new file mode 100644
index 0000000000000000000000000000000000000000..1b612e671c94677dfc3925bd110cbceb74640fed
--- /dev/null
+++ b/flash_attn/impls/xformers.html
@@ -0,0 +1,4000 @@
+
+
+
+
+
+ xformers
+
+
+
+
+
+
+
+
+
+
+
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31
+
+
+
+
+
xFormers Memory Efficient Attention
+
xFormers Benchmark
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# "xformers",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+import xformers.ops as xops
+
+
+def xformers_attention ( q , k , v ):
+ """xFormers memory efficient attention"""
+ # xFormers expects [batch, seq_len, heads, head_dim]
+ return xops . memory_efficient_attention ( q , k , v )
+
+
+kbt . add (
+ "xformers_meff" ,
+ xformers_attention ,
+ tags = { "family" : "xformers" , "backend" : "memory_efficient" , "compile" : "none" },
+)
+
+if __name__ == "__main__" :
+ device = "cuda" if torch . cuda . is_available () else "cpu"
+ dtype = "float32" if device == "cpu" else "bfloat16"
+
+ # Flux-like workloads
+ base = 1024 if device == "cuda" else 512
+ flux_sizes = (
+ [ 128 , 256 , 320 , 384 , 448 , 512 ] if device == "cuda" else [ 64 , 128 , 192 , 256 ]
+ )
+ heads = 24 if device == "cuda" else 8
+ head_dim = 128 if device == "cuda" else 64
+
+ wl = []
+ for L in flux_sizes :
+ wl . append (
+ {
+ "name" : f "flux_L { L } " ,
+ "batch" : 1 ,
+ "seq_len" : base + L ,
+ "heads" : heads ,
+ "head_dim" : head_dim ,
+ "dtype" : dtype ,
+ "device" : device ,
+ "seed" : 0 ,
+ }
+ )
+
+ kbt . run (
+ wl ,
+ jsonl = "attn.jsonl" ,
+ reps = 5 ,
+ warmup = 2 ,
+ gen = kbt . attn . gen_qkv ,
+ ref = kbt . attn . ref_math ,
+ cmp = kbt . attn . cmp_allclose ,
+ )
+ kbt . summarize ([ "attn.jsonl" ])
+
+
+
+
+
+
+
impl wl p50(ms) ok
+xformers_meff flux_L128 0.35 True
+xformers_meff flux_L256 0.41 True
+xformers_meff flux_L320 0.43 True
+xformers_meff flux_L384 0.44 True
+xformers_meff flux_L448 0.48 True
+xformers_meff flux_L512 0.50 True
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading kiwisolver (1.4MiB)
+Downloading setuptools (1.1MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading pillow (6.3MiB)
+Downloading numpy (15.9MiB)
+Downloading matplotlib (8.3MiB)
+Downloading fonttools (4.7MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading xformers (111.8MiB)
+Downloading networkx (1.9MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+Downloading sympy (6.0MiB)
+Downloading triton (148.4MiB)
+Downloading torch (846.8MiB)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading setuptools
+ Downloading networkx
+ Downloading fonttools
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading matplotlib
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading xformers
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading torch
+Installed 38 packages in 250ms
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/flash_attn/index.html b/flash_attn/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..eea7df846d9f2d44c6c6e03a5ac30d00cecd90cf
--- /dev/null
+++ b/flash_attn/index.html
@@ -0,0 +1,89 @@
+
+
+
+
+
+ Index of /flash_attn
+
+
+
+
+ Index of /flash_attn
+
+
+
\ No newline at end of file
diff --git a/flash_attn/results/artifacts/combine/latency.csv b/flash_attn/results/artifacts/combine/latency.csv
new file mode 100644
index 0000000000000000000000000000000000000000..fe7bd55163eff0be13ab95a18cadede44a50adaf
--- /dev/null
+++ b/flash_attn/results/artifacts/combine/latency.csv
@@ -0,0 +1,43 @@
+Implementation,Impl ID,Workload,Batch,Seq Length,Heads,Head Dim,Dtype,Mean (ms),P10 (ms),P50 (ms),P90 (ms),Reps,Peak Mem (MB),Backend,Family
+Flash (PyTorch SDPA),torch_flash_ma,flux_L128,1,1152,24,128,bfloat16,0.407123202085495,0.40537598729133606,0.40755200386047363,0.407584011554718,5,83.38,FLASH,torch-sdpa
+Flash (PyTorch SDPA),torch_flash_ma,flux_L256,1,1280,24,128,bfloat16,0.5235007882118226,0.5212159752845764,0.5232639908790588,0.523360013961792,5,90.62,FLASH,torch-sdpa
+Flash (PyTorch SDPA),torch_flash_ma,flux_L320,1,1344,24,128,bfloat16,0.545849597454071,0.5418559908866882,0.5468159914016724,0.5469120144844055,5,95.06,FLASH,torch-sdpa
+Flash (PyTorch SDPA),torch_flash_ma,flux_L384,1,1408,24,128,bfloat16,0.5892416119575501,0.5867519974708557,0.5888000130653381,0.5888000130653381,5,99.88,FLASH,torch-sdpa
+Flash (PyTorch SDPA),torch_flash_ma,flux_L448,1,1472,24,128,bfloat16,0.6449280023574829,0.6430720090866089,0.6442239880561829,0.6450240015983582,5,103.81,FLASH,torch-sdpa
+Flash (PyTorch SDPA),torch_flash_ma,flux_L512,1,1536,24,128,bfloat16,0.6823423862457275,0.6777600049972534,0.6809599995613098,0.6818559765815735,5,109.12,FLASH,torch-sdpa
+MemEff (PyTorch SDPA),torch_mem_eff,flux_L128,1,1152,24,128,bfloat16,0.48371200561523436,0.4821760058403015,0.4833280146121979,0.4853760004043579,5,83.38,EFFICIENT,torch-sdpa
+MemEff (PyTorch SDPA),torch_mem_eff,flux_L256,1,1280,24,128,bfloat16,0.6268800020217895,0.6246399879455566,0.6266880035400391,0.6286720037460327,5,90.62,EFFICIENT,torch-sdpa
+MemEff (PyTorch SDPA),torch_mem_eff,flux_L320,1,1344,24,128,bfloat16,0.699776005744934,0.6973440051078796,0.7004160284996033,0.7004479765892029,5,95.94,EFFICIENT,torch-sdpa
+MemEff (PyTorch SDPA),torch_mem_eff,flux_L384,1,1408,24,128,bfloat16,0.8333312034606933,0.8284159898757935,0.8325120210647583,0.8376320004463196,5,100.0,EFFICIENT,torch-sdpa
+MemEff (PyTorch SDPA),torch_mem_eff,flux_L448,1,1472,24,128,bfloat16,0.9533439993858337,0.9502720236778259,0.9512959718704224,0.9572479724884033,5,103.81,EFFICIENT,torch-sdpa
+MemEff (PyTorch SDPA),torch_mem_eff,flux_L512,1,1536,24,128,bfloat16,1.0066367864608765,1.0024960041046143,1.0045440196990967,1.0097919702529907,5,109.12,EFFICIENT,torch-sdpa
+xFormers,xformers_meff,flux_L128,1,1152,24,128,bfloat16,0.3452928066253662,0.3389439880847931,0.3461120128631592,0.3461120128631592,5,83.38,memory_efficient,xformers
+xFormers,xformers_meff,flux_L256,1,1280,24,128,bfloat16,0.41234560012817384,0.40959998965263367,0.41280001401901245,0.41286399960517883,5,90.62,memory_efficient,xformers
+xFormers,xformers_meff,flux_L320,1,1344,24,128,bfloat16,0.4366208016872406,0.4310399889945984,0.4331519901752472,0.4362240135669708,5,95.06,memory_efficient,xformers
+xFormers,xformers_meff,flux_L384,1,1408,24,128,bfloat16,0.4450624048709869,0.4359680116176605,0.44361600279808044,0.447488009929657,5,99.88,memory_efficient,xformers
+xFormers,xformers_meff,flux_L448,1,1472,24,128,bfloat16,0.4750400006771088,0.4711039960384369,0.47513601183891296,0.4763199985027313,5,103.81,memory_efficient,xformers
+xFormers,xformers_meff,flux_L512,1,1536,24,128,bfloat16,0.5009407997131348,0.49663999676704407,0.4997119903564453,0.5038080215454102,5,109.12,memory_efficient,xformers
+Compiled (default),torch_flash_compiled_default,flux_L128,1,1152,24,128,bfloat16,0.3856383919715881,0.3563520014286041,0.35942399501800537,0.3624959886074066,5,83.38,FLASH,torch-sdpa
+Compiled (default),torch_flash_compiled_default,flux_L256,1,1280,24,128,bfloat16,0.4982912003993988,0.4926080107688904,0.49663999676704407,0.5017600059509277,5,90.62,FLASH,torch-sdpa
+Compiled (default),torch_flash_compiled_default,flux_L320,1,1344,24,128,bfloat16,0.5369919896125793,0.5335040092468262,0.5366079807281494,0.5386239886283875,5,95.25,FLASH,torch-sdpa
+Compiled (default),torch_flash_compiled_default,flux_L384,1,1408,24,128,bfloat16,0.5841408014297486,0.5775359869003296,0.5868800282478333,0.5877760052680969,5,99.88,FLASH,torch-sdpa
+Compiled (default),torch_flash_compiled_default,flux_L448,1,1472,24,128,bfloat16,0.6184704065322876,0.6072319746017456,0.6113280057907104,0.6144000291824341,5,103.81,FLASH,torch-sdpa
+Compiled (default),torch_flash_compiled_default,flux_L512,1,1536,24,128,bfloat16,0.6428672075271606,0.6399999856948853,0.6430720090866089,0.6430720090866089,5,109.12,FLASH,torch-sdpa
+Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L128,1,1152,24,128,bfloat16,0.40020479559898375,0.3665919899940491,0.3768320083618164,0.41171199083328247,5,81.75,FLASH,torch-sdpa
+Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L256,1,1280,24,128,bfloat16,0.5535807967185974,0.5160959959030151,0.5489599704742432,0.5631359815597534,5,92.88,FLASH,torch-sdpa
+Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L320,1,1344,24,128,bfloat16,0.6143999934196472,0.562175989151001,0.6144000291824341,0.6318079829216003,5,95.13,FLASH,torch-sdpa
+Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L384,1,1408,24,128,bfloat16,0.6754495978355408,0.6512640118598938,0.6584320068359375,0.6799359917640686,5,97.13,FLASH,torch-sdpa
+Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L448,1,1472,24,128,bfloat16,0.7210752129554748,0.6973119974136353,0.7014080286026001,0.7229440212249756,5,99.0,FLASH,torch-sdpa
+Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L512,1,1536,24,128,bfloat16,0.7735359907150269,0.7485439777374268,0.7557439804077148,0.7710719704627991,5,101.63,FLASH,torch-sdpa
+HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L128,1,1152,24,128,bfloat16,0.2456959992647171,0.24371199309825897,0.24566400051116943,0.2457599937915802,5,83.38,flash-attn,hf-kernels
+HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L256,1,1280,24,128,bfloat16,0.3215551972389221,0.3164159953594208,0.319487988948822,0.32051199674606323,5,90.62,flash-attn,hf-kernels
+HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L320,1,1344,24,128,bfloat16,0.3384703993797302,0.33670398592948914,0.33792001008987427,0.33983999490737915,5,95.06,flash-attn,hf-kernels
+HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L384,1,1408,24,128,bfloat16,0.3510208010673523,0.3481599986553192,0.3491840064525604,0.35225600004196167,5,99.88,flash-attn,hf-kernels
+HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L448,1,1472,24,128,bfloat16,0.3829823970794678,0.38095998764038086,0.3829759955406189,0.3840000033378601,5,103.81,flash-attn,hf-kernels
+HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L512,1,1536,24,128,bfloat16,0.4259391903877258,0.4227519929409027,0.4249599874019623,0.4259839951992035,5,109.12,flash-attn,hf-kernels
+HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L128,1,1152,24,128,bfloat16,0.2755008041858673,0.26736000180244446,0.27561599016189575,0.27955201268196106,5,83.38,flash-attn3,hf-kernels
+HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L256,1,1280,24,128,bfloat16,0.3397440016269684,0.3368000090122223,0.3399679958820343,0.34191998839378357,5,90.62,flash-attn3,hf-kernels
+HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L320,1,1344,24,128,bfloat16,0.36019839644432067,0.3563520014286041,0.3604480028152466,0.36137598752975464,5,95.06,flash-attn3,hf-kernels
+HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L384,1,1408,24,128,bfloat16,0.37342079877853396,0.3718400001525879,0.37379199266433716,0.3746879994869232,5,99.88,flash-attn3,hf-kernels
+HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L448,1,1472,24,128,bfloat16,0.4024448037147522,0.3993600010871887,0.4014720022678375,0.4034560024738312,5,103.81,flash-attn3,hf-kernels
+HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L512,1,1536,24,128,bfloat16,0.4305088043212891,0.4270080029964447,0.4291520118713379,0.4331519901752472,5,109.12,flash-attn3,hf-kernels
diff --git a/flash_attn/results/artifacts/combine/latency.svg b/flash_attn/results/artifacts/combine/latency.svg
new file mode 100644
index 0000000000000000000000000000000000000000..ea5c5865f5ff0718bc77d9c64aa425433567eef1
--- /dev/null
+++ b/flash_attn/results/artifacts/combine/latency.svg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7bb668989495a1f179dd65f7426ee2a611f8dc193219b9a7385c79f6701d161a
+size 29810
diff --git a/flash_attn/results/cells/combine.py b/flash_attn/results/cells/combine.py
new file mode 100644
index 0000000000000000000000000000000000000000..f703ae3d1403d602560c9d3b36d51fab69b7f3a5
--- /dev/null
+++ b/flash_attn/results/cells/combine.py
@@ -0,0 +1,319 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# "matplotlib",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import os
+import sys
+from pathlib import Path
+import json
+import torch # noqa: F401 # imported because upstream may expect torch to be importable
+import kernels_benchmark_tools as kbt
+
+# --- Matplotlib setup and helpers ------------------------------------------------
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import csv
+
+
+# Keep text as text (not paths) so CSS can style fonts, size, etc.
+mpl.rcParams["svg.fonttype"] = "none"
+# Make ids deterministic across builds
+mpl.rcParams["svg.hashsalt"] = "latency-benchmark-combined"
+# Avoid auto-closed figures interfering with our tagging
+mpl.rcParams["figure.autolayout"] = True
+# Make background transparent
+mpl.rcParams["figure.facecolor"] = "none"
+mpl.rcParams["axes.facecolor"] = "none"
+mpl.rcParams["savefig.facecolor"] = "none"
+mpl.rcParams["savefig.edgecolor"] = "none"
+
+def _slugify(s: str) -> str:
+ s = (s or "").strip().lower()
+ keep = []
+ for ch in s:
+ if ch.isalnum():
+ keep.append(ch)
+ elif ch in (" ", "-", "_", "/", ".", ":"):
+ keep.append("-")
+ else:
+ keep.append("")
+ out = "".join(keep)
+ while "--" in out:
+ out = out.replace("--", "-")
+ return out.strip("-") or "unnamed"
+
+def _tag_current_figure(default_series_prefix="series"):
+ """Attach SVG ids (gid) to key artists so they can be targeted from CSS."""
+ fig = plt.gcf()
+ if fig is None:
+ return
+
+ # Tag the figure itself
+ fig.set_gid("figure--latency")
+
+ for ax_idx, ax in enumerate(fig.get_axes(), start=1):
+ ax.set_gid(f"axes--{ax_idx}")
+
+ # Axis labels & title
+ if ax.get_title():
+ for t in ax.texts:
+ if t.get_text() == ax.get_title():
+ t.set_gid("title--main")
+ if ax.xaxis and ax.xaxis.get_label():
+ ax.xaxis.label.set_gid("label--x")
+ if ax.yaxis and ax.yaxis.get_label():
+ ax.yaxis.label.set_gid("label--y")
+
+ # Gridlines
+ for i, gl in enumerate(ax.get_xgridlines(), start=1):
+ gl.set_gid(f"grid-x--{i}")
+ for i, gl in enumerate(ax.get_ygridlines(), start=1):
+ gl.set_gid(f"grid-y--{i}")
+
+ # Legend block & entries
+ leg = ax.get_legend()
+ if leg is not None:
+ leg.set_gid("legend")
+ for i, txt in enumerate(leg.get_texts(), start=1):
+ label_slug = _slugify(txt.get_text())
+ txt.set_gid(f"legend-label--{label_slug or i}")
+
+ # Series (lines, patches)
+ # Lines
+ line_seen = {}
+ for ln in getattr(ax, "lines", []):
+ raw_label = ln.get_label() or ""
+ # Matplotlib uses labels beginning with "_" for non-legendable items
+ label = raw_label if not raw_label.startswith("_") else f"{default_series_prefix}"
+ slug = _slugify(label)
+ line_seen[slug] = line_seen.get(slug, 0) + 1
+ suffix = "" if line_seen[slug] == 1 else f"-{line_seen[slug]}"
+ ln.set_gid(f"series--{slug}{suffix}")
+
+ # Patches (bars, areas)
+ patch_seen = {}
+ for pt in getattr(ax, "patches", []):
+ label = getattr(pt, "get_label", lambda: "")() or f"{default_series_prefix}"
+ if isinstance(label, str) and label.startswith("_"):
+ label = default_series_prefix
+ slug = _slugify(label)
+ patch_seen[slug] = patch_seen.get(slug, 0) + 1
+ suffix = "" if patch_seen[slug] == 1 else f"-{patch_seen[slug]}"
+ pt.set_gid(f"series--{slug}{suffix}")
+
+def _postprocess_svg_add_classes(svg_path: Path):
+ """Add convenient CSS classes alongside ids (e.g., class='series grid grid-x')."""
+ try:
+ import xml.etree.ElementTree as ET
+ ET.register_namespace("", "http://www.w3.org/2000/svg")
+ tree = ET.parse(svg_path)
+ root = tree.getroot()
+ for el in root.iter():
+ el_id = el.attrib.get("id", "")
+ if not el_id:
+ continue
+ cls = []
+ if el_id.startswith("figure--"):
+ cls.append("figure")
+ elif el_id.startswith("axes--"):
+ cls.append("axes")
+ elif el_id.startswith("grid-x--"):
+ cls += ["grid", "grid-x"]
+ elif el_id.startswith("grid-y--"):
+ cls += ["grid", "grid-y"]
+ elif el_id.startswith("legend"):
+ cls.append("legend")
+ elif el_id.startswith("label--x"):
+ cls.append("xlabel")
+ elif el_id.startswith("label--y"):
+ cls.append("ylabel")
+ elif el_id.startswith("title--"):
+ cls.append("title")
+ elif el_id.startswith("series--"):
+ cls.append("series")
+ if cls:
+ # Preserve any existing class (unlikely from Matplotlib)
+ existing = el.attrib.get("class", "")
+ el.set("class", (existing + " " + " ".join(cls)).strip())
+ tree.write(svg_path, encoding="utf-8", xml_declaration=True)
+ except Exception as e:
+ print(f"✗ SVG postprocess (classes) skipped: {e}")
+
+# Monkey-patch savefig to force SVG & ensure tagging occurs even if kbt.viz saves internally.
+_orig_savefig = plt.savefig
+def _savefig_svg(fname, *args, **kwargs):
+ # Always save as SVG at a stable path for the artifact system
+ out = Path("latency.svg")
+ kwargs["format"] = "svg"
+ # Ensure everything we care about has ids before export
+ _tag_current_figure()
+ res = _orig_savefig(out, *args, **kwargs)
+ # Add helpful CSS classes on top of ids
+ _postprocess_svg_add_classes(out)
+ print(f"✓ Combined visualization saved as {out}")
+ return res
+
+plt.savefig = _savefig_svg # apply patch
+
+# Capture close calls in case kbt.viz() closes figures before we re-save
+_orig_close = plt.close
+_last_closed = {"fig": None}
+def _capture_close(arg=None):
+ try:
+ if hasattr(arg, "savefig"): # looks like a Figure
+ _last_closed["fig"] = arg
+ else:
+ _last_closed["fig"] = plt.gcf()
+ finally:
+ return _orig_close(arg)
+plt.close = _capture_close
+
+# --- Locate benchmark artifacts --------------------------------------------------
+cache_dirs = {
+ "Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'),
+ "MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'),
+ "Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'),
+ "xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'),
+ "SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'),
+ "Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'),
+ "Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'),
+ "HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'),
+ "HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'),
+}
+
+print("LOADING BENCHMARK DATA")
+for name, cache_dir in cache_dirs.items():
+ print(f"{name:30s}: {cache_dir}")
+print()
+
+file_mapping = {
+ "Flash (PyTorch SDPA)": "attn.jsonl",
+ "MemEff (PyTorch SDPA)": "attn.jsonl",
+ "Flash Attn 2": "attn.jsonl",
+ "xFormers": "attn.jsonl",
+ "SageAttention": "attn.jsonl",
+ "Compiled (default)": "attn_default.jsonl",
+ "Compiled (max-autotune)": "attn_max_autotune.jsonl",
+ "HF Kernels Flash Attn": "attn.jsonl",
+ "HF Kernels Flash Attn3": "attn.jsonl",
+}
+
+all_paths = []
+for name, cache_dir in cache_dirs.items():
+ if cache_dir:
+ path = Path(cache_dir) / file_mapping[name]
+ if path.exists() and path.stat().st_size > 0:
+ all_paths.append(str(path))
+ print(f"✓ Found {name}: {path}")
+ else:
+ print(f"⊘ Empty/Missing {name}: {path}")
+ else:
+ print(f"✗ No cache dir for {name}")
+print()
+
+if not all_paths:
+ print("ERROR: No benchmark data files found!")
+ # restore patched functions before exiting
+ plt.savefig = _orig_savefig
+ plt.close = _orig_close
+ sys.exit(1)
+
+# --- Summary + Visualization -----------------------------------------------------
+print("COMBINED BENCHMARK SUMMARY\n")
+kbt.summarize(all_paths)
+print("\nGENERATING COMBINED VISUALIZATION\n")
+
+try:
+ # If kbt.viz saves internally, our patched savefig ensures SVG gets written,
+ # and it will carry ids/classes for CSS styling.
+ kbt.viz(all_paths)
+ # Safety net: if kbt.viz didn't save, save now.
+ # if not Path("latency.svg").exists():
+ # _tag_current_figure()
+ # plt.savefig("latency.svg")
+
+ plt.savefig("latency.svg") # ensure saved with tagging
+
+ print("✓ SVG visualization ready: latency.svg!")
+except ImportError as e:
+ print(f"✗ Visualization requires matplotlib: {e}")
+except Exception as e:
+ print(f"✗ Visualization failed: {e}")
+finally:
+ # Clean up patches to avoid side effects in later cells
+ plt.savefig = _orig_savefig
+ plt.close = _orig_close
+
+print()
+print("ANALYSIS COMPLETE")
+print(f"Total implementations analyzed: {len(all_paths)}")
+print(f"\nImplementations included:")
+for name, cache_dir in cache_dirs.items():
+ if cache_dir:
+ path = Path(cache_dir) / file_mapping[name]
+ if path.exists() and path.stat().st_size > 0:
+ print(f" ✓ {name}")
+
+
+
+# Collect all benchmark data and export to CSV
+all_data = {}
+for name, cache_dir in cache_dirs.items():
+ if cache_dir:
+ path = Path(cache_dir) / file_mapping[name]
+ if path.exists() and path.stat().st_size > 0:
+ with open(path, 'r') as f:
+ records = [json.loads(line) for line in f]
+ all_data[name] = records
+
+# Export to CSV
+csv_path = Path("latency.csv")
+with open(csv_path, 'w', newline='') as csvfile:
+ writer = csv.writer(csvfile)
+
+ # Write header
+ header = ["Implementation", "Impl ID", "Workload", "Batch", "Seq Length", "Heads", "Head Dim", "Dtype",
+ "Mean (ms)", "P10 (ms)", "P50 (ms)", "P90 (ms)", "Reps",
+ # "Compile (ms)",
+ "Peak Mem (MB)", "Backend", "Family"]
+ writer.writerow(header)
+
+ # Write data rows
+ for impl_name, records in all_data.items():
+ for record in records:
+ wl = record.get('wl', {})
+ lat = record.get('lat_ms', {})
+ tags = record.get('tags', {})
+
+ row = [
+ impl_name,
+ record.get('impl', ''),
+ wl.get('name', ''),
+ wl.get('batch', ''),
+ wl.get('seq_len', ''),
+ wl.get('heads', ''),
+ wl.get('head_dim', ''),
+ wl.get('dtype', ''),
+ lat.get('mean', ''),
+ lat.get('p10', ''),
+ lat.get('p50', ''),
+ lat.get('p90', ''),
+ lat.get('reps', ''),
+ # record.get('compile_ms', ''),
+ round(record.get('peak_bytes', 0) / 1024 / 1024, 2) if record.get('peak_bytes') else '',
+ tags.get('backend', ''),
+ tags.get('family', ''),
+ ]
+ writer.writerow(row)
+
+print(f"✓ CSV export complete: {csv_path}")
+print(f"Total implementations: {len(all_data)}")
+print(f"Total records: {sum(len(records) for records in all_data.values())}")
diff --git a/flash_attn/results/cells/csv_export.py b/flash_attn/results/cells/csv_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..403d9439ac641761a863de39cb66f18d4e0e924e
--- /dev/null
+++ b/flash_attn/results/cells/csv_export.py
@@ -0,0 +1,76 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import os
+import csv
+from pathlib import Path
+import json
+
+# --- Locate benchmark artifacts --------------------------------------------------
+cache_dirs = {
+ "Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'),
+ "MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'),
+ "Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'),
+ "xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'),
+ "SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'),
+ "Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'),
+ "Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'),
+ "HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'),
+ "HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'),
+}
+
+file_mapping = {
+ "Flash (PyTorch SDPA)": "attn.jsonl",
+ "MemEff (PyTorch SDPA)": "attn.jsonl",
+ "Flash Attn 2": "attn.jsonl",
+ "xFormers": "attn.jsonl",
+ "SageAttention": "attn.jsonl",
+ "Compiled (default)": "attn_default.jsonl",
+ "Compiled (max-autotune)": "attn_max_autotune.jsonl",
+ "HF Kernels Flash Attn": "attn.jsonl",
+ "HF Kernels Flash Attn3": "attn.jsonl",
+}
+
+# Collect all benchmark data
+all_data = {}
+for name, cache_dir in cache_dirs.items():
+ if cache_dir:
+ path = Path(cache_dir) / file_mapping[name]
+ if path.exists() and path.stat().st_size > 0:
+ with open(path, 'r') as f:
+ records = [json.loads(line) for line in f]
+ all_data[name] = records
+
+# Export to CSV
+csv_path = Path("latency.csv")
+with open(csv_path, 'w', newline='') as csvfile:
+ writer = csv.writer(csvfile)
+
+ # Write header
+ header = ["Implementation", "Sequence Length", "Latency (ms)", "Min (ms)", "Max (ms)", "Median (ms)"]
+ writer.writerow(header)
+
+ # Write data rows
+ for impl_name, records in all_data.items():
+ for record in records:
+ row = [
+ impl_name,
+ record.get('seqlen', ''),
+ record.get('latency', ''),
+ record.get('min', ''),
+ record.get('max', ''),
+ record.get('median', ''),
+ ]
+ writer.writerow(row)
+
+print(f"✓ CSV export complete: {csv_path}")
+print(f"Total implementations: {len(all_data)}")
+print(f"Total records: {sum(len(records) for records in all_data.values())}")
\ No newline at end of file
diff --git a/flash_attn/results/combined_results.html b/flash_attn/results/combined_results.html
new file mode 100644
index 0000000000000000000000000000000000000000..6af794548eaa21127d41882494b0c1ec263c9bd6
--- /dev/null
+++ b/flash_attn/results/combined_results.html
@@ -0,0 +1,7372 @@
+
+
+
+
+
+ Flash Attention Benchmark - Combined Results
+
+
+
+
+
+
+
+
+
+
+
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31
+
+
+
+
+
Flash Attention Benchmarks - Aggregated Results
+
This document combines benchmark results from multiple attention implementations
+using cross-file dependencies.
+
Combined Summary and Visualization
+
+
+
+
+
+
+
+ 2025-10-14T20:33:46.842309
+ image/svg+xml
+
+
+ Matplotlib v3.10.7, https://matplotlib.org/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L128
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L256
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L320
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L384
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L448
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L512
+
+
+
+ Workload
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.3
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.4
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.5
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.6
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.7
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.8
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.9
+
+
+
+
+
+
+
+
+
+
+
+
+ 1.0
+
+
+
+ Latency P50 (ms)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Attention Implementation Latency
+
+
+
+
+
+
+
+
+
+
+
+
+ torch_flash_ma
+
+
+
+
+
+
+
+
+ torch_mem_eff
+
+
+
+
+
+
+
+
+ xformers_meff
+
+
+
+
+
+
+
+
+ torch_flash_compiled_default
+
+
+
+
+
+
+
+
+ torch_flash_compiled_max_autotune
+
+
+
+
+
+
+
+
+ hf_kernels_flash_attn
+
+
+
+
+
+
+
+
+ hf_kernels_flash_attn3
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+Implementation
+Impl ID
+Workload
+Batch
+Seq Length
+Heads
+Head Dim
+Dtype
+Mean (ms)
+P10 (ms)
+P50 (ms)
+P90 (ms)
+Reps
+Peak Mem (MB)
+Backend
+Family
+
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.407123202085495
+0.40537598729133606
+0.40755200386047363
+0.407584011554718
+5
+83.38
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.5235007882118226
+0.5212159752845764
+0.5232639908790588
+0.523360013961792
+5
+90.62
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.545849597454071
+0.5418559908866882
+0.5468159914016724
+0.5469120144844055
+5
+95.06
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.5892416119575501
+0.5867519974708557
+0.5888000130653381
+0.5888000130653381
+5
+99.88
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.6449280023574829
+0.6430720090866089
+0.6442239880561829
+0.6450240015983582
+5
+103.81
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.6823423862457275
+0.6777600049972534
+0.6809599995613098
+0.6818559765815735
+5
+109.12
+FLASH
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.48371200561523436
+0.4821760058403015
+0.4833280146121979
+0.4853760004043579
+5
+83.38
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.6268800020217895
+0.6246399879455566
+0.6266880035400391
+0.6286720037460327
+5
+90.62
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.699776005744934
+0.6973440051078796
+0.7004160284996033
+0.7004479765892029
+5
+95.94
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.8333312034606933
+0.8284159898757935
+0.8325120210647583
+0.8376320004463196
+5
+100.0
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.9533439993858337
+0.9502720236778259
+0.9512959718704224
+0.9572479724884033
+5
+103.81
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L512
+1
+1536
+24
+128
+bfloat16
+1.0066367864608765
+1.0024960041046143
+1.0045440196990967
+1.0097919702529907
+5
+109.12
+EFFICIENT
+torch-sdpa
+
+
+xFormers
+xformers_meff
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.3452928066253662
+0.3389439880847931
+0.3461120128631592
+0.3461120128631592
+5
+83.38
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.41234560012817384
+0.40959998965263367
+0.41280001401901245
+0.41286399960517883
+5
+90.62
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.4366208016872406
+0.4310399889945984
+0.4331519901752472
+0.4362240135669708
+5
+95.06
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.4450624048709869
+0.4359680116176605
+0.44361600279808044
+0.447488009929657
+5
+99.88
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.4750400006771088
+0.4711039960384369
+0.47513601183891296
+0.4763199985027313
+5
+103.81
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.5009407997131348
+0.49663999676704407
+0.4997119903564453
+0.5038080215454102
+5
+109.12
+memory_efficient
+xformers
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.3856383919715881
+0.3563520014286041
+0.35942399501800537
+0.3624959886074066
+5
+83.38
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.4982912003993988
+0.4926080107688904
+0.49663999676704407
+0.5017600059509277
+5
+90.62
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.5369919896125793
+0.5335040092468262
+0.5366079807281494
+0.5386239886283875
+5
+95.25
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.5841408014297486
+0.5775359869003296
+0.5868800282478333
+0.5877760052680969
+5
+99.88
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.6184704065322876
+0.6072319746017456
+0.6113280057907104
+0.6144000291824341
+5
+103.81
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.6428672075271606
+0.6399999856948853
+0.6430720090866089
+0.6430720090866089
+5
+109.12
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.40020479559898375
+0.3665919899940491
+0.3768320083618164
+0.41171199083328247
+5
+81.75
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.5535807967185974
+0.5160959959030151
+0.5489599704742432
+0.5631359815597534
+5
+92.88
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.6143999934196472
+0.562175989151001
+0.6144000291824341
+0.6318079829216003
+5
+95.13
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.6754495978355408
+0.6512640118598938
+0.6584320068359375
+0.6799359917640686
+5
+97.13
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.7210752129554748
+0.6973119974136353
+0.7014080286026001
+0.7229440212249756
+5
+99.0
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.7735359907150269
+0.7485439777374268
+0.7557439804077148
+0.7710719704627991
+5
+101.63
+FLASH
+torch-sdpa
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.2456959992647171
+0.24371199309825897
+0.24566400051116943
+0.2457599937915802
+5
+83.38
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.3215551972389221
+0.3164159953594208
+0.319487988948822
+0.32051199674606323
+5
+90.62
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.3384703993797302
+0.33670398592948914
+0.33792001008987427
+0.33983999490737915
+5
+95.06
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.3510208010673523
+0.3481599986553192
+0.3491840064525604
+0.35225600004196167
+5
+99.88
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.3829823970794678
+0.38095998764038086
+0.3829759955406189
+0.3840000033378601
+5
+103.81
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.4259391903877258
+0.4227519929409027
+0.4249599874019623
+0.4259839951992035
+5
+109.12
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.2755008041858673
+0.26736000180244446
+0.27561599016189575
+0.27955201268196106
+5
+83.38
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.3397440016269684
+0.3368000090122223
+0.3399679958820343
+0.34191998839378357
+5
+90.62
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.36019839644432067
+0.3563520014286041
+0.3604480028152466
+0.36137598752975464
+5
+95.06
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.37342079877853396
+0.3718400001525879
+0.37379199266433716
+0.3746879994869232
+5
+99.88
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.4024448037147522
+0.3993600010871887
+0.4014720022678375
+0.4034560024738312
+5
+103.81
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.4305088043212891
+0.4270080029964447
+0.4291520118713379
+0.4331519901752472
+5
+109.12
+flash-attn3
+hf-kernels
+
+
+
+
+
+
+
+
+
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "numpy",
+# "torch",
+# "kernels-benchmark-tools",
+# "matplotlib",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import os
+import sys
+from pathlib import Path
+import json
+import torch # noqa: F401 # imported because upstream may expect torch to be importable
+import kernels_benchmark_tools as kbt
+
+# --- Matplotlib setup and helpers ------------------------------------------------
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import csv
+
+
+# Keep text as text (not paths) so CSS can style fonts, size, etc.
+mpl . rcParams [ "svg.fonttype" ] = "none"
+# Make ids deterministic across builds
+mpl . rcParams [ "svg.hashsalt" ] = "latency-benchmark-combined"
+# Avoid auto-closed figures interfering with our tagging
+mpl . rcParams [ "figure.autolayout" ] = True
+# Make background transparent
+mpl . rcParams [ "figure.facecolor" ] = "none"
+mpl . rcParams [ "axes.facecolor" ] = "none"
+mpl . rcParams [ "savefig.facecolor" ] = "none"
+mpl . rcParams [ "savefig.edgecolor" ] = "none"
+
+def _slugify ( s : str ) -> str :
+ s = ( s or "" ) . strip () . lower ()
+ keep = []
+ for ch in s :
+ if ch . isalnum ():
+ keep . append ( ch )
+ elif ch in ( " " , "-" , "_" , "/" , "." , ":" ):
+ keep . append ( "-" )
+ else :
+ keep . append ( "" )
+ out = "" . join ( keep )
+ while "--" in out :
+ out = out . replace ( "--" , "-" )
+ return out . strip ( "-" ) or "unnamed"
+
+def _tag_current_figure ( default_series_prefix = "series" ):
+ """Attach SVG ids (gid) to key artists so they can be targeted from CSS."""
+ fig = plt . gcf ()
+ if fig is None :
+ return
+
+ # Tag the figure itself
+ fig . set_gid ( "figure--latency" )
+
+ for ax_idx , ax in enumerate ( fig . get_axes (), start = 1 ):
+ ax . set_gid ( f "axes-- { ax_idx } " )
+
+ # Axis labels & title
+ if ax . get_title ():
+ for t in ax . texts :
+ if t . get_text () == ax . get_title ():
+ t . set_gid ( "title--main" )
+ if ax . xaxis and ax . xaxis . get_label ():
+ ax . xaxis . label . set_gid ( "label--x" )
+ if ax . yaxis and ax . yaxis . get_label ():
+ ax . yaxis . label . set_gid ( "label--y" )
+
+ # Gridlines
+ for i , gl in enumerate ( ax . get_xgridlines (), start = 1 ):
+ gl . set_gid ( f "grid-x-- { i } " )
+ for i , gl in enumerate ( ax . get_ygridlines (), start = 1 ):
+ gl . set_gid ( f "grid-y-- { i } " )
+
+ # Legend block & entries
+ leg = ax . get_legend ()
+ if leg is not None :
+ leg . set_gid ( "legend" )
+ for i , txt in enumerate ( leg . get_texts (), start = 1 ):
+ label_slug = _slugify ( txt . get_text ())
+ txt . set_gid ( f "legend-label-- { label_slug or i } " )
+
+ # Series (lines, patches)
+ # Lines
+ line_seen = {}
+ for ln in getattr ( ax , "lines" , []):
+ raw_label = ln . get_label () or ""
+ # Matplotlib uses labels beginning with "_" for non-legendable items
+ label = raw_label if not raw_label . startswith ( "_" ) else f " { default_series_prefix } "
+ slug = _slugify ( label )
+ line_seen [ slug ] = line_seen . get ( slug , 0 ) + 1
+ suffix = "" if line_seen [ slug ] == 1 else f "- { line_seen [ slug ] } "
+ ln . set_gid ( f "series-- { slug }{ suffix } " )
+
+ # Patches (bars, areas)
+ patch_seen = {}
+ for pt in getattr ( ax , "patches" , []):
+ label = getattr ( pt , "get_label" , lambda : "" )() or f " { default_series_prefix } "
+ if isinstance ( label , str ) and label . startswith ( "_" ):
+ label = default_series_prefix
+ slug = _slugify ( label )
+ patch_seen [ slug ] = patch_seen . get ( slug , 0 ) + 1
+ suffix = "" if patch_seen [ slug ] == 1 else f "- { patch_seen [ slug ] } "
+ pt . set_gid ( f "series-- { slug }{ suffix } " )
+
+def _postprocess_svg_add_classes ( svg_path : Path ):
+ """Add convenient CSS classes alongside ids (e.g., class='series grid grid-x')."""
+ try :
+ import xml.etree.ElementTree as ET
+ ET . register_namespace ( "" , "http://www.w3.org/2000/svg" )
+ tree = ET . parse ( svg_path )
+ root = tree . getroot ()
+ for el in root . iter ():
+ el_id = el . attrib . get ( "id" , "" )
+ if not el_id :
+ continue
+ cls = []
+ if el_id . startswith ( "figure--" ):
+ cls . append ( "figure" )
+ elif el_id . startswith ( "axes--" ):
+ cls . append ( "axes" )
+ elif el_id . startswith ( "grid-x--" ):
+ cls += [ "grid" , "grid-x" ]
+ elif el_id . startswith ( "grid-y--" ):
+ cls += [ "grid" , "grid-y" ]
+ elif el_id . startswith ( "legend" ):
+ cls . append ( "legend" )
+ elif el_id . startswith ( "label--x" ):
+ cls . append ( "xlabel" )
+ elif el_id . startswith ( "label--y" ):
+ cls . append ( "ylabel" )
+ elif el_id . startswith ( "title--" ):
+ cls . append ( "title" )
+ elif el_id . startswith ( "series--" ):
+ cls . append ( "series" )
+ if cls :
+ # Preserve any existing class (unlikely from Matplotlib)
+ existing = el . attrib . get ( "class" , "" )
+ el . set ( "class" , ( existing + " " + " " . join ( cls )) . strip ())
+ tree . write ( svg_path , encoding = "utf-8" , xml_declaration = True )
+ except Exception as e :
+ print ( f "✗ SVG postprocess (classes) skipped: { e } " )
+
+# Monkey-patch savefig to force SVG & ensure tagging occurs even if kbt.viz saves internally.
+_orig_savefig = plt . savefig
+def _savefig_svg ( fname , * args , ** kwargs ):
+ # Always save as SVG at a stable path for the artifact system
+ out = Path ( "latency.svg" )
+ kwargs [ "format" ] = "svg"
+ # Ensure everything we care about has ids before export
+ _tag_current_figure ()
+ res = _orig_savefig ( out , * args , ** kwargs )
+ # Add helpful CSS classes on top of ids
+ _postprocess_svg_add_classes ( out )
+ print ( f "✓ Combined visualization saved as { out } " )
+ return res
+
+plt . savefig = _savefig_svg # apply patch
+
+# Capture close calls in case kbt.viz() closes figures before we re-save
+_orig_close = plt . close
+_last_closed = { "fig" : None }
+def _capture_close ( arg = None ):
+ try :
+ if hasattr ( arg , "savefig" ): # looks like a Figure
+ _last_closed [ "fig" ] = arg
+ else :
+ _last_closed [ "fig" ] = plt . gcf ()
+ finally :
+ return _orig_close ( arg )
+plt . close = _capture_close
+
+# --- Locate benchmark artifacts --------------------------------------------------
+cache_dirs = {
+ "Flash (PyTorch SDPA)" : os . environ . get ( 'UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK' ),
+ "MemEff (PyTorch SDPA)" : os . environ . get ( 'UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK' ),
+ "Flash Attn 2" : os . environ . get ( 'UVNOTE_FILE_FLASH_ATTN2_BENCHMARK' ),
+ "xFormers" : os . environ . get ( 'UVNOTE_FILE_XFORMERS_BENCHMARK' ),
+ "SageAttention" : os . environ . get ( 'UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK' ),
+ "Compiled (default)" : os . environ . get ( 'UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT' ),
+ "Compiled (max-autotune)" : os . environ . get ( 'UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE' ),
+ "HF Kernels Flash Attn" : os . environ . get ( 'UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK' ),
+ "HF Kernels Flash Attn3" : os . environ . get ( 'UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK' ),
+}
+
+print ( "LOADING BENCHMARK DATA" )
+for name , cache_dir in cache_dirs . items ():
+ print ( f " { name : 30s } : { cache_dir } " )
+print ()
+
+file_mapping = {
+ "Flash (PyTorch SDPA)" : "attn.jsonl" ,
+ "MemEff (PyTorch SDPA)" : "attn.jsonl" ,
+ "Flash Attn 2" : "attn.jsonl" ,
+ "xFormers" : "attn.jsonl" ,
+ "SageAttention" : "attn.jsonl" ,
+ "Compiled (default)" : "attn_default.jsonl" ,
+ "Compiled (max-autotune)" : "attn_max_autotune.jsonl" ,
+ "HF Kernels Flash Attn" : "attn.jsonl" ,
+ "HF Kernels Flash Attn3" : "attn.jsonl" ,
+}
+
+all_paths = []
+for name , cache_dir in cache_dirs . items ():
+ if cache_dir :
+ path = Path ( cache_dir ) / file_mapping [ name ]
+ if path . exists () and path . stat () . st_size > 0 :
+ all_paths . append ( str ( path ))
+ print ( f "✓ Found { name } : { path } " )
+ else :
+ print ( f "⊘ Empty/Missing { name } : { path } " )
+ else :
+ print ( f "✗ No cache dir for { name } " )
+print ()
+
+if not all_paths :
+ print ( "ERROR: No benchmark data files found!" )
+ # restore patched functions before exiting
+ plt . savefig = _orig_savefig
+ plt . close = _orig_close
+ sys . exit ( 1 )
+
+# --- Summary + Visualization -----------------------------------------------------
+print ( "COMBINED BENCHMARK SUMMARY \n " )
+kbt . summarize ( all_paths )
+print ( " \n GENERATING COMBINED VISUALIZATION \n " )
+
+try :
+ # If kbt.viz saves internally, our patched savefig ensures SVG gets written,
+ # and it will carry ids/classes for CSS styling.
+ kbt . viz ( all_paths )
+ # Safety net: if kbt.viz didn't save, save now.
+ # if not Path("latency.svg").exists():
+ # _tag_current_figure()
+ # plt.savefig("latency.svg")
+
+ plt . savefig ( "latency.svg" ) # ensure saved with tagging
+
+ print ( "✓ SVG visualization ready: latency.svg!" )
+except ImportError as e :
+ print ( f "✗ Visualization requires matplotlib: { e } " )
+except Exception as e :
+ print ( f "✗ Visualization failed: { e } " )
+finally :
+ # Clean up patches to avoid side effects in later cells
+ plt . savefig = _orig_savefig
+ plt . close = _orig_close
+
+print ()
+print ( "ANALYSIS COMPLETE" )
+print ( f "Total implementations analyzed: { len ( all_paths ) } " )
+print ( f " \n Implementations included:" )
+for name , cache_dir in cache_dirs . items ():
+ if cache_dir :
+ path = Path ( cache_dir ) / file_mapping [ name ]
+ if path . exists () and path . stat () . st_size > 0 :
+ print ( f " ✓ { name } " )
+
+
+
+# Collect all benchmark data and export to CSV
+all_data = {}
+for name , cache_dir in cache_dirs . items ():
+ if cache_dir :
+ path = Path ( cache_dir ) / file_mapping [ name ]
+ if path . exists () and path . stat () . st_size > 0 :
+ with open ( path , 'r' ) as f :
+ records = [ json . loads ( line ) for line in f ]
+ all_data [ name ] = records
+
+# Export to CSV
+csv_path = Path ( "latency.csv" )
+with open ( csv_path , 'w' , newline = '' ) as csvfile :
+ writer = csv . writer ( csvfile )
+
+ # Write header
+ header = [ "Implementation" , "Impl ID" , "Workload" , "Batch" , "Seq Length" , "Heads" , "Head Dim" , "Dtype" ,
+ "Mean (ms)" , "P10 (ms)" , "P50 (ms)" , "P90 (ms)" , "Reps" ,
+ # "Compile (ms)",
+ "Peak Mem (MB)" , "Backend" , "Family" ]
+ writer . writerow ( header )
+
+ # Write data rows
+ for impl_name , records in all_data . items ():
+ for record in records :
+ wl = record . get ( 'wl' , {})
+ lat = record . get ( 'lat_ms' , {})
+ tags = record . get ( 'tags' , {})
+
+ row = [
+ impl_name ,
+ record . get ( 'impl' , '' ),
+ wl . get ( 'name' , '' ),
+ wl . get ( 'batch' , '' ),
+ wl . get ( 'seq_len' , '' ),
+ wl . get ( 'heads' , '' ),
+ wl . get ( 'head_dim' , '' ),
+ wl . get ( 'dtype' , '' ),
+ lat . get ( 'mean' , '' ),
+ lat . get ( 'p10' , '' ),
+ lat . get ( 'p50' , '' ),
+ lat . get ( 'p90' , '' ),
+ lat . get ( 'reps' , '' ),
+ # record.get('compile_ms', ''),
+ round ( record . get ( 'peak_bytes' , 0 ) / 1024 / 1024 , 2 ) if record . get ( 'peak_bytes' ) else '' ,
+ tags . get ( 'backend' , '' ),
+ tags . get ( 'family' , '' ),
+ ]
+ writer . writerow ( row )
+
+print ( f "✓ CSV export complete: { csv_path } " )
+print ( f "Total implementations: { len ( all_data ) } " )
+print ( f "Total records: { sum ( len ( records ) for records in all_data . values ()) } " )
+
+
+
+
+
+
+
+
LOADING BENCHMARK DATA
+Flash (PyTorch SDPA) : /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/dfabb76be980c54bed5516e2c57aa0a8d1d29c88b8b1d32ce8f8eb1b96260e90
+MemEff (PyTorch SDPA) : /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/b55d0f8b9e1a42ce197ed67716743fb2bd3bd76a1c4eef86ae36338351e6458d
+Flash Attn 2 : None
+xFormers : /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/be950c5cf60d4cdaea76a6c38ab4fb7e6da509ca9833fced05c39f642981839a
+SageAttention : None
+Compiled (default) : /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/50695b32047eb68d75addc0415164aa59c26dd8b180acf56a0f2fc92ca88f9fe
+Compiled (max-autotune) : /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/e22033ba3f39f87c6b9bba3f26af59c09999dab741717146939479dd3e140834
+HF Kernels Flash Attn : /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/f5347da8e6f046c3e9a96cb7843f885425fe84602e9e6c10758773df819982c9
+HF Kernels Flash Attn3 : /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/482931c458dfb5bb4d51247d9e6348ae4d4932319d902903f6f858005c8f75f7
+
+✓ Found Flash (PyTorch SDPA): /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/dfabb76be980c54bed5516e2c57aa0a8d1d29c88b8b1d32ce8f8eb1b96260e90/attn.jsonl
+✓ Found MemEff (PyTorch SDPA): /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/b55d0f8b9e1a42ce197ed67716743fb2bd3bd76a1c4eef86ae36338351e6458d/attn.jsonl
+✗ No cache dir for Flash Attn 2
+✓ Found xFormers: /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/be950c5cf60d4cdaea76a6c38ab4fb7e6da509ca9833fced05c39f642981839a/attn.jsonl
+✗ No cache dir for SageAttention
+✓ Found Compiled (default): /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/50695b32047eb68d75addc0415164aa59c26dd8b180acf56a0f2fc92ca88f9fe/attn_default.jsonl
+✓ Found Compiled (max-autotune): /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/e22033ba3f39f87c6b9bba3f26af59c09999dab741717146939479dd3e140834/attn_max_autotune.jsonl
+✓ Found HF Kernels Flash Attn: /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/f5347da8e6f046c3e9a96cb7843f885425fe84602e9e6c10758773df819982c9/attn.jsonl
+✓ Found HF Kernels Flash Attn3: /home/ubuntu/Projects/kernels-uvnotes/flash_attn/impls/.uvnote/cache/482931c458dfb5bb4d51247d9e6348ae4d4932319d902903f6f858005c8f75f7/attn.jsonl
+
+COMBINED BENCHMARK SUMMARY
+
+impl wl p50(ms) ok
+hf_kernels_flash_attn flux_L128 0.25 True
+hf_kernels_flash_attn flux_L256 0.32 True
+hf_kernels_flash_attn flux_L320 0.34 True
+hf_kernels_flash_attn flux_L384 0.35 True
+hf_kernels_flash_attn flux_L448 0.38 True
+hf_kernels_flash_attn flux_L512 0.42 True
+hf_kernels_flash_attn3 flux_L128 0.28 True
+hf_kernels_flash_attn3 flux_L256 0.34 True
+hf_kernels_flash_attn3 flux_L320 0.36 True
+hf_kernels_flash_attn3 flux_L384 0.37 True
+hf_kernels_flash_attn3 flux_L448 0.40 True
+hf_kernels_flash_attn3 flux_L512 0.43 True
+torch_flash_compiled_default flux_L128 0.36 True
+torch_flash_compiled_default flux_L256 0.50 True
+torch_flash_compiled_default flux_L320 0.54 True
+torch_flash_compiled_default flux_L384 0.59 True
+torch_flash_compiled_default flux_L448 0.61 True
+torch_flash_compiled_default flux_L512 0.64 True
+torch_flash_compiled_max_autotune flux_L128 0.38 True
+torch_flash_compiled_max_autotune flux_L256 0.55 True
+torch_flash_compiled_max_autotune flux_L320 0.61 True
+torch_flash_compiled_max_autotune flux_L384 0.66 True
+torch_flash_compiled_max_autotune flux_L448 0.70 True
+torch_flash_compiled_max_autotune flux_L512 0.76 True
+torch_flash_ma flux_L128 0.41 True
+torch_flash_ma flux_L256 0.52 True
+torch_flash_ma flux_L320 0.55 True
+torch_flash_ma flux_L384 0.59 True
+torch_flash_ma flux_L448 0.64 True
+torch_flash_ma flux_L512 0.68 True
+torch_mem_eff flux_L128 0.48 True
+torch_mem_eff flux_L256 0.63 True
+torch_mem_eff flux_L320 0.70 True
+torch_mem_eff flux_L384 0.83 True
+torch_mem_eff flux_L448 0.95 True
+torch_mem_eff flux_L512 1.00 True
+xformers_meff flux_L128 0.35 True
+xformers_meff flux_L256 0.41 True
+xformers_meff flux_L320 0.43 True
+xformers_meff flux_L384 0.44 True
+xformers_meff flux_L448 0.48 True
+xformers_meff flux_L512 0.50 True
+
+GENERATING COMBINED VISUALIZATION
+
+Loaded 42 records
+✓ Combined visualization saved as latency.svg
+Saved latency.png
+✓ Combined visualization saved as latency.svg
+✓ SVG visualization ready: latency.svg!
+
+ANALYSIS COMPLETE
+Total implementations analyzed: 7
+
+Implementations included:
+ ✓ Flash (PyTorch SDPA)
+ ✓ MemEff (PyTorch SDPA)
+ ✓ xFormers
+ ✓ Compiled (default)
+ ✓ Compiled (max-autotune)
+ ✓ HF Kernels Flash Attn
+ ✓ HF Kernels Flash Attn3
+✓ CSV export complete: latency.csv
+Total implementations: 7
+Total records: 42
+
+
+
+
+ Updating https://github.com/drbh/kernels-benchmark-tools.git (main)
+Downloading nvidia-cufile-cu12 (1.1MiB)
+Downloading nvidia-cufft-cu12 (184.2MiB)
+Downloading nvidia-cublas-cu12 (566.8MiB)
+Downloading networkx (1.9MiB)
+Downloading sympy (6.0MiB)
+Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
+Downloading nvidia-cusparselt-cu12 (273.9MiB)
+Downloading nvidia-cudnn-cu12 (674.0MiB)
+Downloading nvidia-nvjitlink-cu12 (37.4MiB)
+Downloading numpy (15.9MiB)
+Downloading nvidia-cusolver-cu12 (255.1MiB)
+Downloading nvidia-curand-cu12 (60.7MiB)
+Downloading setuptools (1.1MiB)
+Downloading matplotlib (8.3MiB)
+Downloading fonttools (4.7MiB)
+Downloading nvidia-cusparse-cu12 (274.9MiB)
+Downloading pillow (6.3MiB)
+Downloading nvidia-nccl-cu12 (307.4MiB)
+Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
+Downloading kiwisolver (1.4MiB)
+Downloading triton (148.4MiB)
+Downloading torch (846.8MiB)
+ Updated https://github.com/drbh/kernels-benchmark-tools.git (f457279bca6573cd2fa54a74e67118f5e6b7a31c)
+ Building kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading nvidia-cufile-cu12
+ Downloading kiwisolver
+ Downloading setuptools
+ Downloading networkx
+ Downloading fonttools
+ Downloading pillow
+ Built kernels-benchmark-tools @ git+https://github.com/drbh/kernels-benchmark-tools.git@f457279bca6573cd2fa54a74e67118f5e6b7a31c
+ Downloading matplotlib
+ Downloading nvidia-cuda-cupti-cu12
+ Downloading numpy
+ Downloading sympy
+ Downloading nvidia-nvjitlink-cu12
+ Downloading nvidia-curand-cu12
+ Downloading nvidia-cuda-nvrtc-cu12
+ Downloading triton
+ Downloading nvidia-cufft-cu12
+ Downloading nvidia-cusolver-cu12
+ Downloading nvidia-cusparselt-cu12
+ Downloading nvidia-cusparse-cu12
+ Downloading nvidia-nccl-cu12
+ Downloading nvidia-cublas-cu12
+ Downloading nvidia-cudnn-cu12
+ Downloading torch
+Installed 37 packages in 254ms
+
+
+
+
Artifacts:
+
latency.csv
+
latency.svg
+
+
+
+Implementation
+Impl ID
+Workload
+Batch
+Seq Length
+Heads
+Head Dim
+Dtype
+Mean (ms)
+P10 (ms)
+P50 (ms)
+P90 (ms)
+Reps
+Peak Mem (MB)
+Backend
+Family
+
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.407123202085495
+0.40537598729133606
+0.40755200386047363
+0.407584011554718
+5
+83.38
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.5235007882118226
+0.5212159752845764
+0.5232639908790588
+0.523360013961792
+5
+90.62
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.545849597454071
+0.5418559908866882
+0.5468159914016724
+0.5469120144844055
+5
+95.06
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.5892416119575501
+0.5867519974708557
+0.5888000130653381
+0.5888000130653381
+5
+99.88
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.6449280023574829
+0.6430720090866089
+0.6442239880561829
+0.6450240015983582
+5
+103.81
+FLASH
+torch-sdpa
+
+
+Flash (PyTorch SDPA)
+torch_flash_ma
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.6823423862457275
+0.6777600049972534
+0.6809599995613098
+0.6818559765815735
+5
+109.12
+FLASH
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.48371200561523436
+0.4821760058403015
+0.4833280146121979
+0.4853760004043579
+5
+83.38
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.6268800020217895
+0.6246399879455566
+0.6266880035400391
+0.6286720037460327
+5
+90.62
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.699776005744934
+0.6973440051078796
+0.7004160284996033
+0.7004479765892029
+5
+95.94
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.8333312034606933
+0.8284159898757935
+0.8325120210647583
+0.8376320004463196
+5
+100.0
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.9533439993858337
+0.9502720236778259
+0.9512959718704224
+0.9572479724884033
+5
+103.81
+EFFICIENT
+torch-sdpa
+
+
+MemEff (PyTorch SDPA)
+torch_mem_eff
+flux_L512
+1
+1536
+24
+128
+bfloat16
+1.0066367864608765
+1.0024960041046143
+1.0045440196990967
+1.0097919702529907
+5
+109.12
+EFFICIENT
+torch-sdpa
+
+
+xFormers
+xformers_meff
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.3452928066253662
+0.3389439880847931
+0.3461120128631592
+0.3461120128631592
+5
+83.38
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.41234560012817384
+0.40959998965263367
+0.41280001401901245
+0.41286399960517883
+5
+90.62
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.4366208016872406
+0.4310399889945984
+0.4331519901752472
+0.4362240135669708
+5
+95.06
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.4450624048709869
+0.4359680116176605
+0.44361600279808044
+0.447488009929657
+5
+99.88
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.4750400006771088
+0.4711039960384369
+0.47513601183891296
+0.4763199985027313
+5
+103.81
+memory_efficient
+xformers
+
+
+xFormers
+xformers_meff
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.5009407997131348
+0.49663999676704407
+0.4997119903564453
+0.5038080215454102
+5
+109.12
+memory_efficient
+xformers
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.3856383919715881
+0.3563520014286041
+0.35942399501800537
+0.3624959886074066
+5
+83.38
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.4982912003993988
+0.4926080107688904
+0.49663999676704407
+0.5017600059509277
+5
+90.62
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.5369919896125793
+0.5335040092468262
+0.5366079807281494
+0.5386239886283875
+5
+95.25
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.5841408014297486
+0.5775359869003296
+0.5868800282478333
+0.5877760052680969
+5
+99.88
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.6184704065322876
+0.6072319746017456
+0.6113280057907104
+0.6144000291824341
+5
+103.81
+FLASH
+torch-sdpa
+
+
+Compiled (default)
+torch_flash_compiled_default
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.6428672075271606
+0.6399999856948853
+0.6430720090866089
+0.6430720090866089
+5
+109.12
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.40020479559898375
+0.3665919899940491
+0.3768320083618164
+0.41171199083328247
+5
+81.75
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.5535807967185974
+0.5160959959030151
+0.5489599704742432
+0.5631359815597534
+5
+92.88
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.6143999934196472
+0.562175989151001
+0.6144000291824341
+0.6318079829216003
+5
+95.13
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.6754495978355408
+0.6512640118598938
+0.6584320068359375
+0.6799359917640686
+5
+97.13
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.7210752129554748
+0.6973119974136353
+0.7014080286026001
+0.7229440212249756
+5
+99.0
+FLASH
+torch-sdpa
+
+
+Compiled (max-autotune)
+torch_flash_compiled_max_autotune
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.7735359907150269
+0.7485439777374268
+0.7557439804077148
+0.7710719704627991
+5
+101.63
+FLASH
+torch-sdpa
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.2456959992647171
+0.24371199309825897
+0.24566400051116943
+0.2457599937915802
+5
+83.38
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.3215551972389221
+0.3164159953594208
+0.319487988948822
+0.32051199674606323
+5
+90.62
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.3384703993797302
+0.33670398592948914
+0.33792001008987427
+0.33983999490737915
+5
+95.06
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.3510208010673523
+0.3481599986553192
+0.3491840064525604
+0.35225600004196167
+5
+99.88
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.3829823970794678
+0.38095998764038086
+0.3829759955406189
+0.3840000033378601
+5
+103.81
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn
+hf_kernels_flash_attn
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.4259391903877258
+0.4227519929409027
+0.4249599874019623
+0.4259839951992035
+5
+109.12
+flash-attn
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L128
+1
+1152
+24
+128
+bfloat16
+0.2755008041858673
+0.26736000180244446
+0.27561599016189575
+0.27955201268196106
+5
+83.38
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L256
+1
+1280
+24
+128
+bfloat16
+0.3397440016269684
+0.3368000090122223
+0.3399679958820343
+0.34191998839378357
+5
+90.62
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L320
+1
+1344
+24
+128
+bfloat16
+0.36019839644432067
+0.3563520014286041
+0.3604480028152466
+0.36137598752975464
+5
+95.06
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L384
+1
+1408
+24
+128
+bfloat16
+0.37342079877853396
+0.3718400001525879
+0.37379199266433716
+0.3746879994869232
+5
+99.88
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L448
+1
+1472
+24
+128
+bfloat16
+0.4024448037147522
+0.3993600010871887
+0.4014720022678375
+0.4034560024738312
+5
+103.81
+flash-attn3
+hf-kernels
+
+
+HF Kernels Flash Attn3
+hf_kernels_flash_attn3
+flux_L512
+1
+1536
+24
+128
+bfloat16
+0.4305088043212891
+0.4270080029964447
+0.4291520118713379
+0.4331519901752472
+5
+109.12
+flash-attn3
+hf-kernels
+
+
+
+
+
+
+
+
+
+
+
+ 2025-10-14T20:33:46.842309
+ image/svg+xml
+
+
+ Matplotlib v3.10.7, https://matplotlib.org/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L128
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L256
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L320
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L384
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L448
+
+
+
+
+
+
+
+
+
+
+
+
+ flux_L512
+
+
+
+ Workload
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.3
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.4
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.5
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.6
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.7
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.8
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.9
+
+
+
+
+
+
+
+
+
+
+
+
+ 1.0
+
+
+
+ Latency P50 (ms)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Attention Implementation Latency
+
+
+
+
+
+
+
+
+
+
+
+
+ torch_flash_ma
+
+
+
+
+
+
+
+
+ torch_mem_eff
+
+
+
+
+
+
+
+
+ xformers_meff
+
+
+
+
+
+
+
+
+ torch_flash_compiled_default
+
+
+
+
+
+
+
+
+ torch_flash_compiled_max_autotune
+
+
+
+
+
+
+
+
+ hf_kernels_flash_attn
+
+
+
+
+
+
+
+
+ hf_kernels_flash_attn3
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/flash_attn/results/index.html b/flash_attn/results/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..b87b6002f4b781572dbb50f91850e50ee98130ab
--- /dev/null
+++ b/flash_attn/results/index.html
@@ -0,0 +1,88 @@
+
+
+
+
+
+ Index of /flash_attn/results
+
+
+
+
+ Index of /flash_attn/results
+
+
+
\ No newline at end of file
diff --git a/index.html b/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..df44040e2dd9e1e4a0fc2d5ee08453d4b9953f11
--- /dev/null
+++ b/index.html
@@ -0,0 +1,85 @@
+
+
+
+
+
+ Index of /
+
+
+
+ Index of /
+
+
+
\ No newline at end of file