diff --git a/flash_attn/impls/artifacts/benchmark/attn.jsonl b/flash_attn/impls/artifacts/benchmark/attn.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..f1695213932feff6321b80cc6dac9317e83037c9 --- /dev/null +++ b/flash_attn/impls/artifacts/benchmark/attn.jsonl @@ -0,0 +1,6 @@ +{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.3389439880847931, "p50": 0.3461120128631592, "p90": 0.3461120128631592, "mean": 0.3452928066253662, "reps": 5, "warmup": 2}, "compile_ms": 0.9463679790496826, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.000362396240234375, "mse": 2.9206275939941406e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.40959998965263367, "p50": 0.41280001401901245, "p90": 0.41286399960517883, "mean": 0.41234560012817384, "reps": 5, "warmup": 2}, "compile_ms": 0.34329599142074585, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.4310399889945984, "p50": 0.4331519901752472, "p90": 0.4362240135669708, "mean": 0.4366208016872406, "reps": 5, "warmup": 2}, "compile_ms": 0.35942399501800537, "peak_bytes": 99680256, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003566741943359375, "mse": 2.8759241104125977e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.4359680116176605, "p50": 0.44361600279808044, "p90": 0.447488009929657, "mean": 0.4450624048709869, "reps": 5, "warmup": 2}, "compile_ms": 0.3678080141544342, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003604888916015625, "mse": 2.8759241104125977e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.4711039960384369, "p50": 0.47513601183891296, "p90": 0.4763199985027313, "mean": 0.4750400006771088, "reps": 5, "warmup": 2}, "compile_ms": 0.40857601165771484, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.86102294921875e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:08:21Z", "run": "4862bb56aac04f66908d2a97924104e2", "impl": "xformers_meff", "tags": {"family": "xformers", "backend": "memory_efficient", "compile": "none"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.49663999676704407, "p50": 0.4997119903564453, "p90": 0.5038080215454102, "mean": 0.5009407997131348, "reps": 5, "warmup": 2}, "compile_ms": 0.43724799156188965, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003566741943359375, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null} diff --git a/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl b/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..5ba2e631889b69237014219e513ecfcbae800a20 --- /dev/null +++ b/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl @@ -0,0 +1,6 @@ +{"ts": "2025-10-02T16:11:55Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.3563520014286041, "p50": 0.35942399501800537, "p90": 0.3624959886074066, "mean": 0.3856383919715881, "reps": 5, "warmup": 2}, "compile_ms": 2383.33544921875, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:55Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.4926080107688904, "p50": 0.49663999676704407, "p90": 0.5017600059509277, "mean": 0.4982912003993988, "reps": 5, "warmup": 2}, "compile_ms": 76.60860443115234, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:55Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.5335040092468262, "p50": 0.5366079807281494, "p90": 0.5386239886283875, "mean": 0.5369919896125793, "reps": 5, "warmup": 2}, "compile_ms": 74.49088287353516, "peak_bytes": 99876864, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:55Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.5775359869003296, "p50": 0.5868800282478333, "p90": 0.5877760052680969, "mean": 0.5841408014297486, "reps": 5, "warmup": 2}, "compile_ms": 72.97433471679688, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:56Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.6072319746017456, "p50": 0.6113280057907104, "p90": 0.6144000291824341, "mean": 0.6184704065322876, "reps": 5, "warmup": 2}, "compile_ms": 215.12498474121094, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:56Z", "run": "e9857dd2d39d4b40a6c91c0fdad82b00", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.6399999856948853, "p50": 0.6430720090866089, "p90": 0.6430720090866089, "mean": 0.6428672075271606, "reps": 5, "warmup": 2}, "compile_ms": 71.8028793334961, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null} diff --git a/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl b/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..8a1e8bb5317f89c8a6cab3ccb82bd9e167c6db90 --- /dev/null +++ b/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl @@ -0,0 +1,6 @@ +{"ts": "2025-10-02T16:11:08Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.3665919899940491, "p50": 0.3768320083618164, "p90": 0.41171199083328247, "mean": 0.40020479559898375, "reps": 5, "warmup": 2}, "compile_ms": 2910.97705078125, "peak_bytes": 85722112, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:08Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.5160959959030151, "p50": 0.5489599704742432, "p90": 0.5631359815597534, "mean": 0.5535807967185974, "reps": 5, "warmup": 2}, "compile_ms": 85.84806060791016, "peak_bytes": 97387520, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:09Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.562175989151001, "p50": 0.6144000291824341, "p90": 0.6318079829216003, "mean": 0.6143999934196472, "reps": 5, "warmup": 2}, "compile_ms": 82.77401733398438, "peak_bytes": 99746816, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:09Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.6512640118598938, "p50": 0.6584320068359375, "p90": 0.6799359917640686, "mean": 0.6754495978355408, "reps": 5, "warmup": 2}, "compile_ms": 81.94969940185547, "peak_bytes": 101843968, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:09Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.6973119974136353, "p50": 0.7014080286026001, "p90": 0.7229440212249756, "mean": 0.7210752129554748, "reps": 5, "warmup": 2}, "compile_ms": 81.1141128540039, "peak_bytes": 103810048, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null} +{"ts": "2025-10-02T16:11:10Z", "run": "06f2face3c924e1b89a35a0fb568d4b1", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA L4", "sm": "8.9", "py": "3.12.7", "plat": "Linux-5.15.0-1084-aws-x86_64-with-glibc2.31"}, "lat_ms": {"p10": 0.7485439777374268, "p50": 0.7557439804077148, "p90": 0.7710719704627991, "mean": 0.7735359907150269, "reps": 5, "warmup": 2}, "compile_ms": 767.1397094726562, "peak_bytes": 106562560, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null} diff --git a/flash_attn/impls/cells/benchmark.py b/flash_attn/impls/cells/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..c558501ad2f4abcbaec8236272134bb8e8b0cfc4 --- /dev/null +++ b/flash_attn/impls/cells/benchmark.py @@ -0,0 +1,72 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "numpy", +# "torch", +# "kernels-benchmark-tools", +# "kernels", +# ] +# +# [tool.uv.sources] +# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } +# /// +import torch +import sys +import os +import kernels_benchmark_tools as kbt +from kernels import get_kernel + +hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn", revision="v0.0.2") + + +def hf_flash_attention(query, key, value): + """HuggingFace Kernels Flash Attention""" + return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0] + + +kbt.add( + "hf_kernels_flash_attn", + hf_flash_attention, + tags={"family": "hf-kernels", "backend": "flash-attn", "compile": "none"}, +) + +if __name__ == "__main__": + device = "cuda" if torch.cuda.is_available() else "cpu" + + if device == "cpu": + print("HF Kernels Flash Attention requires CUDA - skipping benchmark") + sys.exit(0) + + dtype = "bfloat16" + + # Flux-like workloads + base = 1024 + flux_sizes = [128, 256, 320, 384, 448, 512] + heads = 24 + head_dim = 128 + + wl = [] + for L in flux_sizes: + wl.append( + { + "name": f"flux_L{L}", + "batch": 1, + "seq_len": base + L, + "heads": heads, + "head_dim": head_dim, + "dtype": dtype, + "device": device, + "seed": 0, + } + ) + + kbt.run( + wl, + jsonl="attn.jsonl", + reps=5, + warmup=2, + gen=kbt.attn.gen_qkv, + ref=kbt.attn.ref_math, + cmp=kbt.attn.cmp_allclose, + ) + kbt.summarize(["attn.jsonl"]) \ No newline at end of file diff --git a/flash_attn/impls/cells/benchmark_default.py b/flash_attn/impls/cells/benchmark_default.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2fd06ac69ffe1f5bc88d1821b17447dc90c846 --- /dev/null +++ b/flash_attn/impls/cells/benchmark_default.py @@ -0,0 +1,70 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "numpy", +# "torch", +# "kernels-benchmark-tools", +# ] +# +# [tool.uv.sources] +# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } +# /// +import torch +import sys +import os +import kernels_benchmark_tools as kbt + + +def torch_flash_base(q, k, v): + qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt) + return o.transpose(1, 2).contiguous() + + +# Compile with default mode +compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False) + +kbt.add( + "torch_flash_compiled_default", + compiled_flash_default, + tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, +) + +if __name__ == "__main__": + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = "float32" if device == "cpu" else "bfloat16" + + # Flux-like workloads + base = 1024 if device == "cuda" else 512 + flux_sizes = ( + [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256] + ) + heads = 24 if device == "cuda" else 8 + head_dim = 128 if device == "cuda" else 64 + + wl = [] + for L in flux_sizes: + wl.append( + { + "name": f"flux_L{L}", + "batch": 1, + "seq_len": base + L, + "heads": heads, + "head_dim": head_dim, + "dtype": dtype, + "device": device, + "seed": 0, + } + ) + + kbt.run( + wl, + jsonl="attn_default.jsonl", + reps=5, + warmup=2, + gen=kbt.attn.gen_qkv, + ref=kbt.attn.ref_math, + cmp=kbt.attn.cmp_allclose, + ) + kbt.summarize(["attn_default.jsonl"]) \ No newline at end of file diff --git a/flash_attn/impls/cells/benchmark_max_autotune.py b/flash_attn/impls/cells/benchmark_max_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..bd96e676c4d9ebdf709b701a7b9a71b9d51774fd --- /dev/null +++ b/flash_attn/impls/cells/benchmark_max_autotune.py @@ -0,0 +1,70 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "numpy", +# "torch", +# "kernels-benchmark-tools", +# ] +# +# [tool.uv.sources] +# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } +# /// +import torch +import sys +import os +import kernels_benchmark_tools as kbt + + +def torch_flash_base(q, k, v): + qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt) + return o.transpose(1, 2).contiguous() + + +# Compile with max-autotune mode +compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False) + +kbt.add( + "torch_flash_compiled_max_autotune", + compiled_flash_max_autotune, + tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, +) + +if __name__ == "__main__": + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = "float32" if device == "cpu" else "bfloat16" + + # Flux-like workloads + base = 1024 if device == "cuda" else 512 + flux_sizes = ( + [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256] + ) + heads = 24 if device == "cuda" else 8 + head_dim = 128 if device == "cuda" else 64 + + wl = [] + for L in flux_sizes: + wl.append( + { + "name": f"flux_L{L}", + "batch": 1, + "seq_len": base + L, + "heads": heads, + "head_dim": head_dim, + "dtype": dtype, + "device": device, + "seed": 0, + } + ) + + kbt.run( + wl, + jsonl="attn_max_autotune.jsonl", + reps=5, + warmup=2, + gen=kbt.attn.gen_qkv, + ref=kbt.attn.ref_math, + cmp=kbt.attn.cmp_allclose, + ) + kbt.summarize(["attn_max_autotune.jsonl"]) \ No newline at end of file diff --git a/flash_attn/impls/cells/nv.py b/flash_attn/impls/cells/nv.py new file mode 100644 index 0000000000000000000000000000000000000000..80eef60a7536ed875fb21731ab2d059458bd20b4 --- /dev/null +++ b/flash_attn/impls/cells/nv.py @@ -0,0 +1,3 @@ +import subprocess + +print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) \ No newline at end of file diff --git a/flash_attn/impls/compiled_variants.html b/flash_attn/impls/compiled_variants.html new file mode 100644 index 0000000000000000000000000000000000000000..6a8ad1c87fe3de178f48216a9a342d6d09eda581 --- /dev/null +++ b/flash_attn/impls/compiled_variants.html @@ -0,0 +1,4163 @@ + + + + + + compiled_variants + + + + + + + +
+
+ + ← back + +
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31 +
+
+ +
+

