Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.92s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         4.77%     333.269us        32.71%       2.284ms       2.284ms       0.000us         0.00%       5.420ms       5.420ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.402ms       100.61%       5.402ms       5.402ms             1  
                     aten::scaled_dot_product_attention         0.44%      30.450us         2.54%     177.435us      59.145us       0.000us         0.00%       4.753ms       1.584ms             3  
          aten::_scaled_dot_product_efficient_attention         0.33%      22.722us         2.10%     146.985us      48.995us       0.000us         0.00%       4.753ms       1.584ms             3  
                     aten::_efficient_attention_forward         0.51%      35.382us         1.42%      99.273us      33.091us       4.753ms        88.51%       4.753ms       1.584ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.753ms        88.51%       4.753ms       1.584ms             3  
                                       aten::contiguous         0.17%      11.660us        24.51%       1.712ms     190.185us       0.000us         0.00%     667.266us      74.141us             9  
                                            aten::clone         0.46%      31.810us        24.34%       1.700ms     188.889us       0.000us         0.00%     667.266us      74.141us             9  
                                            aten::copy_         1.01%      70.871us        22.86%       1.597ms     177.404us     616.738us        11.49%     667.266us      74.141us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     616.738us        11.49%     616.738us      68.526us             9  
                                Activity Buffer Request        20.64%       1.441ms        20.64%       1.441ms       1.441ms      50.528us         0.94%      50.528us      50.528us             1  
                                        aten::transpose         0.91%      63.619us         1.25%      87.011us       3.625us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.33%      23.392us         0.33%      23.392us       0.975us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.24%      16.972us         1.02%      71.553us       7.950us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.18%      82.691us         1.18%      82.691us       3.938us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.55%     108.383us         1.55%     108.383us       9.032us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.05%       3.260us         0.05%       3.260us       1.087us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.12%       8.450us         0.12%       8.450us       2.817us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        67.29%       4.700ms        67.29%       4.700ms       4.700ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.984ms
Self CUDA time total: 5.369ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.53%     251.015us        29.52%       2.098ms       2.098ms       0.000us         0.00%       5.633ms       5.633ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.587ms       100.15%       5.587ms       5.587ms             1  
                     aten::scaled_dot_product_attention         0.25%      17.630us         2.05%     145.594us      48.531us       0.000us         0.00%       4.943ms       1.648ms             3  
          aten::_scaled_dot_product_efficient_attention         0.28%      19.810us         1.80%     127.964us      42.655us       0.000us         0.00%       4.943ms       1.648ms             3  
                     aten::_efficient_attention_forward         0.42%      29.862us         1.18%      83.512us      27.837us       4.943ms        88.61%       4.943ms       1.648ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.943ms        88.61%       4.943ms       1.648ms             3  
                                       aten::contiguous         0.10%       7.191us        23.30%       1.656ms     184.002us       0.000us         0.00%     689.540us      76.616us             9  
                                            aten::clone         0.33%      23.318us        23.20%       1.649ms     183.203us       0.000us         0.00%     689.540us      76.616us             9  
                                            aten::copy_         0.92%      65.725us        22.12%       1.572ms     174.717us     635.140us        11.39%     689.540us      76.616us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     635.140us        11.39%     635.140us      70.571us             9  
                                Activity Buffer Request        20.24%       1.439ms        20.24%       1.439ms       1.439ms      54.400us         0.98%      54.400us      54.400us             1  
                                        aten::transpose         0.71%      50.494us         0.99%      70.123us       2.922us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.28%      19.629us         0.28%      19.629us       0.818us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.18%      12.608us         0.75%      53.061us       5.896us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.94%      66.903us         0.94%      66.903us       3.186us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.25%      89.012us         1.25%      89.012us       7.418us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.220us         0.03%       2.220us       0.740us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       3.880us         0.05%       3.880us       1.293us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.48%       5.009ms        70.48%       5.009ms       5.009ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.107ms
