Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 4.02s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         4.61%     329.029us        32.49%       2.320ms       2.320ms       0.000us         0.00%       5.545ms       5.545ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.524ms       100.54%       5.524ms       5.524ms             1  
                     aten::scaled_dot_product_attention         0.42%      29.860us         2.75%     196.242us      65.414us       0.000us         0.00%       4.878ms       1.626ms             3  
          aten::_scaled_dot_product_efficient_attention         0.35%      25.230us         2.33%     166.382us      55.461us       0.000us         0.00%       4.878ms       1.626ms             3  
                     aten::_efficient_attention_forward         0.73%      52.049us         1.68%     119.861us      39.954us       4.878ms        88.79%       4.878ms       1.626ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.878ms        88.79%       4.878ms       1.626ms             3  
                                       aten::contiguous         0.18%      13.143us        24.28%       1.734ms     192.643us       0.000us         0.00%     666.300us      74.033us             9  
                                            aten::clone         0.50%      35.608us        24.09%       1.721ms     191.183us       0.000us         0.00%     666.300us      74.033us             9  
                                            aten::copy_         1.01%      71.952us        22.59%       1.613ms     179.214us     615.708us        11.21%     666.300us      74.033us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     615.708us        11.21%     615.708us      68.412us             9  
                                Activity Buffer Request        20.33%       1.452ms        20.33%       1.452ms       1.452ms      50.592us         0.92%      50.592us      50.592us             1  
                                        aten::transpose         0.87%      61.994us         1.16%      82.494us       3.437us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.29%      20.500us         0.29%      20.500us       0.854us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.25%      17.742us         1.01%      72.112us       8.012us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.17%      83.610us         1.17%      83.610us       3.981us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.60%     114.582us         1.60%     114.582us       9.548us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       3.180us         0.04%       3.180us       1.060us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.14%      10.280us         0.14%      10.280us       3.427us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        67.51%       4.821ms        67.51%       4.821ms       4.821ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.141ms
Self CUDA time total: 5.494ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.39%     253.102us        28.13%       2.097ms       2.097ms       0.000us         0.00%       5.972ms       5.972ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.926ms       100.15%       5.926ms       5.926ms             1  
                     aten::scaled_dot_product_attention         0.26%      19.190us         1.92%     143.113us      47.704us       0.000us         0.00%       5.278ms       1.759ms             3  
          aten::_scaled_dot_product_efficient_attention         0.26%      19.540us         1.66%     123.923us      41.308us       0.000us         0.00%       5.278ms       1.759ms             3  
                     aten::_efficient_attention_forward         0.37%      27.385us         1.10%      81.652us      27.217us       5.278ms        89.20%       5.278ms       1.759ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.278ms        89.20%       5.278ms       1.759ms             3  
                                       aten::contiguous         0.09%       6.999us        22.26%       1.660ms     184.423us       0.000us         0.00%     693.503us      77.056us             9  
                                            aten::clone         0.31%      23.031us        22.17%       1.653ms     183.645us       0.000us         0.00%     693.503us      77.056us             9  
                                            aten::copy_         0.83%      61.989us        21.18%       1.579ms     175.477us     638.911us        10.80%     693.503us      77.056us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     638.911us        10.80%     638.911us      70.990us             9  
                                Activity Buffer Request        19.45%       1.450ms        19.45%       1.450ms       1.450ms      54.592us         0.92%      54.592us      54.592us             1  
                                        aten::transpose         0.64%      47.641us         0.86%      64.101us       2.671us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      16.460us         0.22%      16.460us       0.686us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.16%      11.730us         0.68%      50.483us       5.609us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.86%      64.470us         0.86%      64.470us       3.070us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.21%      90.240us         1.21%      90.240us       7.520us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.290us         0.03%       2.290us       0.763us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.130us         0.04%       3.130us       1.043us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.87%       5.359ms        71.87%       5.359ms       5.359ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.456ms
