Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.26s | Raw GitHub
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Mon Oct 27 14:45:45 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   31C    P0            135W /  350W |       0MiB /  46068MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.87s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_flash_ma",
    impl_tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
    impl_func=torch_flash,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.610ms       101.76%       3.610ms       3.610ms             1  
                                         torch_flash_ma         6.54%     340.396us        46.01%       2.394ms       2.394ms       0.000us         0.00%       3.588ms       3.588ms             1  
                     aten::scaled_dot_product_attention         0.84%      43.810us         4.24%     220.593us      73.531us       0.000us         0.00%       2.829ms     943.091us             3  
              aten::_scaled_dot_product_flash_attention         0.51%      26.609us         3.40%     176.783us      58.928us       0.000us         0.00%       2.829ms     943.091us             3  
                         aten::_flash_attention_forward         0.74%      38.381us         2.45%     127.692us      42.564us       2.829ms        79.74%       2.829ms     943.091us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.829ms        79.74%       2.829ms     943.091us             3  
                                       aten::contiguous         0.29%      15.001us        33.86%       1.762ms     146.802us       0.000us         0.00%     759.072us      63.256us            12  
                                            aten::clone         0.76%      39.432us        33.57%       1.747ms     145.552us       0.000us         0.00%     759.072us      63.256us            12  
                                            aten::copy_         1.71%      88.801us        31.26%       1.626ms     135.534us     718.688us        20.26%     759.072us      63.256us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     718.688us        20.26%     718.688us      59.891us            12  
                                Activity Buffer Request        27.68%       1.440ms        27.68%       1.440ms       1.440ms      40.384us         1.14%      40.384us      40.384us             1  
                                        aten::transpose         1.34%      69.973us         1.80%      93.503us       3.896us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.45%      23.530us         0.45%      23.530us       0.980us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.50%      25.908us         1.97%     102.319us       6.821us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.75%      91.041us         1.75%      91.041us       3.793us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.36%     123.031us         2.36%     123.031us       8.202us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.31%      16.010us         0.31%      16.010us       5.337us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.05%       2.700us         0.05%       2.700us       0.450us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.17%       8.980us         0.17%       8.980us       2.993us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        53.99%       2.809ms        53.99%       2.809ms       2.809ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.203ms
Self CUDA time total: 3.548ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.17%     272.917us        42.06%       2.218ms       2.218ms       0.000us         0.00%       3.821ms       3.821ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.777ms       100.28%       3.777ms       3.777ms             1  
                     aten::scaled_dot_product_attention         0.53%      27.761us         3.55%     187.333us      62.444us       0.000us         0.00%       3.004ms       1.001ms             3  
              aten::_scaled_dot_product_flash_attention         0.37%      19.492us         3.03%     159.572us      53.191us       0.000us         0.00%       3.004ms       1.001ms             3  
                         aten::_flash_attention_forward         0.75%      39.549us         2.23%     117.371us      39.124us       3.004ms        79.75%       3.004ms       1.001ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.004ms        79.75%       3.004ms       1.001ms             3  
                                       aten::contiguous         0.20%      10.320us        32.06%       1.691ms     140.876us       0.000us         0.00%     817.314us      68.110us            12  
                                            aten::clone         0.55%      29.048us        31.86%       1.680ms     140.016us       0.000us         0.00%     817.314us      68.110us            12  
                                            aten::copy_         1.64%      86.662us        30.11%       1.588ms     132.347us     762.658us        20.25%     817.314us      68.110us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     762.658us        20.25%     762.658us      63.555us            12  
                                Activity Buffer Request        26.84%       1.415ms        26.84%       1.415ms       1.415ms      54.656us         1.45%      54.656us      54.656us             1  
                                        aten::transpose         1.36%      71.528us         1.71%      90.179us       3.757us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.35%      18.651us         0.35%      18.651us       0.777us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.38%      19.801us         1.55%      81.840us       5.456us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.46%      77.040us         1.46%      77.040us       3.210us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.07%     108.973us         2.07%     108.973us       7.265us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.26%      13.940us         0.26%      13.940us       4.647us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.06%       2.910us         0.06%       2.910us       0.485us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.08%       4.240us         0.08%       4.240us       1.413us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        57.94%       3.056ms        57.94%       3.056ms       3.056ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.274ms
