Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.26s | Raw GitHub
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Wed Oct 29 15:50:02 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   29C    P0            165W /  350W |       0MiB /  46068MiB |     61%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.82s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_flash_ma",
    impl_tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
    impl_func=torch_flash,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.562ms       101.45%       3.562ms       3.562ms             1  
                                         torch_flash_ma         6.38%     328.580us        45.84%       2.360ms       2.360ms       0.000us         0.00%       3.551ms       3.551ms             1  
                     aten::scaled_dot_product_attention         0.79%      40.571us         4.12%     212.315us      70.772us       0.000us         0.00%       2.798ms     932.779us             3  
              aten::_scaled_dot_product_flash_attention         0.52%      26.642us         3.34%     171.744us      57.248us       0.000us         0.00%       2.798ms     932.779us             3  
                         aten::_flash_attention_forward         0.74%      37.939us         2.40%     123.383us      41.128us       2.798ms        79.71%       2.798ms     932.779us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.798ms        79.71%       2.798ms     932.779us             3  
                                       aten::contiguous         0.27%      13.720us        34.12%       1.757ms     146.409us       0.000us         0.00%     752.288us      62.691us            12  
                                            aten::clone         0.73%      37.449us        33.85%       1.743ms     145.266us       0.000us         0.00%     752.288us      62.691us            12  
                                            aten::copy_         1.68%      86.484us        31.57%       1.625ms     135.456us     712.095us        20.29%     752.288us      62.691us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     712.095us        20.29%     712.095us      59.341us            12  
                                Activity Buffer Request        28.00%       1.442ms        28.00%       1.442ms       1.442ms      40.193us         1.14%      40.193us      40.193us             1  
                                        aten::transpose         1.22%      62.637us         1.64%      84.218us       3.509us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.42%      21.581us         0.42%      21.581us       0.899us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.48%      24.619us         1.97%     101.523us       6.768us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.76%      90.465us         1.76%      90.465us       3.769us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.36%     121.521us         2.36%     121.521us       8.101us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.31%      15.721us         0.31%      15.721us       5.240us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.280us         0.04%       2.280us       0.380us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.16%       8.181us         0.16%       8.181us       2.727us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        54.16%       2.789ms        54.16%       2.789ms       2.789ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.149ms
Self CUDA time total: 3.510ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.71%     257.538us        44.52%       2.436ms       2.436ms       0.000us         0.00%       3.763ms       3.763ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.719ms       100.29%       3.719ms       3.719ms             1  
                     aten::scaled_dot_product_attention         0.45%      24.440us         3.30%     180.683us      60.228us       0.000us         0.00%       2.948ms     982.525us             3  
              aten::_scaled_dot_product_flash_attention         0.35%      18.890us         2.86%     156.243us      52.081us       0.000us         0.00%       2.948ms     982.525us             3  
                         aten::_flash_attention_forward         0.68%      37.218us         2.07%     113.133us      37.711us       2.948ms        79.49%       2.948ms     982.525us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.948ms        79.49%       2.948ms     982.525us             3  
                                       aten::contiguous         0.16%       8.651us        35.72%       1.955ms     162.890us       0.000us         0.00%     815.678us      67.973us            12  
                                            aten::clone         0.48%      26.452us        35.56%       1.946ms     162.169us       0.000us         0.00%     815.678us      67.973us            12  
                                            aten::copy_         1.81%      99.279us        33.97%       1.859ms     154.885us     760.479us        20.51%     815.678us      67.973us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     760.479us        20.51%     760.479us      63.373us            12  
                                Activity Buffer Request        30.60%       1.674ms        30.60%       1.674ms       1.674ms      55.199us         1.49%      55.199us      55.199us             1  
                                        aten::transpose         0.92%      50.270us         1.23%      67.460us       2.811us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.31%      17.190us         0.31%      17.190us       0.716us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.34%      18.723us         1.45%      79.503us       5.300us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.39%      75.933us         1.39%      75.933us       3.164us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.98%     108.143us         1.98%     108.143us       7.210us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.25%      13.599us         0.25%      13.599us       4.533us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.831us         0.03%       1.831us       0.305us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       3.690us         0.07%       3.690us       1.230us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        55.48%       3.036ms        55.48%       3.036ms       3.036ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.472ms