Self CUDA time total: 5.917ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.16%     240.823us        26.89%       2.051ms       2.051ms       0.000us         0.00%       6.167ms       6.167ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.117ms       100.14%       6.117ms       6.117ms             1  
                     aten::scaled_dot_product_attention         0.24%      18.220us         1.81%     137.732us      45.911us       0.000us         0.00%       5.453ms       1.818ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      18.402us         1.57%     119.512us      39.837us       0.000us         0.00%       5.453ms       1.818ms             3  
                     aten::_efficient_attention_forward         0.35%      26.389us         1.04%      79.670us      26.557us       5.453ms        89.28%       5.453ms       1.818ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.453ms        89.28%       5.453ms       1.818ms             3  
                                       aten::contiguous         0.09%       6.950us        21.38%       1.630ms     181.132us       0.000us         0.00%     713.534us      79.282us             9  
                                            aten::clone         0.28%      21.189us        21.28%       1.623ms     180.360us       0.000us         0.00%     713.534us      79.282us             9  
                                            aten::copy_         0.81%      62.032us        20.34%       1.551ms     172.330us     655.038us        10.72%     713.534us      79.282us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     655.038us        10.72%     655.038us      72.782us             9  
                                Activity Buffer Request        18.63%       1.421ms        18.63%       1.421ms       1.421ms      58.496us         0.96%      58.496us      58.496us             1  
                                        aten::transpose         0.62%      47.348us         0.84%      63.699us       2.654us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.21%      16.351us         0.21%      16.351us       0.681us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.091us         0.67%      51.081us       5.676us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.86%      65.760us         0.86%      65.760us       3.131us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.18%      89.982us         1.18%      89.982us       7.498us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.210us         0.03%       2.210us       0.737us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.100us         0.04%       3.100us       1.033us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        73.11%       5.575ms        73.11%       5.575ms       5.575ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.626ms
Self CUDA time total: 6.108ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         4.44%     356.182us        33.00%       2.648ms       2.648ms       0.000us         0.00%       6.210ms       6.210ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.165ms       100.21%       6.165ms       6.165ms             1  
                     aten::scaled_dot_product_attention         0.29%      23.400us         2.31%     185.263us      61.754us       0.000us         0.00%       5.497ms       1.832ms             3  
          aten::_scaled_dot_product_efficient_attention         0.29%      23.202us         2.02%     161.863us      53.954us       0.000us         0.00%       5.497ms       1.832ms             3  
                     aten::_efficient_attention_forward         0.44%      35.239us         1.36%     108.811us      36.270us       5.497ms        89.36%       5.497ms       1.832ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.497ms        89.36%       5.497ms       1.832ms             3  
                                       aten::contiguous         0.11%       9.040us        25.54%       2.050ms     227.726us       0.000us         0.00%     712.735us      79.193us             9  
                                            aten::clone         0.35%      28.461us        25.43%       2.040ms     226.722us       0.000us         0.00%     712.735us      79.193us             9  
                                            aten::copy_         1.02%      82.020us        24.22%       1.944ms     215.993us     654.527us        10.64%     712.735us      79.193us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     654.527us        10.64%     654.527us      72.725us             9  
                                Activity Buffer Request        19.35%       1.553ms        19.35%       1.553ms       1.553ms      58.208us         0.95%      58.208us      58.208us             1  
                                        aten::transpose         0.81%      64.960us         1.09%      87.330us       3.639us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.28%      22.370us         0.28%      22.370us       0.932us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.19%      15.081us         0.85%      68.092us       7.566us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.09%      87.522us         1.09%      87.522us       4.168us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         4.25%     341.154us         4.25%     341.154us      28.429us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       2.841us         0.04%       2.841us       0.947us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       4.120us         0.05%       4.120us       1.373us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        67.00%       5.376ms        67.00%       5.376ms       5.376ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.025ms