Self CUDA time total: 5.578ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.28%     246.598us        28.54%       2.146ms       2.146ms       0.000us         0.00%       6.014ms       6.014ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.967ms       100.18%       5.967ms       5.967ms             1  
                     aten::scaled_dot_product_attention         0.24%      18.181us         1.92%     144.583us      48.194us       0.000us         0.00%       5.302ms       1.767ms             3  
          aten::_scaled_dot_product_efficient_attention         0.27%      19.980us         1.68%     126.402us      42.134us       0.000us         0.00%       5.302ms       1.767ms             3  
                     aten::_efficient_attention_forward         0.38%      28.571us         1.10%      82.521us      27.507us       5.302ms        89.01%       5.302ms       1.767ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.302ms        89.01%       5.302ms       1.767ms             3  
                                       aten::contiguous         0.09%       6.930us        22.70%       1.707ms     189.666us       0.000us         0.00%     712.547us      79.172us             9  
                                            aten::clone         0.30%      22.691us        22.61%       1.700ms     188.896us       0.000us         0.00%     712.547us      79.172us             9  
                                            aten::copy_         1.08%      81.024us        21.57%       1.622ms     180.228us     654.403us        10.99%     712.547us      79.172us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     654.403us        10.99%     654.403us      72.711us             9  
                                Activity Buffer Request        19.57%       1.471ms        19.57%       1.471ms       1.471ms      58.144us         0.98%      58.144us      58.144us             1  
                                        aten::transpose         0.68%      51.431us         0.95%      71.351us       2.973us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.26%      19.920us         0.26%      19.920us       0.830us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.16%      11.979us         0.74%      55.320us       6.147us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.93%      69.561us         0.93%      69.561us       3.312us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.22%      91.652us         1.22%      91.652us       7.638us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.359us         0.03%       2.359us       0.786us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       3.430us         0.05%       3.430us       1.143us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.46%       5.373ms        71.46%       5.373ms       5.373ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.519ms
Self CUDA time total: 5.956ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.21%     251.576us        29.97%       2.347ms       2.347ms       0.000us         0.00%       6.116ms       6.116ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.068ms       100.14%       6.068ms       6.068ms             1  
                     aten::scaled_dot_product_attention         0.24%      18.800us         1.87%     146.693us      48.898us       0.000us         0.00%       5.408ms       1.803ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      19.900us         1.63%     127.893us      42.631us       0.000us         0.00%       5.408ms       1.803ms             3  
                     aten::_efficient_attention_forward         0.38%      29.372us         1.07%      83.903us      27.968us       5.408ms        89.25%       5.408ms       1.803ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.408ms        89.25%       5.408ms       1.803ms             3  
                                       aten::contiguous         0.10%       7.511us        24.29%       1.902ms     211.340us       0.000us         0.00%     708.735us      78.748us             9  
                                            aten::clone         0.28%      21.872us        24.19%       1.895ms     210.505us       0.000us         0.00%     708.735us      78.748us             9  
                                            aten::copy_         0.85%      66.540us        23.20%       1.817ms     201.834us     651.551us        10.75%     708.735us      78.748us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     651.551us        10.75%     651.551us      72.395us             9  
                                Activity Buffer Request        18.68%       1.462ms        18.68%       1.462ms       1.462ms      57.184us         0.94%      57.184us      57.184us             1  
                                        aten::transpose         0.65%      50.781us         0.90%      70.402us       2.933us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.25%      19.621us         0.25%      19.621us       0.818us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.809us         0.72%      56.170us       6.241us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.90%      70.242us         0.90%      70.242us       3.345us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.97%     310.797us         3.97%     310.797us      25.900us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.250us         0.03%       2.250us       0.750us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.220us         0.04%       3.220us       1.073us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.03%       5.484ms        70.03%       5.484ms       5.484ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.830ms
