Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.21s | Raw GitHub
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Oct 31 20:13:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   43C    P0             83W /  350W |       0MiB /  46068MiB |     11%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.87s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_flash_ma",
    impl_tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
    impl_func=torch_flash,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.600ms       101.99%       3.600ms       3.600ms             1  
                                         torch_flash_ma         6.70%     350.157us        46.68%       2.439ms       2.439ms       0.000us         0.00%       3.570ms       3.570ms             1  
                     aten::scaled_dot_product_attention         0.81%      42.281us         4.26%     222.626us      74.209us       0.000us         0.00%       2.816ms     938.781us             3  
              aten::_scaled_dot_product_flash_attention         0.52%      27.002us         3.45%     180.345us      60.115us       0.000us         0.00%       2.816ms     938.781us             3  
                         aten::_flash_attention_forward         0.79%      41.210us         2.54%     132.453us      44.151us       2.816ms        79.78%       2.816ms     938.781us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.816ms        79.78%       2.816ms     938.781us             3  
                                       aten::contiguous         0.29%      15.041us        34.44%       1.800ms     149.962us       0.000us         0.00%     753.884us      62.824us            12  
                                            aten::clone         0.75%      38.969us        34.15%       1.785ms     148.709us       0.000us         0.00%     753.884us      62.824us            12  
                                            aten::copy_         1.73%      90.324us        31.78%       1.661ms     138.388us     713.788us        20.22%     753.884us      62.824us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     713.788us        20.22%     713.788us      59.482us            12  
                                Activity Buffer Request        28.08%       1.467ms        28.08%       1.467ms       1.467ms      40.096us         1.14%      40.096us      40.096us             1  
                                        aten::transpose         1.25%      65.371us         1.68%      87.543us       3.648us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.42%      22.172us         0.42%      22.172us       0.924us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.53%      27.463us         2.06%     107.524us       7.168us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.78%      93.220us         1.78%      93.220us       3.884us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.49%     130.035us         2.49%     130.035us       8.669us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.32%      16.730us         0.32%      16.730us       5.577us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.05%       2.690us         0.05%       2.690us       0.448us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.17%       9.000us         0.17%       9.000us       3.000us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        53.32%       2.786ms        53.32%       2.786ms       2.786ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.225ms
Self CUDA time total: 3.530ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.88%     260.255us        42.26%       2.252ms       2.252ms       0.000us         0.00%       3.798ms       3.798ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.753ms       100.28%       3.753ms       3.753ms             1  
                     aten::scaled_dot_product_attention         0.49%      25.890us         3.50%     186.735us      62.245us       0.000us         0.00%       2.976ms     991.858us             3  
              aten::_scaled_dot_product_flash_attention         0.33%      17.842us         3.02%     160.845us      53.615us       0.000us         0.00%       2.976ms     991.858us             3  
                         aten::_flash_attention_forward         0.74%      39.289us         2.26%     120.363us      40.121us       2.976ms        79.51%       2.976ms     991.858us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.976ms        79.51%       2.976ms     991.858us             3  
                                       aten::contiguous         0.20%      10.403us        33.03%       1.760ms     146.680us       0.000us         0.00%     822.042us      68.504us            12  
                                            aten::clone         0.53%      28.238us        32.84%       1.750ms     145.813us       0.000us         0.00%     822.042us      68.504us            12  
                                            aten::copy_         1.51%      80.312us        31.12%       1.659ms     138.210us     766.874us        20.49%     822.042us      68.504us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     766.874us        20.49%     766.874us      63.906us            12  
                                Activity Buffer Request        28.02%       1.493ms        28.02%       1.493ms       1.493ms      55.168us         1.47%      55.168us      55.168us             1  
                                        aten::transpose         0.94%      50.313us         1.27%      67.673us       2.820us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.33%      17.360us         0.33%      17.360us       0.723us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.40%      21.528us         1.56%      83.370us       5.558us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.43%      76.263us         1.43%      76.263us       3.178us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.08%     110.943us         2.08%     110.943us       7.396us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.27%      14.621us         0.27%      14.621us       4.874us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.781us         0.03%       1.781us       0.297us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.08%       4.011us         0.08%       4.011us       1.337us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        57.74%       3.077ms        57.74%       3.077ms       3.077ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.329ms