Self CUDA time total: 3.767ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.99%     269.576us        41.89%       2.262ms       2.262ms       0.000us         0.00%       3.875ms       3.875ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.827ms       100.29%       3.827ms       3.827ms             1  
                     aten::scaled_dot_product_attention         0.50%      27.011us         3.47%     187.262us      62.421us       0.000us         0.00%       3.037ms       1.012ms             3  
              aten::_scaled_dot_product_flash_attention         0.35%      18.851us         2.97%     160.251us      53.417us       0.000us         0.00%       3.037ms       1.012ms             3  
                         aten::_flash_attention_forward         0.72%      39.000us         2.20%     118.550us      39.517us       3.037ms        79.57%       3.037ms       1.012ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.037ms        79.57%       3.037ms       1.012ms             3  
                                       aten::contiguous         0.18%       9.780us        32.51%       1.755ms     146.253us       0.000us         0.00%     838.461us      69.872us            12  
                                            aten::clone         0.54%      29.119us        32.32%       1.745ms     145.438us       0.000us         0.00%     838.461us      69.872us            12  
                                            aten::copy_         1.56%      84.200us        30.52%       1.648ms     137.328us     779.741us        20.43%     838.461us      69.872us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     779.741us        20.43%     779.741us      64.978us            12  
                                Activity Buffer Request        27.41%       1.480ms        27.41%       1.480ms       1.480ms      58.720us         1.54%      58.720us      58.720us             1  
                                        aten::transpose         1.00%      54.180us         1.34%      72.500us       3.021us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.34%      18.320us         0.34%      18.320us       0.763us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.36%      19.560us         1.66%      89.381us       5.959us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.53%      82.821us         1.53%      82.821us       3.451us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.99%     107.272us         1.99%     107.272us       7.151us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.30%      16.380us         0.30%      16.380us       5.460us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.850us         0.03%       1.850us       0.308us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       3.830us         0.07%       3.830us       1.277us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        58.11%       3.138ms        58.11%       3.138ms       3.138ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.399ms
Self CUDA time total: 3.817ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.76%     268.853us        43.13%       2.435ms       2.435ms       0.000us         0.00%       3.964ms       3.964ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.917ms       100.30%       3.917ms       3.917ms             1  
                     aten::scaled_dot_product_attention         0.49%      27.720us         3.46%     195.333us      65.111us       0.000us         0.00%       3.118ms       1.039ms             3  
              aten::_scaled_dot_product_flash_attention         0.34%      19.471us         2.97%     167.613us      55.871us       0.000us         0.00%       3.118ms       1.039ms             3  
                         aten::_flash_attention_forward         0.70%      39.530us         2.23%     125.742us      41.914us       3.118ms        79.84%       3.118ms       1.039ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.118ms        79.84%       3.118ms       1.039ms             3  
                                       aten::contiguous         0.17%       9.719us        34.03%       1.921ms     160.116us       0.000us         0.00%     845.599us      70.467us            12  
                                            aten::clone         0.52%      29.239us        33.85%       1.912ms     159.306us       0.000us         0.00%     845.599us      70.467us            12  
                                            aten::copy_         1.54%      86.910us        32.19%       1.818ms     151.460us     787.167us        20.16%     845.599us      70.467us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     787.167us        20.16%     787.167us      65.597us            12  
                                Activity Buffer Request        25.41%       1.435ms        25.41%       1.435ms       1.435ms      58.432us         1.50%      58.432us      58.432us             1  
                                        aten::transpose         0.96%      54.080us         1.28%      72.141us       3.006us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.32%      18.061us         0.32%      18.061us       0.753us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.35%      19.512us         1.49%      84.134us       5.609us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.53%      86.581us         1.53%      86.581us       3.608us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.66%     319.547us         5.66%     319.547us      21.303us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.26%      14.430us         0.26%      14.430us       4.810us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.05%       2.740us         0.05%       2.740us       0.457us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.201us         0.07%       4.201us       1.400us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        56.87%       3.211ms        56.87%       3.211ms       3.211ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.647ms
