Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 3.94s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         5.14%     365.276us        32.53%       2.313ms       2.313ms       0.000us         0.00%       5.511ms       5.511ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.492ms       100.58%       5.492ms       5.492ms             1  
                     aten::scaled_dot_product_attention         0.43%      30.401us         2.47%     175.534us      58.511us       0.000us         0.00%       4.841ms       1.614ms             3  
          aten::_scaled_dot_product_efficient_attention         0.33%      23.489us         2.04%     145.133us      48.378us       0.000us         0.00%       4.841ms       1.614ms             3  
                     aten::_efficient_attention_forward         0.51%      36.572us         1.40%      99.733us      33.244us       4.841ms        88.65%       4.841ms       1.614ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.841ms        88.65%       4.841ms       1.614ms             3  
                                       aten::contiguous         0.18%      12.851us        23.99%       1.706ms     189.523us       0.000us         0.00%     670.241us      74.471us             9  
                                            aten::clone         0.46%      32.742us        23.80%       1.693ms     188.095us       0.000us         0.00%     670.241us      74.471us             9  
                                            aten::copy_         1.05%      74.801us        22.33%       1.588ms     176.415us     619.776us        11.35%     670.241us      74.471us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     619.776us        11.35%     619.776us      68.864us             9  
                                Activity Buffer Request        20.17%       1.434ms        20.17%       1.434ms       1.434ms      50.465us         0.92%      50.465us      50.465us             1  
                                        aten::transpose         0.93%      66.224us         1.25%      88.644us       3.693us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.32%      22.420us         0.32%      22.420us       0.934us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.25%      17.919us         1.02%      72.382us       8.042us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.14%      81.114us         1.14%      81.114us       3.863us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.46%     103.973us         1.46%     103.973us       8.664us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       2.960us         0.04%       2.960us       0.987us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.12%       8.310us         0.12%       8.310us       2.770us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        67.47%       4.798ms        67.47%       4.798ms       4.798ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.111ms
Self CUDA time total: 5.460ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.28%     242.746us        28.00%       2.075ms       2.075ms       0.000us         0.00%       5.933ms       5.933ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.886ms       100.14%       5.886ms       5.886ms             1  
                     aten::scaled_dot_product_attention         0.25%      18.240us         1.89%     140.073us      46.691us       0.000us         0.00%       5.241ms       1.747ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      18.689us         1.64%     121.833us      40.611us       0.000us         0.00%       5.241ms       1.747ms             3  
                     aten::_efficient_attention_forward         0.38%      28.462us         1.09%      81.063us      27.021us       5.241ms        89.17%       5.241ms       1.747ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.241ms        89.17%       5.241ms       1.747ms             3  
                                       aten::contiguous         0.10%       7.041us        22.26%       1.650ms     183.285us       0.000us         0.00%     691.103us      76.789us             9  
                                            aten::clone         0.29%      21.342us        22.17%       1.643ms     182.503us       0.000us         0.00%     691.103us      76.789us             9  
                                            aten::copy_         0.86%      63.451us        21.24%       1.574ms     174.872us     636.671us        10.83%     691.103us      76.789us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     636.671us        10.83%     636.671us      70.741us             9  
                                Activity Buffer Request        19.50%       1.445ms        19.50%       1.445ms       1.445ms      54.432us         0.93%      54.432us      54.432us             1  
                                        aten::transpose         0.64%      47.650us         0.87%      64.701us       2.696us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.23%      17.051us         0.23%      17.051us       0.710us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.16%      11.589us         0.64%      47.330us       5.259us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.82%      60.521us         0.82%      60.521us       2.882us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.19%      88.044us         1.19%      88.044us       7.337us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.420us         0.03%       2.420us       0.807us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.030us         0.04%       3.030us       1.010us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        72.00%       5.335ms        72.00%       5.335ms       5.335ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.410ms