Self CUDA time total: 6.059ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.15%     250.575us        28.50%       2.270ms       2.270ms       0.000us         0.00%       6.322ms       6.322ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.270ms       100.14%       6.270ms       6.270ms             1  
                     aten::scaled_dot_product_attention         0.22%      17.572us         1.82%     145.084us      48.361us       0.000us         0.00%       5.598ms       1.866ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      19.250us         1.60%     127.512us      42.504us       0.000us         0.00%       5.598ms       1.866ms             3  
                     aten::_efficient_attention_forward         0.36%      28.812us         1.05%      83.962us      27.987us       5.598ms        89.40%       5.598ms       1.866ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.598ms        89.40%       5.598ms       1.866ms             3  
                                       aten::contiguous         0.09%       6.912us        22.94%       1.827ms     203.045us       0.000us         0.00%     724.000us      80.444us             9  
                                            aten::clone         0.28%      21.949us        22.86%       1.820ms     202.277us       0.000us         0.00%     724.000us      80.444us             9  
                                            aten::copy_         0.82%      65.091us        21.89%       1.744ms     193.745us     664.032us        10.60%     724.000us      80.444us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     664.032us        10.60%     664.032us      73.781us             9  
                                Activity Buffer Request        18.02%       1.435ms        18.02%       1.435ms       1.435ms      59.968us         0.96%      59.968us      59.968us             1  
                                        aten::transpose         0.64%      50.930us         0.89%      70.859us       2.952us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.25%      19.929us         0.25%      19.929us       0.830us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.022us         0.69%      54.843us       6.094us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.87%      69.430us         0.87%      69.430us       3.306us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.34%     266.388us         3.34%     266.388us      22.199us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.320us         0.03%       2.320us       0.773us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.120us         0.04%       3.120us       1.040us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.50%       5.695ms        71.50%       5.695ms       5.695ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.965ms
Self CUDA time total: 6.262ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.00%     248.403us        26.98%       2.232ms       2.232ms       0.000us         0.00%       6.668ms       6.668ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.616ms       100.13%       6.616ms       6.616ms             1  
                     aten::scaled_dot_product_attention         0.21%      17.221us         1.72%     142.654us      47.551us       0.000us         0.00%       5.939ms       1.980ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      18.779us         1.52%     125.433us      41.811us       0.000us         0.00%       5.939ms       1.980ms             3  
                     aten::_efficient_attention_forward         0.34%      28.440us         0.99%      81.712us      27.237us       5.939ms        89.88%       5.939ms       1.980ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.939ms        89.88%       5.939ms       1.980ms             3  
                                       aten::contiguous         0.08%       6.861us        21.66%       1.792ms     199.142us       0.000us         0.00%     729.440us      81.049us             9  
                                            aten::clone         0.26%      21.352us        21.58%       1.785ms     198.379us       0.000us         0.00%     729.440us      81.049us             9  
                                            aten::copy_         0.83%      69.012us        20.65%       1.709ms     189.858us     668.928us        10.12%     729.440us      81.049us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     668.928us        10.12%     668.928us      74.325us             9  
                                Activity Buffer Request        17.29%       1.430ms        17.29%       1.430ms       1.430ms      60.512us         0.92%      60.512us      60.512us             1  
                                        aten::transpose         0.63%      51.780us         0.89%      73.784us       3.074us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.27%      22.004us         0.27%      22.004us       0.917us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.870us         0.67%      55.340us       6.149us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.84%      69.312us         0.84%      69.312us       3.301us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         2.79%     231.145us         2.79%     231.145us      19.262us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.280us         0.03%       2.280us       0.760us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.570us         0.04%       3.570us       1.190us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        73.02%       6.041ms        73.02%       6.041ms       6.041ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.273ms
Self CUDA time total: 6.608ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.83  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.89  True
torch_mem_eff            cuda_attn_L320_bfloat16     2.00  True
torch_mem_eff            cuda_attn_L384_bfloat16     1.97  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.06  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.19  True

Artifacts:

attention.jsonl