Self CUDA time total: 3.906ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.25%     320.614us        40.80%       2.490ms       2.490ms       0.000us         0.00%       4.428ms       4.428ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.377ms       100.25%       4.377ms       4.377ms             1  
                     aten::scaled_dot_product_attention         0.44%      26.800us         3.27%     199.713us      66.571us       0.000us         0.00%       3.558ms       1.186ms             3  
              aten::_scaled_dot_product_flash_attention         0.32%      19.239us         2.83%     172.913us      57.638us       0.000us         0.00%       3.558ms       1.186ms             3  
                         aten::_flash_attention_forward         0.64%      38.816us         2.13%     129.963us      43.321us       3.558ms        81.48%       3.558ms       1.186ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.558ms        81.48%       3.558ms       1.186ms             3  
                                       aten::contiguous         0.17%      10.568us        31.48%       1.922ms     160.138us       0.000us         0.00%     870.015us      72.501us            12  
                                            aten::clone         0.48%      29.552us        31.31%       1.911ms     159.257us       0.000us         0.00%     870.015us      72.501us            12  
                                            aten::copy_         1.37%      83.622us        29.71%       1.813ms     151.123us     808.479us        18.52%     870.015us      72.501us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     808.479us        18.52%     808.479us      67.373us            12  
                                Activity Buffer Request        24.07%       1.469ms        24.07%       1.469ms       1.469ms      61.536us         1.41%      61.536us      61.536us             1  
                                        aten::transpose         0.88%      53.494us         1.18%      71.893us       2.996us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.30%      18.399us         0.30%      18.399us       0.767us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.45%      27.388us         1.61%      98.450us       6.563us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.35%      82.243us         1.35%      82.243us       3.427us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.68%     285.943us         4.68%     285.943us      19.063us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.29%      17.820us         0.29%      17.820us       5.940us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.328us         0.04%       2.328us       0.388us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.078us         0.07%       4.078us       1.359us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        59.20%       3.614ms        59.20%       3.614ms       3.614ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.104ms
Self CUDA time total: 4.366ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.45%     272.752us        38.96%       2.390ms       2.390ms       0.000us         0.00%       4.517ms       4.517ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.467ms       100.24%       4.467ms       4.467ms             1  
                     aten::scaled_dot_product_attention         0.45%      27.641us         3.22%     197.213us      65.738us       0.000us         0.00%       3.636ms       1.212ms             3  
              aten::_scaled_dot_product_flash_attention         0.32%      19.841us         2.76%     169.572us      56.524us       0.000us         0.00%       3.636ms       1.212ms             3  
                         aten::_flash_attention_forward         0.71%      43.282us         2.06%     126.092us      42.031us       3.636ms        81.58%       3.636ms       1.212ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.636ms        81.58%       3.636ms       1.212ms             3  
                                       aten::contiguous         0.18%      11.069us        30.46%       1.869ms     155.711us       0.000us         0.00%     881.085us      73.424us            12  
                                            aten::clone         0.50%      30.953us        30.28%       1.857ms     154.789us       0.000us         0.00%     881.085us      73.424us            12  
                                            aten::copy_         1.39%      85.529us        28.66%       1.758ms     146.482us     820.670us        18.42%     881.085us      73.424us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     820.670us        18.42%     820.670us      68.389us            12  
                                Activity Buffer Request        23.40%       1.435ms        23.40%       1.435ms       1.435ms      60.415us         1.36%      60.415us      60.415us             1  
                                        aten::transpose         0.92%      56.138us         1.22%      75.130us       3.130us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.31%      18.992us         0.31%      18.992us       0.791us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.33%      20.287us         1.48%      90.810us       6.054us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.36%      83.613us         1.36%      83.613us       3.484us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.26%     261.175us         4.26%     261.175us      17.412us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.28%      17.260us         0.28%      17.260us       5.753us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.850us         0.03%       1.850us       0.308us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.250us         0.07%       4.250us       1.417us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        61.04%       3.744ms        61.04%       3.744ms       3.744ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.134ms
Self CUDA time total: 4.456ms


impl                     wl                  p50(ms)  ok
torch_flash_ma           cuda_attn_L128_bfloat16     1.22  True
torch_flash_ma           cuda_attn_L256_bfloat16     1.27  True
torch_flash_ma           cuda_attn_L320_bfloat16     1.31  True
torch_flash_ma           cuda_attn_L384_bfloat16     1.34  True
torch_flash_ma           cuda_attn_L448_bfloat16     1.48  True
torch_flash_ma           cuda_attn_L512_bfloat16     1.52  True

Artifacts:

attention.jsonl