Self CUDA time total: 5.878ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.21%     244.055us        27.47%       2.092ms       2.092ms       0.000us         0.00%       6.130ms       6.130ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.080ms       100.14%       6.080ms       6.080ms             1  
                     aten::scaled_dot_product_attention         0.23%      17.641us         1.86%     141.944us      47.315us       0.000us         0.00%       5.414ms       1.805ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      19.359us         1.63%     124.303us      41.434us       0.000us         0.00%       5.414ms       1.805ms             3  
                     aten::_efficient_attention_forward         0.37%      28.219us         1.06%      80.592us      26.864us       5.414ms        89.17%       5.414ms       1.805ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.414ms        89.17%       5.414ms       1.805ms             3  
                                       aten::contiguous         0.11%       8.060us        21.81%       1.661ms     184.510us       0.000us         0.00%     716.192us      79.577us             9  
                                            aten::clone         0.29%      22.431us        21.70%       1.653ms     183.615us       0.000us         0.00%     716.192us      79.577us             9  
                                            aten::copy_         0.81%      61.641us        20.75%       1.580ms     175.564us     657.728us        10.83%     716.192us      79.577us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     657.728us        10.83%     657.728us      73.081us             9  
                                Activity Buffer Request        19.08%       1.453ms        19.08%       1.453ms       1.453ms      58.464us         0.96%      58.464us      58.464us             1  
                                        aten::transpose         0.69%      52.203us         0.92%      69.763us       2.907us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.23%      17.560us         0.23%      17.560us       0.732us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.581us         0.66%      50.023us       5.558us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.84%      63.785us         0.84%      63.785us       3.037us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.14%      86.832us         1.14%      86.832us       7.236us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.250us         0.03%       2.250us       0.750us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.260us         0.04%       3.260us       1.087us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        72.53%       5.522ms        72.53%       5.522ms       5.522ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.614ms
Self CUDA time total: 6.072ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.16%     248.365us        29.29%       2.300ms       2.300ms       0.000us         0.00%       6.163ms       6.163ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.114ms       100.14%       6.114ms       6.114ms             1  
                     aten::scaled_dot_product_attention         0.24%      19.232us         1.82%     142.774us      47.591us       0.000us         0.00%       5.452ms       1.817ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      19.461us         1.57%     123.542us      41.181us       0.000us         0.00%       5.452ms       1.817ms             3  
                     aten::_efficient_attention_forward         0.37%      29.029us         1.03%      80.672us      26.891us       5.452ms        89.29%       5.452ms       1.817ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.452ms        89.29%       5.452ms       1.817ms             3  
                                       aten::contiguous         0.10%       7.931us        23.78%       1.867ms     207.435us       0.000us         0.00%     711.072us      79.008us             9  
                                            aten::clone         0.30%      23.532us        23.68%       1.859ms     206.554us       0.000us         0.00%     711.072us      79.008us             9  
                                            aten::copy_         0.81%      63.779us        22.73%       1.785ms     198.306us     653.792us        10.71%     711.072us      79.008us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     653.792us        10.71%     653.792us      72.644us             9  
                                Activity Buffer Request        18.59%       1.459ms        18.59%       1.459ms       1.459ms      57.280us         0.94%      57.280us      57.280us             1  
                                        aten::transpose         0.62%      48.610us         0.83%      65.130us       2.714us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.21%      16.520us         0.21%      16.520us       0.688us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.16%      12.281us         0.65%      50.702us       5.634us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.80%      62.502us         0.80%      62.502us       2.976us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.60%     282.729us         3.60%     282.729us      23.561us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.471us         0.03%       2.471us       0.824us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       4.120us         0.05%       4.120us       1.373us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.71%       5.551ms        70.71%       5.551ms       5.551ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.851ms