Torch Compile Variants!

+

This file benchmarks Flash Attention with different torch.compile modes.

+

Flash Attention with torch.compile(mode="default")

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark_default | 45.83s + | + +Raw +GitHub +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash_base(q, k, v):
+    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+    return o.transpose(1, 2).contiguous()
+
+
+# Compile with default mode
+compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)
+
+kbt.add(
+    "torch_flash_compiled_default",
+    compiled_flash_default,
+    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dtype = "float32" if device == "cpu" else "bfloat16"
+
+    # Flux-like workloads
+    base = 1024 if device == "cuda" else 512
+    flux_sizes = (
+        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+    )
+    heads = 24 if device == "cuda" else 8
+    head_dim = 128 if device == "cuda" else 64
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn_default.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn_default.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +torch_flash_compiled_default flux_L128 0.36 True +torch_flash_compiled_default flux_L256 0.50 True +torch_flash_compiled_default flux_L320 0.54 True +torch_flash_compiled_default flux_L384 0.59 True +torch_flash_compiled_default flux_L448 0.61 True +torch_flash_compiled_default flux_L512 0.64 True +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+attn_default.jsonl +
+
+
+ +

Flash Attention with torch.compile(mode="max-autotune")

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark_max_autotune | 48.72s + | + +Raw +GitHub +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash_base(q, k, v):
+    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+    return o.transpose(1, 2).contiguous()
+
+
+# Compile with max-autotune mode
+compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False)
+
+kbt.add(
+    "torch_flash_compiled_max_autotune",
+    compiled_flash_max_autotune,
+    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dtype = "float32" if device == "cpu" else "bfloat16"
+
+    # Flux-like workloads
+    base = 1024 if device == "cuda" else 512
+    flux_sizes = (
+        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+    )
+    heads = 24 if device == "cuda" else 8
+    head_dim = 128 if device == "cuda" else 64
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn_max_autotune.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn_max_autotune.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +torch_flash_compiled_max_autotune flux_L128 0.38 True +torch_flash_compiled_max_autotune flux_L256 0.55 True +torch_flash_compiled_max_autotune flux_L320 0.61 True +torch_flash_compiled_max_autotune flux_L384 0.66 True +torch_flash_compiled_max_autotune flux_L448 0.70 True +torch_flash_compiled_max_autotune flux_L512 0.76 True +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+attn_max_autotune.jsonl +
+
+
+
+ + + \ No newline at end of file diff --git a/flash_attn/impls/flash_attention.html b/flash_attn/impls/flash_attention.html new file mode 100644 index 0000000000000000000000000000000000000000..69fae8a00674f2b96a097313a070476b23f6617a --- /dev/null +++ b/flash_attn/impls/flash_attention.html @@ -0,0 +1,4057 @@ + + + + + + flash_attention + + + + + + + +
+
+ + ← back + +
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31 +
+
+ +
+

Flash Attention Implementation

+

GPU Info

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: nv | 4.06s + | + +Raw +
+
+
+
import subprocess
+
+print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
+
+ +
+
+
+
+
Thu Oct 2 16:12:42 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.05 Driver Version: 560.35.05 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA L4 Off | 00000000:38:00.0 Off | 0 | +| N/A 41C P0 27W / 72W | 1MiB / 23034MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA L4 Off | 00000000:3A:00.0 Off | 0 | +| N/A 41C P0 27W / 72W | 1MiB / 23034MiB | 2% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA L4 Off | 00000000:3C:00.0 Off | 0 | +| N/A 44C P0 29W / 72W | 1MiB / 23034MiB | 2% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA L4 Off | 00000000:3E:00.0 Off | 0 | +| N/A 42C P0 29W / 72W | 1MiB / 23034MiB | 2% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +
+
+
+ +

Flash Attention Benchmark

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 38.14s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_flash(q, k, v):
+    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+    return o.transpose(1, 2).contiguous()
+
+kbt.add(
+    "torch_flash_ma",
+    torch_flash,
+    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dtype = "float32" if device == "cpu" else "bfloat16"
+
+    # Flux-like workloads scaled down for CPU testing
+    base = 1024 if device == "cuda" else 512
+    flux_sizes = (
+        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+    )
+    heads = 24 if device == "cuda" else 8
+    head_dim = 128 if device == "cuda" else 64
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +torch_flash_ma flux_L128 0.41 True +torch_flash_ma flux_L256 0.52 True +torch_flash_ma flux_L320 0.55 True +torch_flash_ma flux_L384 0.59 True +torch_flash_ma flux_L448 0.64 True +torch_flash_ma flux_L512 0.68 True +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+attn.jsonl +
+
+
+
+ + + \ No newline at end of file diff --git a/flash_attn/impls/hf_kernels_flash_attn.html b/flash_attn/impls/hf_kernels_flash_attn.html new file mode 100644 index 0000000000000000000000000000000000000000..a73274cdbcce6cc4d5e87b9056f24bf359798266 --- /dev/null +++ b/flash_attn/impls/hf_kernels_flash_attn.html @@ -0,0 +1,4010 @@ + + + + + + hf_kernels_flash_attn + + + + + + + +
+
+ + ← back + +
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31 +
+
+ +
+

HF Kernels - Flash Attention

+

HuggingFace Kernels Flash Attention Benchmark

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 40.14s + | + +Raw +GitHub +🤗 HF +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+#     "kernels",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+from kernels import get_kernel
+
+hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn", revision="v0.0.2")
+
+
+def hf_flash_attention(query, key, value):
+    """HuggingFace Kernels Flash Attention"""
+    return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
+
+
+kbt.add(
+    "hf_kernels_flash_attn",
+    hf_flash_attention,
+    tags={"family": "hf-kernels", "backend": "flash-attn", "compile": "none"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+
+    if device == "cpu":
+        print("HF Kernels Flash Attention requires CUDA - skipping benchmark")
+        sys.exit(0)
+
+    dtype = "bfloat16"
+
+    # Flux-like workloads
+    base = 1024
+    flux_sizes = [128, 256, 320, 384, 448, 512]
+    heads = 24
+    head_dim = 128
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +hf_kernels_flash_attn flux_L128 0.25 True +hf_kernels_flash_attn flux_L256 0.32 True +hf_kernels_flash_attn flux_L320 0.34 True +hf_kernels_flash_attn flux_L384 0.35 True +hf_kernels_flash_attn flux_L448 0.38 True +hf_kernels_flash_attn flux_L512 0.42 True +
+
+
▶ UV Install Logs
+ +
+
Fetching 20 files: 0%| | 0/20 [00:00<?, ?it/s] +Fetching 20 files: 5%|▌ | 1/20 [00:00<00:05, 3.64it/s] +Fetching 20 files: 10%|█ | 2/20 [00:02<00:22, 1.24s/it] +Fetching 20 files: 100%|██████████| 20/20 [00:02<00:00, 9.14it/s]
+
+

Artifacts:

+attn.jsonl +
+
+
+
+ + + \ No newline at end of file diff --git a/flash_attn/impls/hf_kernels_flash_attn3.html b/flash_attn/impls/hf_kernels_flash_attn3.html new file mode 100644 index 0000000000000000000000000000000000000000..f745e3998546d5effda35036ae155237f9caf603 --- /dev/null +++ b/flash_attn/impls/hf_kernels_flash_attn3.html @@ -0,0 +1,4007 @@ + + + + + + hf_kernels_flash_attn3 + + + + + + + +
+
+ + ← back + +
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31 +
+
+ +
+

HF Kernels - Flash Attention 3

+

HuggingFace Kernels Flash Attention 3 Benchmark

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 40.68s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+#     "kernels",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+from kernels import get_kernel
+
+hf_kernels_flash_attn3 = get_kernel("kernels-community/flash-attn3")
+
+
+def hf_flash_attention3(query, key, value):
+    return hf_kernels_flash_attn3.flash_attn_func(query, key, value, causal=False)[0]
+
+
+kbt.add(
+    "hf_kernels_flash_attn3",
+    hf_flash_attention3,
+    tags={"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+
+    if device == "cpu":
+        print("HF Kernels Flash Attention 3 requires CUDA - skipping benchmark")
+        sys.exit(0)
+
+    dtype = "bfloat16"
+
+    # Flux-like workloads
+    base = 1024
+    flux_sizes = [128, 256, 320, 384, 448, 512]
+    heads = 24
+    head_dim = 128
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +hf_kernels_flash_attn3 flux_L128 0.28 True +hf_kernels_flash_attn3 flux_L256 0.34 True +hf_kernels_flash_attn3 flux_L320 0.36 True +hf_kernels_flash_attn3 flux_L384 0.37 True +hf_kernels_flash_attn3 flux_L448 0.40 True +hf_kernels_flash_attn3 flux_L512 0.43 True +
+
+
▶ UV Install Logs
+ +
+
Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s] +Fetching 4 files: 25%|██▌ | 1/4 [00:00<00:00, 3.56it/s] +Fetching 4 files: 50%|█████ | 2/4 [00:02<00:02, 1.32s/it] +Fetching 4 files: 100%|██████████| 4/4 [00:02<00:00, 1.72it/s]
+
+

Artifacts:

+attn.jsonl +
+
+
+
+ + + \ No newline at end of file diff --git a/flash_attn/impls/index.html b/flash_attn/impls/index.html new file mode 100644 index 0000000000000000000000000000000000000000..fbf5a24033c4d68d64ce4f1be0ba13266dd4e89d --- /dev/null +++ b/flash_attn/impls/index.html @@ -0,0 +1,94 @@ + + + + + + Index of /flash_attn/impls + + + +
+ ← back +
+

Index of /flash_attn/impls

+ + + \ No newline at end of file diff --git a/flash_attn/impls/mem_efficient_attention.html b/flash_attn/impls/mem_efficient_attention.html new file mode 100644 index 0000000000000000000000000000000000000000..3400970cc242db031fdf36b6c82329d15913de75 --- /dev/null +++ b/flash_attn/impls/mem_efficient_attention.html @@ -0,0 +1,3998 @@ + + + + + + mem_efficient_attention + + + + + + + +
+
+ + ← back + +
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31 +
+
+ +
+

Memory Efficient Attention Implementation

+

Memory Efficient SDPA Benchmark

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 39.23s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+
+
+def torch_mem_eff(q, k, v):
+    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
+    with torch.nn.attention.sdpa_kernel(
+        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
+    ):
+        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
+    return o.transpose(1, 2).contiguous()
+
+kbt.add(
+    "torch_mem_eff",
+    torch_mem_eff,
+    tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dtype = "float32" if device == "cpu" else "bfloat16"
+
+    # Flux-like workloads scaled down for CPU testing
+    base = 1024 if device == "cuda" else 512
+    flux_sizes = (
+        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+    )
+    heads = 24 if device == "cuda" else 8
+    head_dim = 128 if device == "cuda" else 64
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +torch_mem_eff flux_L128 0.48 True +torch_mem_eff flux_L256 0.63 True +torch_mem_eff flux_L320 0.70 True +torch_mem_eff flux_L384 0.83 True +torch_mem_eff flux_L448 0.95 True +torch_mem_eff flux_L512 1.00 True +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+attn.jsonl +
+
+
+
+ + + \ No newline at end of file diff --git a/flash_attn/impls/sage_attention.html b/flash_attn/impls/sage_attention.html new file mode 100644 index 0000000000000000000000000000000000000000..20d60b76924d107347f498604c5ef8072e658fda --- /dev/null +++ b/flash_attn/impls/sage_attention.html @@ -0,0 +1,4022 @@ + + + + + + sage_attention + + + + + + + +
+
+ + ← back + +
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31 +
+
+ +
+

SageAttention Implementation

+

SageAttention Benchmark (INT8 Quantized)

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 41.27s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels",
+#     "kernels-benchmark-tools",
+#     "sageattention",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+# from sageattention import sageattn_qk_int8_pv_fp16_cuda
+
+
+# def sage_attention(q, k, v):
+#     """SageAttention with INT8 Q/K quantization and FP16 P/V"""
+#     return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout="NHD")
+
+from kernels import get_kernel
+
+hf_kernels_sage_attn = get_kernel("kernels-community/sage_attention")
+
+
+def sage_attention(query, key, value):
+    """HuggingFace Kernels Flash Attention"""
+    return hf_kernels_sage_attn.fwd(query, key, value, is_causal=False)[0]
+
+kbt.add(
+    "sage_int8_fp16",
+    sage_attention,
+    tags={"family": "sageattention", "backend": "int8_fp16_cuda", "compile": "none"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+
+    if device == "cpu":
+        print("SageAttention requires CUDA - skipping benchmark")
+        sys.exit(0)
+
+    dtype = "bfloat16"
+
+    # Flux-like workloads
+    base = 1024
+    flux_sizes = [128, 256, 320, 384, 448, 512]
+    heads = 24
+    head_dim = 128
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +sage_int8_fp16 flux_L128 FAIL False + Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd' +sage_int8_fp16 flux_L256 FAIL False + Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd' +sage_int8_fp16 flux_L320 FAIL False + Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd' +sage_int8_fp16 flux_L384 FAIL False + Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd' +sage_int8_fp16 flux_L448 FAIL False + Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd' +sage_int8_fp16 flux_L512 FAIL False + Error: module 'sage_attention_a8eb63760f50ebd' has no attribute 'fwd' +
+
+
▶ UV Install Logs
+ +
+
Fetching 11 files: 0%| | 0/11 [00:00<?, ?it/s] +Fetching 11 files: 9%|▉ | 1/11 [00:00<00:05, 1.85it/s] +Fetching 11 files: 45%|████▌ | 5/11 [00:00<00:00, 6.46it/s] +Fetching 11 files: 73%|███████▎ | 8/11 [00:01<00:00, 10.07it/s] +Fetching 11 files: 100%|██████████| 11/11 [00:01<00:00, 10.94it/s]
+
+

Artifacts:

+attn.jsonl +
+
+
+
+ + + \ No newline at end of file diff --git a/flash_attn/impls/xformers.html b/flash_attn/impls/xformers.html new file mode 100644 index 0000000000000000000000000000000000000000..1b612e671c94677dfc3925bd110cbceb74640fed --- /dev/null +++ b/flash_attn/impls/xformers.html @@ -0,0 +1,4000 @@ + + + + + + xformers + + + + + + + +
+
+ + ← back + +
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31 +
+
+ +
+

xFormers Memory Efficient Attention

+

xFormers Benchmark

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 41.87s + | + +Raw +
+
+
+
# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "numpy",
+#     "torch",
+#     "kernels-benchmark-tools",
+#     "xformers",
+# ]
+#
+# [tool.uv.sources]
+# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
+# ///
+import torch
+import sys
+import os
+import kernels_benchmark_tools as kbt
+import xformers.ops as xops
+
+
+def xformers_attention(q, k, v):
+    """xFormers memory efficient attention"""
+    # xFormers expects [batch, seq_len, heads, head_dim]
+    return xops.memory_efficient_attention(q, k, v)
+
+
+kbt.add(
+    "xformers_meff",
+    xformers_attention,
+    tags={"family": "xformers", "backend": "memory_efficient", "compile": "none"},
+)
+
+if __name__ == "__main__":
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dtype = "float32" if device == "cpu" else "bfloat16"
+
+    # Flux-like workloads
+    base = 1024 if device == "cuda" else 512
+    flux_sizes = (
+        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
+    )
+    heads = 24 if device == "cuda" else 8
+    head_dim = 128 if device == "cuda" else 64
+
+    wl = []
+    for L in flux_sizes:
+        wl.append(
+            {
+                "name": f"flux_L{L}",
+                "batch": 1,
+                "seq_len": base + L,
+                "heads": heads,
+                "head_dim": head_dim,
+                "dtype": dtype,
+                "device": device,
+                "seed": 0,
+            }
+        )
+
+    kbt.run(
+        wl,
+        jsonl="attn.jsonl",
+        reps=5,
+        warmup=2,
+        gen=kbt.attn.gen_qkv,
+        ref=kbt.attn.ref_math,
+        cmp=kbt.attn.cmp_allclose,
+    )
+    kbt.summarize(["attn.jsonl"])
+
+ +
+
+
+
+
impl wl p50(ms) ok +xformers_meff flux_L128 0.35 True +xformers_meff flux_L256 0.41 True +xformers_meff flux_L320 0.43 True +xformers_meff flux_L384 0.44 True +xformers_meff flux_L448 0.48 True +xformers_meff flux_L512 0.50 True +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+attn.jsonl +
+
+
+
+ + + \ No newline at end of file diff --git a/flash_attn/index.html b/flash_attn/index.html new file mode 100644 index 0000000000000000000000000000000000000000..eea7df846d9f2d44c6c6e03a5ac30d00cecd90cf --- /dev/null +++ b/flash_attn/index.html @@ -0,0 +1,89 @@ + + + + + + Index of /flash_attn + + + +
+ ← back +
+

Index of /flash_attn

+ + + \ No newline at end of file diff --git a/flash_attn/results/artifacts/combine/latency.csv b/flash_attn/results/artifacts/combine/latency.csv new file mode 100644 index 0000000000000000000000000000000000000000..fe7bd55163eff0be13ab95a18cadede44a50adaf --- /dev/null +++ b/flash_attn/results/artifacts/combine/latency.csv @@ -0,0 +1,43 @@ +Implementation,Impl ID,Workload,Batch,Seq Length,Heads,Head Dim,Dtype,Mean (ms),P10 (ms),P50 (ms),P90 (ms),Reps,Peak Mem (MB),Backend,Family +Flash (PyTorch SDPA),torch_flash_ma,flux_L128,1,1152,24,128,bfloat16,0.407123202085495,0.40537598729133606,0.40755200386047363,0.407584011554718,5,83.38,FLASH,torch-sdpa +Flash (PyTorch SDPA),torch_flash_ma,flux_L256,1,1280,24,128,bfloat16,0.5235007882118226,0.5212159752845764,0.5232639908790588,0.523360013961792,5,90.62,FLASH,torch-sdpa +Flash (PyTorch SDPA),torch_flash_ma,flux_L320,1,1344,24,128,bfloat16,0.545849597454071,0.5418559908866882,0.5468159914016724,0.5469120144844055,5,95.06,FLASH,torch-sdpa +Flash (PyTorch SDPA),torch_flash_ma,flux_L384,1,1408,24,128,bfloat16,0.5892416119575501,0.5867519974708557,0.5888000130653381,0.5888000130653381,5,99.88,FLASH,torch-sdpa +Flash (PyTorch SDPA),torch_flash_ma,flux_L448,1,1472,24,128,bfloat16,0.6449280023574829,0.6430720090866089,0.6442239880561829,0.6450240015983582,5,103.81,FLASH,torch-sdpa +Flash (PyTorch SDPA),torch_flash_ma,flux_L512,1,1536,24,128,bfloat16,0.6823423862457275,0.6777600049972534,0.6809599995613098,0.6818559765815735,5,109.12,FLASH,torch-sdpa +MemEff (PyTorch SDPA),torch_mem_eff,flux_L128,1,1152,24,128,bfloat16,0.48371200561523436,0.4821760058403015,0.4833280146121979,0.4853760004043579,5,83.38,EFFICIENT,torch-sdpa +MemEff (PyTorch SDPA),torch_mem_eff,flux_L256,1,1280,24,128,bfloat16,0.6268800020217895,0.6246399879455566,0.6266880035400391,0.6286720037460327,5,90.62,EFFICIENT,torch-sdpa +MemEff (PyTorch SDPA),torch_mem_eff,flux_L320,1,1344,24,128,bfloat16,0.699776005744934,0.6973440051078796,0.7004160284996033,0.7004479765892029,5,95.94,EFFICIENT,torch-sdpa +MemEff (PyTorch SDPA),torch_mem_eff,flux_L384,1,1408,24,128,bfloat16,0.8333312034606933,0.8284159898757935,0.8325120210647583,0.8376320004463196,5,100.0,EFFICIENT,torch-sdpa +MemEff (PyTorch SDPA),torch_mem_eff,flux_L448,1,1472,24,128,bfloat16,0.9533439993858337,0.9502720236778259,0.9512959718704224,0.9572479724884033,5,103.81,EFFICIENT,torch-sdpa +MemEff (PyTorch SDPA),torch_mem_eff,flux_L512,1,1536,24,128,bfloat16,1.0066367864608765,1.0024960041046143,1.0045440196990967,1.0097919702529907,5,109.12,EFFICIENT,torch-sdpa +xFormers,xformers_meff,flux_L128,1,1152,24,128,bfloat16,0.3452928066253662,0.3389439880847931,0.3461120128631592,0.3461120128631592,5,83.38,memory_efficient,xformers +xFormers,xformers_meff,flux_L256,1,1280,24,128,bfloat16,0.41234560012817384,0.40959998965263367,0.41280001401901245,0.41286399960517883,5,90.62,memory_efficient,xformers +xFormers,xformers_meff,flux_L320,1,1344,24,128,bfloat16,0.4366208016872406,0.4310399889945984,0.4331519901752472,0.4362240135669708,5,95.06,memory_efficient,xformers +xFormers,xformers_meff,flux_L384,1,1408,24,128,bfloat16,0.4450624048709869,0.4359680116176605,0.44361600279808044,0.447488009929657,5,99.88,memory_efficient,xformers +xFormers,xformers_meff,flux_L448,1,1472,24,128,bfloat16,0.4750400006771088,0.4711039960384369,0.47513601183891296,0.4763199985027313,5,103.81,memory_efficient,xformers +xFormers,xformers_meff,flux_L512,1,1536,24,128,bfloat16,0.5009407997131348,0.49663999676704407,0.4997119903564453,0.5038080215454102,5,109.12,memory_efficient,xformers +Compiled (default),torch_flash_compiled_default,flux_L128,1,1152,24,128,bfloat16,0.3856383919715881,0.3563520014286041,0.35942399501800537,0.3624959886074066,5,83.38,FLASH,torch-sdpa +Compiled (default),torch_flash_compiled_default,flux_L256,1,1280,24,128,bfloat16,0.4982912003993988,0.4926080107688904,0.49663999676704407,0.5017600059509277,5,90.62,FLASH,torch-sdpa +Compiled (default),torch_flash_compiled_default,flux_L320,1,1344,24,128,bfloat16,0.5369919896125793,0.5335040092468262,0.5366079807281494,0.5386239886283875,5,95.25,FLASH,torch-sdpa +Compiled (default),torch_flash_compiled_default,flux_L384,1,1408,24,128,bfloat16,0.5841408014297486,0.5775359869003296,0.5868800282478333,0.5877760052680969,5,99.88,FLASH,torch-sdpa +Compiled (default),torch_flash_compiled_default,flux_L448,1,1472,24,128,bfloat16,0.6184704065322876,0.6072319746017456,0.6113280057907104,0.6144000291824341,5,103.81,FLASH,torch-sdpa +Compiled (default),torch_flash_compiled_default,flux_L512,1,1536,24,128,bfloat16,0.6428672075271606,0.6399999856948853,0.6430720090866089,0.6430720090866089,5,109.12,FLASH,torch-sdpa +Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L128,1,1152,24,128,bfloat16,0.40020479559898375,0.3665919899940491,0.3768320083618164,0.41171199083328247,5,81.75,FLASH,torch-sdpa +Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L256,1,1280,24,128,bfloat16,0.5535807967185974,0.5160959959030151,0.5489599704742432,0.5631359815597534,5,92.88,FLASH,torch-sdpa +Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L320,1,1344,24,128,bfloat16,0.6143999934196472,0.562175989151001,0.6144000291824341,0.6318079829216003,5,95.13,FLASH,torch-sdpa +Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L384,1,1408,24,128,bfloat16,0.6754495978355408,0.6512640118598938,0.6584320068359375,0.6799359917640686,5,97.13,FLASH,torch-sdpa +Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L448,1,1472,24,128,bfloat16,0.7210752129554748,0.6973119974136353,0.7014080286026001,0.7229440212249756,5,99.0,FLASH,torch-sdpa +Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L512,1,1536,24,128,bfloat16,0.7735359907150269,0.7485439777374268,0.7557439804077148,0.7710719704627991,5,101.63,FLASH,torch-sdpa +HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L128,1,1152,24,128,bfloat16,0.2456959992647171,0.24371199309825897,0.24566400051116943,0.2457599937915802,5,83.38,flash-attn,hf-kernels +HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L256,1,1280,24,128,bfloat16,0.3215551972389221,0.3164159953594208,0.319487988948822,0.32051199674606323,5,90.62,flash-attn,hf-kernels +HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L320,1,1344,24,128,bfloat16,0.3384703993797302,0.33670398592948914,0.33792001008987427,0.33983999490737915,5,95.06,flash-attn,hf-kernels +HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L384,1,1408,24,128,bfloat16,0.3510208010673523,0.3481599986553192,0.3491840064525604,0.35225600004196167,5,99.88,flash-attn,hf-kernels +HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L448,1,1472,24,128,bfloat16,0.3829823970794678,0.38095998764038086,0.3829759955406189,0.3840000033378601,5,103.81,flash-attn,hf-kernels +HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L512,1,1536,24,128,bfloat16,0.4259391903877258,0.4227519929409027,0.4249599874019623,0.4259839951992035,5,109.12,flash-attn,hf-kernels +HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L128,1,1152,24,128,bfloat16,0.2755008041858673,0.26736000180244446,0.27561599016189575,0.27955201268196106,5,83.38,flash-attn3,hf-kernels +HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L256,1,1280,24,128,bfloat16,0.3397440016269684,0.3368000090122223,0.3399679958820343,0.34191998839378357,5,90.62,flash-attn3,hf-kernels +HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L320,1,1344,24,128,bfloat16,0.36019839644432067,0.3563520014286041,0.3604480028152466,0.36137598752975464,5,95.06,flash-attn3,hf-kernels +HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L384,1,1408,24,128,bfloat16,0.37342079877853396,0.3718400001525879,0.37379199266433716,0.3746879994869232,5,99.88,flash-attn3,hf-kernels +HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L448,1,1472,24,128,bfloat16,0.4024448037147522,0.3993600010871887,0.4014720022678375,0.4034560024738312,5,103.81,flash-attn3,hf-kernels +HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L512,1,1536,24,128,bfloat16,0.4305088043212891,0.4270080029964447,0.4291520118713379,0.4331519901752472,5,109.12,flash-attn3,hf-kernels diff --git a/flash_attn/results/artifacts/combine/latency.svg b/flash_attn/results/artifacts/combine/latency.svg new file mode 100644 index 0000000000000000000000000000000000000000..ea5c5865f5ff0718bc77d9c64aa425433567eef1 --- /dev/null +++ b/flash_attn/results/artifacts/combine/latency.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bb668989495a1f179dd65f7426ee2a611f8dc193219b9a7385c79f6701d161a +size 29810 diff --git a/flash_attn/results/cells/combine.py b/flash_attn/results/cells/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..f703ae3d1403d602560c9d3b36d51fab69b7f3a5 --- /dev/null +++ b/flash_attn/results/cells/combine.py @@ -0,0 +1,319 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "numpy", +# "torch", +# "kernels-benchmark-tools", +# "matplotlib", +# ] +# +# [tool.uv.sources] +# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } +# /// +import os +import sys +from pathlib import Path +import json +import torch # noqa: F401 # imported because upstream may expect torch to be importable +import kernels_benchmark_tools as kbt + +# --- Matplotlib setup and helpers ------------------------------------------------ +import matplotlib as mpl +import matplotlib.pyplot as plt +import csv + + +# Keep text as text (not paths) so CSS can style fonts, size, etc. +mpl.rcParams["svg.fonttype"] = "none" +# Make ids deterministic across builds +mpl.rcParams["svg.hashsalt"] = "latency-benchmark-combined" +# Avoid auto-closed figures interfering with our tagging +mpl.rcParams["figure.autolayout"] = True +# Make background transparent +mpl.rcParams["figure.facecolor"] = "none" +mpl.rcParams["axes.facecolor"] = "none" +mpl.rcParams["savefig.facecolor"] = "none" +mpl.rcParams["savefig.edgecolor"] = "none" + +def _slugify(s: str) -> str: + s = (s or "").strip().lower() + keep = [] + for ch in s: + if ch.isalnum(): + keep.append(ch) + elif ch in (" ", "-", "_", "/", ".", ":"): + keep.append("-") + else: + keep.append("") + out = "".join(keep) + while "--" in out: + out = out.replace("--", "-") + return out.strip("-") or "unnamed" + +def _tag_current_figure(default_series_prefix="series"): + """Attach SVG ids (gid) to key artists so they can be targeted from CSS.""" + fig = plt.gcf() + if fig is None: + return + + # Tag the figure itself + fig.set_gid("figure--latency") + + for ax_idx, ax in enumerate(fig.get_axes(), start=1): + ax.set_gid(f"axes--{ax_idx}") + + # Axis labels & title + if ax.get_title(): + for t in ax.texts: + if t.get_text() == ax.get_title(): + t.set_gid("title--main") + if ax.xaxis and ax.xaxis.get_label(): + ax.xaxis.label.set_gid("label--x") + if ax.yaxis and ax.yaxis.get_label(): + ax.yaxis.label.set_gid("label--y") + + # Gridlines + for i, gl in enumerate(ax.get_xgridlines(), start=1): + gl.set_gid(f"grid-x--{i}") + for i, gl in enumerate(ax.get_ygridlines(), start=1): + gl.set_gid(f"grid-y--{i}") + + # Legend block & entries + leg = ax.get_legend() + if leg is not None: + leg.set_gid("legend") + for i, txt in enumerate(leg.get_texts(), start=1): + label_slug = _slugify(txt.get_text()) + txt.set_gid(f"legend-label--{label_slug or i}") + + # Series (lines, patches) + # Lines + line_seen = {} + for ln in getattr(ax, "lines", []): + raw_label = ln.get_label() or "" + # Matplotlib uses labels beginning with "_" for non-legendable items + label = raw_label if not raw_label.startswith("_") else f"{default_series_prefix}" + slug = _slugify(label) + line_seen[slug] = line_seen.get(slug, 0) + 1 + suffix = "" if line_seen[slug] == 1 else f"-{line_seen[slug]}" + ln.set_gid(f"series--{slug}{suffix}") + + # Patches (bars, areas) + patch_seen = {} + for pt in getattr(ax, "patches", []): + label = getattr(pt, "get_label", lambda: "")() or f"{default_series_prefix}" + if isinstance(label, str) and label.startswith("_"): + label = default_series_prefix + slug = _slugify(label) + patch_seen[slug] = patch_seen.get(slug, 0) + 1 + suffix = "" if patch_seen[slug] == 1 else f"-{patch_seen[slug]}" + pt.set_gid(f"series--{slug}{suffix}") + +def _postprocess_svg_add_classes(svg_path: Path): + """Add convenient CSS classes alongside ids (e.g., class='series grid grid-x').""" + try: + import xml.etree.ElementTree as ET + ET.register_namespace("", "http://www.w3.org/2000/svg") + tree = ET.parse(svg_path) + root = tree.getroot() + for el in root.iter(): + el_id = el.attrib.get("id", "") + if not el_id: + continue + cls = [] + if el_id.startswith("figure--"): + cls.append("figure") + elif el_id.startswith("axes--"): + cls.append("axes") + elif el_id.startswith("grid-x--"): + cls += ["grid", "grid-x"] + elif el_id.startswith("grid-y--"): + cls += ["grid", "grid-y"] + elif el_id.startswith("legend"): + cls.append("legend") + elif el_id.startswith("label--x"): + cls.append("xlabel") + elif el_id.startswith("label--y"): + cls.append("ylabel") + elif el_id.startswith("title--"): + cls.append("title") + elif el_id.startswith("series--"): + cls.append("series") + if cls: + # Preserve any existing class (unlikely from Matplotlib) + existing = el.attrib.get("class", "") + el.set("class", (existing + " " + " ".join(cls)).strip()) + tree.write(svg_path, encoding="utf-8", xml_declaration=True) + except Exception as e: + print(f"✗ SVG postprocess (classes) skipped: {e}") + +# Monkey-patch savefig to force SVG & ensure tagging occurs even if kbt.viz saves internally. +_orig_savefig = plt.savefig +def _savefig_svg(fname, *args, **kwargs): + # Always save as SVG at a stable path for the artifact system + out = Path("latency.svg") + kwargs["format"] = "svg" + # Ensure everything we care about has ids before export + _tag_current_figure() + res = _orig_savefig(out, *args, **kwargs) + # Add helpful CSS classes on top of ids + _postprocess_svg_add_classes(out) + print(f"✓ Combined visualization saved as {out}") + return res + +plt.savefig = _savefig_svg # apply patch + +# Capture close calls in case kbt.viz() closes figures before we re-save +_orig_close = plt.close +_last_closed = {"fig": None} +def _capture_close(arg=None): + try: + if hasattr(arg, "savefig"): # looks like a Figure + _last_closed["fig"] = arg + else: + _last_closed["fig"] = plt.gcf() + finally: + return _orig_close(arg) +plt.close = _capture_close + +# --- Locate benchmark artifacts -------------------------------------------------- +cache_dirs = { + "Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'), + "MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'), + "Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'), + "xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'), + "SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'), + "Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'), + "Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'), + "HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'), + "HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'), +} + +print("LOADING BENCHMARK DATA") +for name, cache_dir in cache_dirs.items(): + print(f"{name:30s}: {cache_dir}") +print() + +file_mapping = { + "Flash (PyTorch SDPA)": "attn.jsonl", + "MemEff (PyTorch SDPA)": "attn.jsonl", + "Flash Attn 2": "attn.jsonl", + "xFormers": "attn.jsonl", + "SageAttention": "attn.jsonl", + "Compiled (default)": "attn_default.jsonl", + "Compiled (max-autotune)": "attn_max_autotune.jsonl", + "HF Kernels Flash Attn": "attn.jsonl", + "HF Kernels Flash Attn3": "attn.jsonl", +} + +all_paths = [] +for name, cache_dir in cache_dirs.items(): + if cache_dir: + path = Path(cache_dir) / file_mapping[name] + if path.exists() and path.stat().st_size > 0: + all_paths.append(str(path)) + print(f"✓ Found {name}: {path}") + else: + print(f"⊘ Empty/Missing {name}: {path}") + else: + print(f"✗ No cache dir for {name}") +print() + +if not all_paths: + print("ERROR: No benchmark data files found!") + # restore patched functions before exiting + plt.savefig = _orig_savefig + plt.close = _orig_close + sys.exit(1) + +# --- Summary + Visualization ----------------------------------------------------- +print("COMBINED BENCHMARK SUMMARY\n") +kbt.summarize(all_paths) +print("\nGENERATING COMBINED VISUALIZATION\n") + +try: + # If kbt.viz saves internally, our patched savefig ensures SVG gets written, + # and it will carry ids/classes for CSS styling. + kbt.viz(all_paths) + # Safety net: if kbt.viz didn't save, save now. + # if not Path("latency.svg").exists(): + # _tag_current_figure() + # plt.savefig("latency.svg") + + plt.savefig("latency.svg") # ensure saved with tagging + + print("✓ SVG visualization ready: latency.svg!") +except ImportError as e: + print(f"✗ Visualization requires matplotlib: {e}") +except Exception as e: + print(f"✗ Visualization failed: {e}") +finally: + # Clean up patches to avoid side effects in later cells + plt.savefig = _orig_savefig + plt.close = _orig_close + +print() +print("ANALYSIS COMPLETE") +print(f"Total implementations analyzed: {len(all_paths)}") +print(f"\nImplementations included:") +for name, cache_dir in cache_dirs.items(): + if cache_dir: + path = Path(cache_dir) / file_mapping[name] + if path.exists() and path.stat().st_size > 0: + print(f" ✓ {name}") + + + +# Collect all benchmark data and export to CSV +all_data = {} +for name, cache_dir in cache_dirs.items(): + if cache_dir: + path = Path(cache_dir) / file_mapping[name] + if path.exists() and path.stat().st_size > 0: + with open(path, 'r') as f: + records = [json.loads(line) for line in f] + all_data[name] = records + +# Export to CSV +csv_path = Path("latency.csv") +with open(csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + + # Write header + header = ["Implementation", "Impl ID", "Workload", "Batch", "Seq Length", "Heads", "Head Dim", "Dtype", + "Mean (ms)", "P10 (ms)", "P50 (ms)", "P90 (ms)", "Reps", + # "Compile (ms)", + "Peak Mem (MB)", "Backend", "Family"] + writer.writerow(header) + + # Write data rows + for impl_name, records in all_data.items(): + for record in records: + wl = record.get('wl', {}) + lat = record.get('lat_ms', {}) + tags = record.get('tags', {}) + + row = [ + impl_name, + record.get('impl', ''), + wl.get('name', ''), + wl.get('batch', ''), + wl.get('seq_len', ''), + wl.get('heads', ''), + wl.get('head_dim', ''), + wl.get('dtype', ''), + lat.get('mean', ''), + lat.get('p10', ''), + lat.get('p50', ''), + lat.get('p90', ''), + lat.get('reps', ''), + # record.get('compile_ms', ''), + round(record.get('peak_bytes', 0) / 1024 / 1024, 2) if record.get('peak_bytes') else '', + tags.get('backend', ''), + tags.get('family', ''), + ] + writer.writerow(row) + +print(f"✓ CSV export complete: {csv_path}") +print(f"Total implementations: {len(all_data)}") +print(f"Total records: {sum(len(records) for records in all_data.values())}") diff --git a/flash_attn/results/cells/csv_export.py b/flash_attn/results/cells/csv_export.py new file mode 100644 index 0000000000000000000000000000000000000000..403d9439ac641761a863de39cb66f18d4e0e924e --- /dev/null +++ b/flash_attn/results/cells/csv_export.py @@ -0,0 +1,76 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "numpy", +# "torch", +# "kernels-benchmark-tools", +# ] +# +# [tool.uv.sources] +# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } +# /// +import os +import csv +from pathlib import Path +import json + +# --- Locate benchmark artifacts -------------------------------------------------- +cache_dirs = { + "Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'), + "MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'), + "Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'), + "xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'), + "SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'), + "Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'), + "Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'), + "HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'), + "HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'), +} + +file_mapping = { + "Flash (PyTorch SDPA)": "attn.jsonl", + "MemEff (PyTorch SDPA)": "attn.jsonl", + "Flash Attn 2": "attn.jsonl", + "xFormers": "attn.jsonl", + "SageAttention": "attn.jsonl", + "Compiled (default)": "attn_default.jsonl", + "Compiled (max-autotune)": "attn_max_autotune.jsonl", + "HF Kernels Flash Attn": "attn.jsonl", + "HF Kernels Flash Attn3": "attn.jsonl", +} + +# Collect all benchmark data +all_data = {} +for name, cache_dir in cache_dirs.items(): + if cache_dir: + path = Path(cache_dir) / file_mapping[name] + if path.exists() and path.stat().st_size > 0: + with open(path, 'r') as f: + records = [json.loads(line) for line in f] + all_data[name] = records + +# Export to CSV +csv_path = Path("latency.csv") +with open(csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + + # Write header + header = ["Implementation", "Sequence Length", "Latency (ms)", "Min (ms)", "Max (ms)", "Median (ms)"] + writer.writerow(header) + + # Write data rows + for impl_name, records in all_data.items(): + for record in records: + row = [ + impl_name, + record.get('seqlen', ''), + record.get('latency', ''), + record.get('min', ''), + record.get('max', ''), + record.get('median', ''), + ] + writer.writerow(row) + +print(f"✓ CSV export complete: {csv_path}") +print(f"Total implementations: {len(all_data)}") +print(f"Total records: {sum(len(records) for records in all_data.values())}") \ No newline at end of file diff --git a/flash_attn/results/combined_results.html b/flash_attn/results/combined_results.html new file mode 100644 index 0000000000000000000000000000000000000000..6af794548eaa21127d41882494b0c1ec263c9bd6 --- /dev/null +++ b/flash_attn/results/combined_results.html @@ -0,0 +1,7372 @@ + + + + + + Flash Attention Benchmark - Combined Results + + + + + + + +
+
+ + ← back + +
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-5.15.0-1084-aws-x86_64-with-glibc2.31 +
+
+ +
+

Flash Attention Benchmarks - Aggregated Results

+

This document combines benchmark results from multiple attention implementations +using cross-file dependencies.

+

Combined Summary and Visualization

+
+ + + + + + + 2025-10-14T20:33:46.842309 + image/svg+xml + + + Matplotlib v3.10.7, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + flux_L128 + + + + + + + + + + + + + flux_L256 + + + + + + + + + + + + + flux_L320 + + + + + + + + + + + + + flux_L384 + + + + + + + + + + + + + flux_L448 + + + + + + + + + + + + + flux_L512 + + + + Workload + + + + + + + + + + + + + + + + + 0.3 + + + + + + + + + + + + + 0.4 + + + + + + + + + + + + + 0.5 + + + + + + + + + + + + + 0.6 + + + + + + + + + + + + + 0.7 + + + + + + + + + + + + + 0.8 + + + + + + + + + + + + + 0.9 + + + + + + + + + + + + + 1.0 + + + + Latency P50 (ms) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Attention Implementation Latency + + + + + + + + + + + + + torch_flash_ma + + + + + + + + + torch_mem_eff + + + + + + + + + xformers_meff + + + + + + + + + torch_flash_compiled_default + + + + + + + + + torch_flash_compiled_max_autotune + + + + + + + + + hf_kernels_flash_attn + + + + + + + + + hf_kernels_flash_attn3 + + + + + + + + + + +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ImplementationImpl IDWorkloadBatchSeq LengthHeadsHead DimDtypeMean (ms)P10 (ms)P50 (ms)P90 (ms)RepsPeak Mem (MB)BackendFamily
Flash (PyTorch SDPA)torch_flash_maflux_L1281115224128bfloat160.4071232020854950.405375987291336060.407552003860473630.407584011554718583.38FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L2561128024128bfloat160.52350078821182260.52121597528457640.52326399087905880.523360013961792590.62FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L3201134424128bfloat160.5458495974540710.54185599088668820.54681599140167240.5469120144844055595.06FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L3841140824128bfloat160.58924161195755010.58675199747085570.58880001306533810.5888000130653381599.88FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L4481147224128bfloat160.64492800235748290.64307200908660890.64422398805618290.64502400159835825103.81FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L5121153624128bfloat160.68234238624572750.67776000499725340.68095999956130980.68185597658157355109.12FLASHtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L1281115224128bfloat160.483712005615234360.48217600584030150.48332801461219790.4853760004043579583.38EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L2561128024128bfloat160.62688000202178950.62463998794555660.62668800354003910.6286720037460327590.62EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L3201134424128bfloat160.6997760057449340.69734400510787960.70041602849960330.7004479765892029595.94EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L3841140824128bfloat160.83333120346069330.82841598987579350.83251202106475830.83763200044631965100.0EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L4481147224128bfloat160.95334399938583370.95027202367782590.95129597187042240.95724797248840335103.81EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L5121153624128bfloat161.00663678646087651.00249600410461431.00454401969909671.00979197025299075109.12EFFICIENTtorch-sdpa
xFormersxformers_meffflux_L1281115224128bfloat160.34529280662536620.33894398808479310.34611201286315920.3461120128631592583.38memory_efficientxformers
xFormersxformers_meffflux_L2561128024128bfloat160.412345600128173840.409599989652633670.412800014019012450.41286399960517883590.62memory_efficientxformers
xFormersxformers_meffflux_L3201134424128bfloat160.43662080168724060.43103998899459840.43315199017524720.4362240135669708595.06memory_efficientxformers
xFormersxformers_meffflux_L3841140824128bfloat160.44506240487098690.43596801161766050.443616002798080440.447488009929657599.88memory_efficientxformers
xFormersxformers_meffflux_L4481147224128bfloat160.47504000067710880.47110399603843690.475136011838912960.47631999850273135103.81memory_efficientxformers
xFormersxformers_meffflux_L5121153624128bfloat160.50094079971313480.496639996767044070.49971199035644530.50380802154541025109.12memory_efficientxformers
Compiled (default)torch_flash_compiled_defaultflux_L1281115224128bfloat160.38563839197158810.35635200142860410.359423995018005370.3624959886074066583.38FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L2561128024128bfloat160.49829120039939880.49260801076889040.496639996767044070.5017600059509277590.62FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L3201134424128bfloat160.53699198961257930.53350400924682620.53660798072814940.5386239886283875595.25FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L3841140824128bfloat160.58414080142974860.57753598690032960.58688002824783330.5877760052680969599.88FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L4481147224128bfloat160.61847040653228760.60723197460174560.61132800579071040.61440002918243415103.81FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L5121153624128bfloat160.64286720752716060.63999998569488530.64307200908660890.64307200908660895109.12FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L1281115224128bfloat160.400204795598983750.36659198999404910.37683200836181640.41171199083328247581.75FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L2561128024128bfloat160.55358079671859740.51609599590301510.54895997047424320.5631359815597534592.88FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L3201134424128bfloat160.61439999341964720.5621759891510010.61440002918243410.6318079829216003595.13FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L3841140824128bfloat160.67544959783554080.65126401185989380.65843200683593750.6799359917640686597.13FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L4481147224128bfloat160.72107521295547480.69731199741363530.70140802860260010.7229440212249756599.0FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L5121153624128bfloat160.77353599071502690.74854397773742680.75574398040771480.77107197046279915101.63FLASHtorch-sdpa
HF Kernels Flash Attnhf_kernels_flash_attnflux_L1281115224128bfloat160.24569599926471710.243711993098258970.245664000511169430.2457599937915802583.38flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L2561128024128bfloat160.32155519723892210.31641599535942080.3194879889488220.32051199674606323590.62flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L3201134424128bfloat160.33847039937973020.336703985929489140.337920010089874270.33983999490737915595.06flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L3841140824128bfloat160.35102080106735230.34815999865531920.34918400645256040.35225600004196167599.88flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L4481147224128bfloat160.38298239707946780.380959987640380860.38297599554061890.38400000333786015103.81flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L5121153624128bfloat160.42593919038772580.42275199294090270.42495998740196230.42598399519920355109.12flash-attnhf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L1281115224128bfloat160.27550080418586730.267360001802444460.275615990161895750.27955201268196106583.38flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L2561128024128bfloat160.33974400162696840.33680000901222230.33996799588203430.34191998839378357590.62flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L3201134424128bfloat160.360198396444320670.35635200142860410.36044800281524660.36137598752975464595.06flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L3841140824128bfloat160.373420798778533960.37184000015258790.373791992664337160.3746879994869232599.88flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L4481147224128bfloat160.40244480371475220.39936000108718870.40147200226783750.40345600247383125103.81flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L5121153624128bfloat160.43050880432128910.42700800299644470.42915201187133790.43315199017524725109.12flash-attn3hf-kernels
+
+ +
+
+ +▶ code +▶ output + ▶ uv-logs + | +Cell: combine | 32.84s + | + +Raw +
+ + +
+ + +
+ + + \ No newline at end of file diff --git a/flash_attn/results/index.html b/flash_attn/results/index.html new file mode 100644 index 0000000000000000000000000000000000000000..b87b6002f4b781572dbb50f91850e50ee98130ab --- /dev/null +++ b/flash_attn/results/index.html @@ -0,0 +1,88 @@ + + + + + + Index of /flash_attn/results + + + +
+ ← back +
+

Index of /flash_attn/results

+ + + \ No newline at end of file diff --git a/index.html b/index.html new file mode 100644 index 0000000000000000000000000000000000000000..df44040e2dd9e1e4a0fc2d5ee08453d4b9953f11 --- /dev/null +++ b/index.html @@ -0,0 +1,85 @@ + + + + + + Index of / + + + +

Index of /

+ + + \ No newline at end of file