Self CUDA time total: 6.152ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.33%     272.217us        28.45%       2.323ms       2.323ms       0.000us         0.00%       6.452ms       6.452ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.401ms       100.14%       6.401ms       6.401ms             1  
                     aten::scaled_dot_product_attention         0.25%      20.040us         1.74%     141.700us      47.233us       0.000us         0.00%       5.729ms       1.910ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      18.560us         1.49%     121.660us      40.553us       0.000us         0.00%       5.729ms       1.910ms             3  
                     aten::_efficient_attention_forward         0.34%      27.420us         1.00%      81.440us      27.147us       5.729ms        89.62%       5.729ms       1.910ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.729ms        89.62%       5.729ms       1.910ms             3  
                                       aten::contiguous         0.09%       7.310us        22.83%       1.865ms     207.177us       0.000us         0.00%     723.614us      80.402us             9  
                                            aten::clone         0.27%      22.438us        22.75%       1.857ms     206.364us       0.000us         0.00%     723.614us      80.402us             9  
                                            aten::copy_         0.75%      61.292us        21.84%       1.783ms     198.108us     663.806us        10.38%     723.614us      80.402us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     663.806us        10.38%     663.806us      73.756us             9  
                                Activity Buffer Request        18.13%       1.481ms        18.13%       1.481ms       1.481ms      59.808us         0.94%      59.808us      59.808us             1  
                                        aten::transpose         0.61%      49.591us         0.81%      66.019us       2.751us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.20%      16.428us         0.20%      16.428us       0.684us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.501us         0.64%      51.871us       5.763us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.80%      65.620us         0.80%      65.620us       3.125us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.24%     264.473us         3.24%     264.473us      22.039us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.310us         0.03%       2.310us       0.770us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.060us         0.04%       3.060us       1.020us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.55%       5.843ms        71.55%       5.843ms       5.843ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.166ms
Self CUDA time total: 6.392ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.84%     238.921us        26.25%       2.206ms       2.206ms       0.000us         0.00%       6.803ms       6.803ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.751ms       100.13%       6.751ms       6.751ms             1  
                     aten::scaled_dot_product_attention         0.23%      19.080us         1.67%     140.122us      46.707us       0.000us         0.00%       6.072ms       2.024ms             3  
          aten::_scaled_dot_product_efficient_attention         0.22%      18.680us         1.44%     121.042us      40.347us       0.000us         0.00%       6.072ms       2.024ms             3  
                     aten::_efficient_attention_forward         0.32%      27.009us         0.95%      79.840us      26.613us       6.072ms        90.07%       6.072ms       2.024ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       6.072ms        90.07%       6.072ms       2.024ms             3  
                                       aten::contiguous         0.09%       7.439us        21.24%       1.785ms     198.324us       0.000us         0.00%     731.099us      81.233us             9  
                                            aten::clone         0.26%      21.852us        21.15%       1.777ms     197.498us       0.000us         0.00%     731.099us      81.233us             9  
                                            aten::copy_         0.77%      64.769us        20.27%       1.703ms     189.239us     669.820us         9.93%     731.099us      81.233us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     669.820us         9.93%     669.820us      74.424us             9  
                                Activity Buffer Request        16.92%       1.422ms        16.92%       1.422ms       1.422ms      61.279us         0.91%      61.279us      61.279us             1  
                                        aten::transpose         0.57%      48.271us         0.77%      64.334us       2.681us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.19%      16.063us         0.19%      16.063us       0.669us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.440us         0.62%      52.480us       5.831us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.79%      66.661us         0.79%      66.661us       3.174us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         2.84%     238.383us         2.84%     238.383us      19.865us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.270us         0.03%       2.270us       0.757us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.090us         0.04%       3.090us       1.030us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        73.75%       6.196ms        73.75%       6.196ms       6.196ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.402ms
Self CUDA time total: 6.742ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.89  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.95  True
torch_mem_eff            cuda_attn_L320_bfloat16     2.05  True
torch_mem_eff            cuda_attn_L384_bfloat16     2.08  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.13  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.27  True

Artifacts:

attention.jsonl