Self CUDA time total: 3.742ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.87%     262.676us        41.62%       2.245ms       2.245ms       0.000us         0.00%       3.882ms       3.882ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.834ms       100.29%       3.834ms       3.834ms             1  
                     aten::scaled_dot_product_attention         0.50%      26.770us         3.49%     188.015us      62.672us       0.000us         0.00%       3.044ms       1.015ms             3  
              aten::_scaled_dot_product_flash_attention         0.35%      18.803us         2.99%     161.245us      53.748us       0.000us         0.00%       3.044ms       1.015ms             3  
                         aten::_flash_attention_forward         0.74%      39.829us         2.21%     119.102us      39.701us       3.044ms        79.61%       3.044ms       1.015ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.044ms        79.61%       3.044ms       1.015ms             3  
                                       aten::contiguous         0.18%       9.451us        32.36%       1.746ms     145.465us       0.000us         0.00%     838.367us      69.864us            12  
                                            aten::clone         0.54%      28.881us        32.18%       1.736ms     144.678us       0.000us         0.00%     838.367us      69.864us            12  
                                            aten::copy_         1.51%      81.201us        30.48%       1.644ms     137.016us     779.615us        20.39%     838.367us      69.864us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     779.615us        20.39%     779.615us      64.968us            12  
                                Activity Buffer Request        27.31%       1.473ms        27.31%       1.473ms       1.473ms      58.752us         1.54%      58.752us      58.752us             1  
                                        aten::transpose         1.01%      54.592us         1.34%      72.471us       3.020us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.33%      17.879us         0.33%      17.879us       0.745us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.37%      20.117us         1.53%      82.751us       5.517us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.41%      76.295us         1.41%      76.295us       3.179us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.13%     114.795us         2.13%     114.795us       7.653us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.27%      14.801us         0.27%      14.801us       4.934us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.110us         0.04%       2.110us       0.352us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       3.990us         0.07%       3.990us       1.330us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        58.38%       3.149ms        58.38%       3.149ms       3.149ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.395ms
Self CUDA time total: 3.823ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.61%     261.106us        43.54%       2.469ms       2.469ms       0.000us         0.00%       3.945ms       3.945ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.898ms       100.28%       3.898ms       3.898ms             1  
                     aten::scaled_dot_product_attention         0.46%      26.241us         3.40%     192.654us      64.218us       0.000us         0.00%       3.100ms       1.033ms             3  
              aten::_scaled_dot_product_flash_attention         0.34%      19.509us         2.94%     166.413us      55.471us       0.000us         0.00%       3.100ms       1.033ms             3  
                         aten::_flash_attention_forward         0.74%      42.081us         2.16%     122.633us      40.878us       3.100ms        79.76%       3.100ms       1.033ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.100ms        79.76%       3.100ms       1.033ms             3  
                                       aten::contiguous         0.20%      11.161us        34.71%       1.968ms     163.994us       0.000us         0.00%     844.704us      70.392us            12  
                                            aten::clone         0.52%      29.682us        34.51%       1.957ms     163.064us       0.000us         0.00%     844.704us      70.392us            12  
                                            aten::copy_         1.45%      82.261us        32.81%       1.860ms     155.026us     786.784us        20.24%     844.704us      70.392us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     786.784us        20.24%     786.784us      65.565us            12  
                                Activity Buffer Request        26.26%       1.489ms        26.26%       1.489ms       1.489ms      57.920us         1.49%      57.920us      57.920us             1  
                                        aten::transpose         0.95%      53.820us         1.26%      71.322us       2.972us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.31%      17.502us         0.31%      17.502us       0.729us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.39%      21.943us         1.53%      86.983us       5.799us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.40%      79.202us         1.40%      79.202us       3.300us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.55%     314.487us         5.55%     314.487us      20.966us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.26%      14.830us         0.26%      14.830us       4.943us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.010us         0.04%       2.010us       0.335us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.040us         0.07%       4.040us       1.347us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        56.46%       3.201ms        56.46%       3.201ms       3.201ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.670ms