Self CUDA time total: 3.708ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.65%     248.558us        40.70%       2.176ms       2.176ms       0.000us         0.00%       3.868ms       3.868ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.819ms       100.29%       3.819ms       3.819ms             1  
                     aten::scaled_dot_product_attention         0.45%      24.181us         3.36%     179.834us      59.945us       0.000us         0.00%       3.027ms       1.009ms             3  
              aten::_scaled_dot_product_flash_attention         0.34%      18.100us         2.91%     155.653us      51.884us       0.000us         0.00%       3.027ms       1.009ms             3  
                         aten::_flash_attention_forward         0.73%      38.760us         2.16%     115.412us      38.471us       3.027ms        79.48%       3.027ms       1.009ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.027ms        79.48%       3.027ms       1.009ms             3  
                                       aten::contiguous         0.16%       8.609us        31.88%       1.704ms     142.018us       0.000us         0.00%     841.280us      70.107us            12  
                                            aten::clone         0.50%      26.820us        31.72%       1.696ms     141.301us       0.000us         0.00%     841.280us      70.107us            12  
                                            aten::copy_         1.47%      78.703us        30.10%       1.609ms     134.076us     781.631us        20.52%     841.280us      70.107us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     781.631us        20.52%     781.631us      65.136us            12  
                                Activity Buffer Request        27.11%       1.449ms        27.11%       1.449ms       1.449ms      59.649us         1.57%      59.649us      59.649us             1  
                                        aten::transpose         0.90%      48.151us         1.22%      65.102us       2.713us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.32%      16.951us         0.32%      16.951us       0.706us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.35%      18.789us         1.49%      79.862us       5.324us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.38%      73.892us         1.38%      73.892us       3.079us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.96%     104.680us         1.96%     104.680us       6.979us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.28%      15.081us         0.28%      15.081us       5.027us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.791us         0.03%       1.791us       0.299us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       3.500us         0.07%       3.500us       1.167us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        59.30%       3.169ms        59.30%       3.169ms       3.169ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.345ms
Self CUDA time total: 3.808ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.50%     255.237us        42.25%       2.398ms       2.398ms       0.000us         0.00%       3.984ms       3.984ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.936ms       100.28%       3.936ms       3.936ms             1  
                     aten::scaled_dot_product_attention         0.42%      23.840us         3.17%     179.904us      59.968us       0.000us         0.00%       3.135ms       1.045ms             3  
              aten::_scaled_dot_product_flash_attention         0.36%      20.442us         2.75%     156.064us      52.021us       0.000us         0.00%       3.135ms       1.045ms             3  
                         aten::_flash_attention_forward         0.68%      38.721us         1.99%     113.183us      37.728us       3.135ms        79.87%       3.135ms       1.045ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.135ms        79.87%       3.135ms       1.045ms             3  
                                       aten::contiguous         0.17%       9.382us        33.81%       1.919ms     159.915us       0.000us         0.00%     848.416us      70.701us            12  
                                            aten::clone         0.52%      29.639us        33.64%       1.910ms     159.133us       0.000us         0.00%     848.416us      70.701us            12  
                                            aten::copy_         1.40%      79.644us        32.03%       1.818ms     151.492us     790.048us        20.13%     848.416us      70.701us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     790.048us        20.13%     790.048us      65.837us            12  
                                Activity Buffer Request        25.14%       1.427ms        25.14%       1.427ms       1.427ms      58.368us         1.49%      58.368us      58.368us             1  
                                        aten::transpose         0.87%      49.289us         1.17%      66.169us       2.757us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.30%      16.880us         0.30%      16.880us       0.703us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.35%      19.852us         1.42%      80.662us       5.377us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.32%      74.981us         1.32%      74.981us       3.124us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.89%     334.125us         5.89%     334.125us      22.275us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.24%      13.720us         0.24%      13.720us       4.573us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.760us         0.03%       1.760us       0.293us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.06%       3.570us         0.06%       3.570us       1.190us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        57.75%       3.278ms        57.75%       3.278ms       3.278ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.676ms
