Flash Attention Implementation

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.26s | Raw GitHub
import subprocess

print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Mon Nov 10 21:58:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   32C    P0            139W /  350W |       0MiB /  46068MiB |     83%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Flash Attention Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 4.03s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_flash(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_flash_ma",
    impl_tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
    impl_func=torch_flash,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.628ms       101.57%       3.628ms       3.628ms             1  
                                         torch_flash_ma         5.67%     314.697us        48.49%       2.689ms       2.689ms       0.000us         0.00%       3.612ms       3.612ms             1  
                     aten::scaled_dot_product_attention         0.72%      39.870us         3.84%     213.234us      71.078us       0.000us         0.00%       2.845ms     948.416us             3  
              aten::_scaled_dot_product_flash_attention         0.43%      24.020us         3.13%     173.364us      57.788us       0.000us         0.00%       2.845ms     948.416us             3  
                         aten::_flash_attention_forward         0.70%      39.034us         2.33%     129.042us      43.014us       2.845ms        79.65%       2.845ms     948.416us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.845ms        79.65%       2.845ms     948.416us             3  
                                       aten::contiguous         0.22%      12.191us        37.88%       2.101ms     175.086us       0.000us         0.00%     766.879us      63.907us            12  
                                            aten::clone         0.59%      32.480us        37.66%       2.089ms     174.070us       0.000us         0.00%     766.879us      63.907us            12  
                                            aten::copy_         1.56%      86.776us        35.66%       1.978ms     164.799us     726.879us        20.35%     766.879us      63.907us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     726.879us        20.35%     726.879us      60.573us            12  
                                Activity Buffer Request        32.26%       1.789ms        32.26%       1.789ms       1.789ms      40.000us         1.12%      40.000us      40.000us             1  
                                        aten::transpose         1.07%      59.612us         1.46%      80.772us       3.365us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.38%      21.160us         0.38%      21.160us       0.882us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.40%      22.459us         1.80%      99.659us       6.644us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.66%      92.037us         1.66%      92.037us       3.835us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         2.29%     126.900us         2.29%     126.900us       8.460us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.28%      15.620us         0.28%      15.620us       5.207us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.04%       2.280us         0.04%       2.280us       0.380us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.20%      11.200us         0.20%      11.200us       3.733us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        51.51%       2.857ms        51.51%       2.857ms       2.857ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.546ms
Self CUDA time total: 3.572ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.57%     259.472us        46.25%       2.626ms       2.626ms       0.000us         0.00%       3.786ms       3.786ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.742ms       100.27%       3.742ms       3.742ms             1  
                     aten::scaled_dot_product_attention         0.42%      24.011us         3.41%     193.713us      64.571us       0.000us         0.00%       2.968ms     989.492us             3  
              aten::_scaled_dot_product_flash_attention         0.33%      18.660us         2.99%     169.702us      56.567us       0.000us         0.00%       2.968ms     989.492us             3  
                         aten::_flash_attention_forward         0.83%      47.240us         2.21%     125.672us      41.891us       2.968ms        79.55%       2.968ms     989.492us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.968ms        79.55%       2.968ms     989.492us             3  
                                       aten::contiguous         0.19%      10.613us        37.48%       2.128ms     177.333us       0.000us         0.00%     817.342us      68.112us            12  
                                            aten::clone         0.52%      29.369us        37.29%       2.117ms     176.448us       0.000us         0.00%     817.342us      68.112us            12  
                                            aten::copy_         1.41%      80.272us        35.64%       2.023ms     168.619us     762.942us        20.45%     817.342us      68.112us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     762.942us        20.45%     762.942us      63.579us            12  
                                Activity Buffer Request        32.67%       1.855ms        32.67%       1.855ms       1.855ms      54.400us         1.46%      54.400us      54.400us             1  
                                        aten::transpose         0.90%      51.353us         1.23%      69.912us       2.913us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.33%      18.559us         0.33%      18.559us       0.773us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.37%      20.909us         1.47%      83.391us       5.559us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.39%      78.982us         1.39%      78.982us       3.291us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.94%     110.382us         1.94%     110.382us       7.359us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.24%      13.461us         0.24%      13.461us       4.487us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.05%       2.710us         0.05%       2.710us       0.452us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.09%       4.940us         0.09%       4.940us       1.647us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        53.75%       3.052ms        53.75%       3.052ms       3.052ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.678ms
Self CUDA time total: 3.731ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.60%     260.065us        44.20%       2.500ms       2.500ms       0.000us         0.00%       3.871ms       3.871ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.823ms       100.27%       3.823ms       3.823ms             1  
                     aten::scaled_dot_product_attention         0.46%      25.840us         3.28%     185.632us      61.877us       0.000us         0.00%       3.035ms       1.012ms             3  
              aten::_scaled_dot_product_flash_attention         0.32%      17.999us         2.82%     159.792us      53.264us       0.000us         0.00%       3.035ms       1.012ms             3  
                         aten::_flash_attention_forward         0.73%      41.121us         2.09%     118.472us      39.491us       3.035ms        79.59%       3.035ms       1.012ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.035ms        79.59%       3.035ms       1.012ms             3  
                                       aten::contiguous         0.19%      10.499us        35.53%       2.010ms     167.521us       0.000us         0.00%     836.093us      69.674us            12  
                                            aten::clone         0.50%      28.109us        35.35%       2.000ms     166.646us       0.000us         0.00%     836.093us      69.674us            12  
                                            aten::copy_         1.42%      80.472us        33.72%       1.908ms     158.959us     778.333us        20.41%     836.093us      69.674us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     778.333us        20.41%     778.333us      64.861us            12  
                                Activity Buffer Request        30.89%       1.747ms        30.89%       1.747ms       1.747ms      57.760us         1.51%      57.760us      57.760us             1  
                                        aten::transpose         0.88%      49.936us         1.20%      67.813us       2.826us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.32%      17.877us         0.32%      17.877us       0.745us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.36%      20.321us         1.47%      83.262us       5.551us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.37%      77.333us         1.37%      77.333us       3.222us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         1.81%     102.481us         1.81%     102.481us       6.832us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.25%      14.120us         0.25%      14.120us       4.707us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.688us         0.03%       1.688us       0.281us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.09%       5.331us         0.09%       5.331us       1.777us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        55.80%       3.157ms        55.80%       3.157ms       3.157ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.657ms
Self CUDA time total: 3.813ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.36%     258.876us        46.43%       2.758ms       2.758ms       0.000us         0.00%       3.960ms       3.960ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       3.911ms       100.27%       3.911ms       3.911ms             1  
                     aten::scaled_dot_product_attention         0.42%      24.860us         4.02%     238.593us      79.531us       0.000us         0.00%       3.109ms       1.036ms             3  
              aten::_scaled_dot_product_flash_attention         0.32%      19.211us         3.60%     213.733us      71.244us       0.000us         0.00%       3.109ms       1.036ms             3  
                         aten::_flash_attention_forward         0.74%      43.768us         2.88%     170.772us      56.924us       3.109ms        79.70%       3.109ms       1.036ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.109ms        79.70%       3.109ms       1.036ms             3  
                                       aten::contiguous         0.17%      10.099us        37.27%       2.213ms     184.454us       0.000us         0.00%     850.560us      70.880us            12  
                                            aten::clone         0.48%      28.250us        37.10%       2.203ms     183.613us       0.000us         0.00%     850.560us      70.880us            12  
                                            aten::copy_         1.36%      80.903us        35.54%       2.111ms     175.896us     791.680us        20.30%     850.560us      70.880us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     791.680us        20.30%     791.680us      65.973us            12  
                                Activity Buffer Request        29.13%       1.730ms        29.13%       1.730ms       1.730ms      58.880us         1.51%      58.880us      58.880us             1  
                                        aten::transpose         0.86%      50.781us         1.18%      70.362us       2.932us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.33%      19.581us         0.33%      19.581us       0.816us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.35%      20.589us         1.40%      83.331us       5.555us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.32%      78.663us         1.32%      78.663us       3.278us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         5.47%     324.743us         5.47%     324.743us      21.650us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.23%      13.800us         0.23%      13.800us       4.600us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.80%      47.662us         0.80%      47.662us       7.944us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.10%       5.930us         0.10%       5.930us       1.977us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        53.57%       3.181ms        53.57%       3.181ms       3.181ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.939ms
Self CUDA time total: 3.901ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         4.85%     313.852us        44.01%       2.846ms       2.846ms       0.000us         0.00%       4.405ms       4.405ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.356ms       100.24%       4.356ms       4.356ms             1  
                     aten::scaled_dot_product_attention         0.40%      25.602us         2.92%     188.673us      62.891us       0.000us         0.00%       3.542ms       1.181ms             3  
              aten::_scaled_dot_product_flash_attention         0.29%      18.450us         2.52%     163.071us      54.357us       0.000us         0.00%       3.542ms       1.181ms             3  
                         aten::_flash_attention_forward         0.66%      42.791us         1.88%     121.422us      40.474us       3.542ms        81.52%       3.542ms       1.181ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.542ms        81.52%       3.542ms       1.181ms             3  
                                       aten::contiguous         0.15%       9.702us        35.55%       2.299ms     191.596us       0.000us         0.00%     862.461us      71.872us            12  
                                            aten::clone         0.45%      28.857us        35.40%       2.289ms     190.788us       0.000us         0.00%     862.461us      71.872us            12  
                                            aten::copy_         1.23%      79.423us        33.92%       2.194ms     182.809us     803.166us        18.48%     862.461us      71.872us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     803.166us        18.48%     803.166us      66.930us            12  
                                Activity Buffer Request        28.18%       1.822ms        28.18%       1.822ms       1.822ms      59.295us         1.36%      59.295us      59.295us             1  
                                        aten::transpose         0.77%      49.902us         1.04%      67.461us       2.811us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.27%      17.559us         0.27%      17.559us       0.732us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.33%      21.611us         1.34%      86.704us       5.780us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.24%      80.042us         1.24%      80.042us       3.335us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.86%     314.554us         4.86%     314.554us      20.970us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.23%      14.691us         0.23%      14.691us       4.897us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.03%       1.700us         0.03%       1.700us       0.283us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.08%       4.940us         0.08%       4.940us       1.647us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        55.99%       3.621ms        55.99%       3.621ms       3.621ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.467ms
Self CUDA time total: 4.345ms



======================================================================
PROFILE TRACE: torch_flash_ma | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         torch_flash_ma         3.49%     226.744us        41.30%       2.682ms       2.682ms       0.000us         0.00%       4.507ms       4.507ms             1  
                                         torch_flash_ma         0.00%       0.000us         0.00%       0.000us       0.000us       4.456ms       100.23%       4.456ms       4.456ms             1  
                     aten::scaled_dot_product_attention         0.39%      25.000us         2.68%     173.753us      57.918us       0.000us         0.00%       3.635ms       1.212ms             3  
              aten::_scaled_dot_product_flash_attention         0.28%      18.340us         2.29%     148.753us      49.584us       0.000us         0.00%       3.635ms       1.212ms             3  
                         aten::_flash_attention_forward         0.53%      34.164us         1.68%     109.263us      36.421us       3.635ms        81.77%       3.635ms       1.212ms             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       3.635ms        81.77%       3.635ms       1.212ms             3  
                                       aten::contiguous         0.14%       8.821us        34.49%       2.240ms     186.626us       0.000us         0.00%     871.422us      72.619us            12  
                                            aten::clone         0.41%      26.612us        34.36%       2.231ms     185.890us       0.000us         0.00%     871.422us      72.619us            12  
                                            aten::copy_         1.18%      76.909us        32.95%       2.140ms     178.308us     810.270us        18.23%     871.422us      72.619us            12  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     810.270us        18.23%     810.270us      67.523us            12  
                                Activity Buffer Request        27.48%       1.784ms        27.48%       1.784ms       1.784ms      61.152us         1.38%      61.152us      61.152us             1  
                                        aten::transpose         0.71%      45.940us         0.97%      63.019us       2.626us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.26%      17.079us         0.26%      17.079us       0.712us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.30%      19.781us         1.27%      82.742us       5.516us       0.000us         0.00%       0.000us       0.000us            15  
                                            aten::empty         1.21%      78.423us         1.21%      78.423us       3.268us       0.000us         0.00%       0.000us       0.000us            24  
                                       cudaLaunchKernel         4.62%     300.294us         4.62%     300.294us      20.020us       0.000us         0.00%       0.000us       0.000us            15  
                                    aten::empty_strided         0.21%      13.430us         0.21%      13.430us       4.477us       0.000us         0.00%       0.000us       0.000us             3  
                                 cudaDeviceGetAttribute         0.02%       1.610us         0.02%       1.610us       0.268us       0.000us         0.00%       0.000us       0.000us             6  
                                   cudaFuncSetAttribute         0.07%       4.648us         0.07%       4.648us       1.549us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        58.70%       3.811ms        58.70%       3.811ms       3.811ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.493ms
Self CUDA time total: 4.445ms


impl                     wl                  p50(ms)  ok
torch_flash_ma           cuda_attn_L128_bfloat16     1.23  True
torch_flash_ma           cuda_attn_L256_bfloat16     1.28  True
torch_flash_ma           cuda_attn_L320_bfloat16     1.30  True
torch_flash_ma           cuda_attn_L384_bfloat16     1.33  True
torch_flash_ma           cuda_attn_L448_bfloat16     1.48  True
torch_flash_ma           cuda_attn_L512_bfloat16     1.52  True

Artifacts:

attention.jsonl