Self CUDA time total: 6.106ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.01%     243.675us        28.03%       2.272ms       2.272ms       0.000us         0.00%       6.451ms       6.451ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.399ms       100.13%       6.399ms       6.399ms             1  
                     aten::scaled_dot_product_attention         0.23%      18.671us         1.77%     143.224us      47.741us       0.000us         0.00%       5.726ms       1.909ms             3  
          aten::_scaled_dot_product_efficient_attention         0.24%      19.652us         1.54%     124.553us      41.518us       0.000us         0.00%       5.726ms       1.909ms             3  
                     aten::_efficient_attention_forward         0.35%      28.317us         0.99%      80.642us      26.881us       5.726ms        89.60%       5.726ms       1.909ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.726ms        89.60%       5.726ms       1.909ms             3  
                                       aten::contiguous         0.10%       7.791us        22.70%       1.840ms     204.460us       0.000us         0.00%     725.025us      80.558us             9  
                                            aten::clone         0.29%      23.489us        22.61%       1.832ms     203.594us       0.000us         0.00%     725.025us      80.558us             9  
                                            aten::copy_         0.81%      65.293us        21.68%       1.757ms     195.223us     664.641us        10.40%     725.025us      80.558us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     664.641us        10.40%     664.641us      73.849us             9  
                                Activity Buffer Request        17.77%       1.440ms        17.77%       1.440ms       1.440ms      60.384us         0.94%      60.384us      60.384us             1  
                                        aten::transpose         0.63%      51.151us         0.85%      69.251us       2.885us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.22%      18.100us         0.22%      18.100us       0.754us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.960us         0.64%      51.852us       5.761us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.79%      64.314us         0.79%      64.314us       3.063us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.36%     272.117us         3.36%     272.117us      22.676us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.500us         0.03%       2.500us       0.833us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.06%       4.532us         0.06%       4.532us       1.511us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        71.97%       5.833ms        71.97%       5.833ms       5.833ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.105ms
Self CUDA time total: 6.391ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.88%     242.135us        27.00%       2.269ms       2.269ms       0.000us         0.00%       6.759ms       6.759ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.705ms       100.12%       6.705ms       6.705ms             1  
                     aten::scaled_dot_product_attention         0.21%      17.851us         1.72%     144.884us      48.295us       0.000us         0.00%       6.024ms       2.008ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      19.591us         1.51%     127.033us      42.344us       0.000us         0.00%       6.024ms       2.008ms             3  
                     aten::_efficient_attention_forward         0.34%      28.520us         0.97%      81.532us      27.177us       6.024ms        89.96%       6.024ms       2.008ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       6.024ms        89.96%       6.024ms       2.008ms             3  
                                       aten::contiguous         0.10%       8.099us        21.87%       1.838ms     204.242us       0.000us         0.00%     734.178us      81.575us             9  
                                            aten::clone         0.28%      23.122us        21.78%       1.830ms     203.342us       0.000us         0.00%     734.178us      81.575us             9  
                                            aten::copy_         0.74%      62.180us        20.86%       1.753ms     194.799us     672.322us        10.04%     734.178us      81.575us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     672.322us        10.04%     672.322us      74.702us             9  
                                Activity Buffer Request        17.19%       1.445ms        17.19%       1.445ms       1.445ms      61.856us         0.92%      61.856us      61.856us             1  
                                        aten::transpose         0.62%      52.351us         0.83%      70.022us       2.918us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.21%      17.671us         0.21%      17.671us       0.736us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      12.653us         0.64%      53.763us       5.974us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.79%      66.761us         0.79%      66.761us       3.179us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.19%     267.907us         3.19%     267.907us      22.326us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.430us         0.03%       2.430us       0.810us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.350us         0.04%       3.350us       1.117us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        73.00%       6.134ms        73.00%       6.134ms       6.134ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.404ms
Self CUDA time total: 6.697ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.85  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.95  True
torch_mem_eff            cuda_attn_L320_bfloat16     1.99  True
torch_mem_eff            cuda_attn_L384_bfloat16     2.07  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.06  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.25  True

Artifacts:

attention.jsonl