Self CUDA time total: 3.925ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         5.07%     311.056us        40.82%       2.505ms       2.505ms       0.000us         0.00%       4.409ms       4.409ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.359ms       100.26%       4.359ms       4.359ms             1  
                     aten::scaled_dot_product_attention         0.41%      24.931us         3.07%     188.265us      62.755us       0.000us         0.00%       3.539ms       1.180ms             3  
              aten::_scaled_dot_product_flash_attention         0.33%      20.199us         2.66%     163.334us      54.445us       0.000us         0.00%       3.539ms       1.180ms             3  
                         aten::_flash_attention_forward         0.67%      41.371us         1.94%     118.823us      39.608us       3.539ms        81.38%       3.539ms       1.180ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.539ms        81.38%       3.539ms       1.180ms             3  
                                       aten::contiguous         0.16%       9.771us        31.97%       1.962ms     163.526us       0.000us         0.00%     870.819us      72.568us            12  
                                            aten::clone         0.47%      28.779us        31.82%       1.953ms     162.712us       0.000us         0.00%     870.819us      72.568us            12  
                                            aten::copy_         1.27%      77.896us        30.33%       1.862ms     155.132us     809.571us        18.62%     870.819us      72.568us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     809.571us        18.62%     809.571us      67.464us            12  
                                Activity Buffer Request        24.14%       1.481ms        24.14%       1.481ms       1.481ms      61.248us         1.41%      61.248us      61.248us             1  
                                        aten::transpose         0.82%      50.583us         1.11%      68.092us       2.837us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.29%      17.509us         0.29%      17.509us       0.730us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.32%      19.913us         1.33%      81.883us       5.459us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.23%      75.660us         1.23%      75.660us       3.153us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.31%     325.825us         5.31%     325.825us      21.722us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.24%      14.770us         0.24%      14.770us       4.923us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.990us         0.03%       1.990us       0.332us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.06%       3.670us         0.06%       3.670us       1.223us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        59.18%       3.632ms        59.18%       3.632ms       3.632ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.137ms
Self CUDA time total: 4.348ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.13%     252.675us        38.98%       2.384ms       2.384ms       0.000us         0.00%       4.451ms       4.451ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.400ms       100.24%       4.400ms       4.400ms             1  
                     aten::scaled_dot_product_attention         0.50%      30.480us         3.11%     190.334us      63.445us       0.000us         0.00%       3.566ms       1.189ms             3  
              aten::_scaled_dot_product_flash_attention         0.31%      19.082us         2.61%     159.854us      53.285us       0.000us         0.00%       3.566ms       1.189ms             3  
                         aten::_flash_attention_forward         0.62%      38.112us         1.93%     118.053us      39.351us       3.566ms        81.24%       3.566ms       1.189ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.566ms        81.24%       3.566ms       1.189ms             3  
                                       aten::contiguous         0.16%       9.891us        31.02%       1.897ms     158.059us       0.000us         0.00%     884.831us      73.736us            12  
                                            aten::clone         0.50%      30.290us        30.85%       1.887ms     157.234us       0.000us         0.00%     884.831us      73.736us            12  
                                            aten::copy_         1.28%      78.520us        29.35%       1.795ms     149.550us     823.711us        18.76%     884.831us      73.736us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     823.711us        18.76%     823.711us      68.643us            12  
                                Activity Buffer Request        23.29%       1.424ms        23.29%       1.424ms       1.424ms      61.120us         1.39%      61.120us      61.120us             1  
                                        aten::transpose         0.81%      49.593us         1.09%      66.721us       2.780us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.28%      17.128us         0.28%      17.128us       0.714us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.33%      20.381us         1.35%      82.362us       5.491us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.23%      74.920us         1.23%      74.920us       3.122us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.19%     317.558us         5.19%     317.558us      21.171us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.25%      15.161us         0.25%      15.161us       5.054us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.791us         0.03%       1.791us       0.299us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.06%       3.670us         0.06%       3.670us       1.223us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        61.02%       3.732ms        61.02%       3.732ms       3.732ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.115ms
Self CUDA time total: 4.390ms


impl                     wl                  p50(ms)  ok
torch_flash_ma           cuda_attn_L128_bfloat16     1.21  True
torch_flash_ma           cuda_attn_L256_bfloat16     1.27  True
torch_flash_ma           cuda_attn_L320_bfloat16     1.30  True
torch_flash_ma           cuda_attn_L384_bfloat16     1.32  True
torch_flash_ma           cuda_attn_L448_bfloat16     1.48  True
torch_flash_ma           cuda_attn_L512_bfloat16     1.49  True

Artifacts:

attention.jsonl