Self CUDA time total: 3.887ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.12%     312.519us        40.82%       2.493ms       2.493ms       0.000us         0.00%       4.416ms       4.416ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.365ms       100.24%       4.365ms       4.365ms             1  
                     aten::scaled_dot_product_attention         0.42%      25.922us         3.20%     195.246us      65.082us       0.000us         0.00%       3.547ms       1.182ms             3  
              aten::_scaled_dot_product_flash_attention         0.34%      20.847us         2.77%     169.324us      56.441us       0.000us         0.00%       3.547ms       1.182ms             3  
                         aten::_flash_attention_forward         0.72%      44.243us         2.07%     126.303us      42.101us       3.547ms        81.45%       3.547ms       1.182ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.547ms        81.45%       3.547ms       1.182ms             3  
                                       aten::contiguous         0.17%      10.559us        31.73%       1.938ms     161.473us       0.000us         0.00%     869.122us      72.427us            12  
                                            aten::clone         0.47%      28.763us        31.56%       1.927ms     160.593us       0.000us         0.00%     869.122us      72.427us            12  
                                            aten::copy_         1.36%      83.033us        30.01%       1.832ms     152.707us     807.906us        18.55%     869.122us      72.427us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     807.906us        18.55%     807.906us      67.326us            12  
                                Activity Buffer Request        24.51%       1.497ms        24.51%       1.497ms       1.497ms      61.216us         1.41%      61.216us      61.216us             1  
                                        aten::transpose         0.85%      52.195us         1.14%      69.864us       2.911us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.29%      17.669us         0.29%      17.669us       0.736us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.34%      20.921us         1.44%      87.791us       5.853us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.30%      79.270us         1.30%      79.270us       3.303us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.55%     277.575us         4.55%     277.575us      18.505us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.27%      16.520us         0.27%      16.520us       5.507us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.960us         0.03%       1.960us       0.327us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.040us         0.07%       4.040us       1.347us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        59.18%       3.614ms        59.18%       3.614ms       3.614ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.107ms
Self CUDA time total: 4.355ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         3.85%     236.256us        38.02%       2.335ms       2.335ms       0.000us         0.00%       4.535ms       4.535ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.485ms       100.25%       4.485ms       4.485ms             1  
                     aten::scaled_dot_product_attention         0.43%      26.452us         2.98%     183.275us      61.092us       0.000us         0.00%       3.655ms       1.218ms             3  
              aten::_scaled_dot_product_flash_attention         0.30%      18.620us         2.55%     156.823us      52.274us       0.000us         0.00%       3.655ms       1.218ms             3  
                         aten::_flash_attention_forward         0.59%      36.060us         1.88%     115.323us      38.441us       3.655ms        81.69%       3.655ms       1.218ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.655ms        81.69%       3.655ms       1.218ms             3  
                                       aten::contiguous         0.16%       9.770us        30.40%       1.867ms     155.567us       0.000us         0.00%     880.065us      73.339us            12  
                                            aten::clone         0.46%      28.179us        30.24%       1.857ms     154.753us       0.000us         0.00%     880.065us      73.339us            12  
                                            aten::copy_         1.36%      83.563us        28.74%       1.765ms     147.054us     819.137us        18.31%     880.065us      73.339us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     819.137us        18.31%     819.137us      68.261us            12  
                                Activity Buffer Request        23.24%       1.427ms        23.24%       1.427ms       1.427ms      60.928us         1.36%      60.928us      60.928us             1  
                                        aten::transpose         0.86%      52.980us         1.16%      71.060us       2.961us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.29%      18.080us         0.29%      18.080us       0.753us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.34%      20.930us         1.37%      83.913us       5.594us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.25%      77.043us         1.25%      77.043us       3.210us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.54%     278.990us         4.54%     278.990us      18.599us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.24%      14.661us         0.24%      14.661us       4.887us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.978us         0.03%       1.978us       0.330us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.06%       3.901us         0.06%       3.901us       1.300us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        61.98%       3.806ms        61.98%       3.806ms       3.806ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.141ms
Self CUDA time total: 4.474ms


impl                     wl                  p50(ms)  ok
torch_flash_ma           cuda_attn_L128_bfloat16     1.22  True
torch_flash_ma           cuda_attn_L256_bfloat16     1.28  True
torch_flash_ma           cuda_attn_L320_bfloat16     1.30  True
torch_flash_ma           cuda_attn_L384_bfloat16     1.33  True
torch_flash_ma           cuda_attn_L448_bfloat16     1.50  True
torch_flash_ma           cuda_attn_L512_bfloat16     1.51  True

Artifacts:

attention.jsonl