Memory Efficient Attention Implementation

Memory Efficient SDPA Benchmark

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 4.18s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_mem_eff(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(
        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
    ):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


run_benchmark(
    kernel_type=KernelTypeEnum.ATTENTION,
    impl_name="torch_mem_eff",
    impl_tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
    impl_func=torch_mem_eff,
)
Running attention benchmark on cuda with 6 workloads.

======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L128_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         4.45%     324.566us        35.26%       2.573ms       2.573ms       0.000us         0.00%       5.439ms       5.439ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.406ms       100.38%       5.406ms       5.406ms             1  
                     aten::scaled_dot_product_attention         0.42%      30.389us         2.31%     168.211us      56.070us       0.000us         0.00%       4.771ms       1.590ms             3  
          aten::_scaled_dot_product_efficient_attention         0.30%      21.751us         1.89%     137.822us      45.941us       0.000us         0.00%       4.771ms       1.590ms             3  
                     aten::_efficient_attention_forward         0.46%      33.370us         1.30%      95.011us      31.670us       4.771ms        88.58%       4.771ms       1.590ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       4.771ms        88.58%       4.771ms       1.590ms             3  
                                       aten::contiguous         0.14%      10.493us        27.68%       2.020ms     224.395us       0.000us         0.00%     668.482us      74.276us             9  
                                            aten::clone         0.39%      28.130us        27.53%       2.009ms     223.229us       0.000us         0.00%     668.482us      74.276us             9  
                                            aten::copy_         1.01%      73.701us        26.23%       1.914ms     212.678us     614.946us        11.42%     668.482us      74.276us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     614.946us        11.42%     614.946us      68.327us             9  
                                Activity Buffer Request        24.11%       1.759ms        24.11%       1.759ms       1.759ms      53.536us         0.99%      53.536us      53.536us             1  
                                        aten::transpose         0.83%      60.400us         1.12%      81.609us       3.400us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.29%      21.209us         0.29%      21.209us       0.884us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.20%      14.439us         0.92%      66.830us       7.426us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         1.09%      79.191us         1.09%      79.191us       3.771us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.43%     104.332us         1.43%     104.332us       8.694us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.04%       3.220us         0.04%       3.220us       1.073us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.12%       8.781us         0.12%       8.781us       2.927us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        64.74%       4.724ms        64.74%       4.724ms       4.724ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.297ms
Self CUDA time total: 5.386ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L256_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.09%     234.954us        31.65%       2.404ms       2.404ms       0.000us         0.00%       5.782ms       5.782ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.735ms       100.14%       5.735ms       5.735ms             1  
                     aten::scaled_dot_product_attention         0.22%      16.961us         1.81%     137.382us      45.794us       0.000us         0.00%       5.091ms       1.697ms             3  
          aten::_scaled_dot_product_efficient_attention         0.25%      19.139us         1.59%     120.421us      40.140us       0.000us         0.00%       5.091ms       1.697ms             3  
                     aten::_efficient_attention_forward         0.36%      27.009us         1.04%      78.740us      26.247us       5.091ms        88.89%       5.091ms       1.697ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.091ms        88.89%       5.091ms       1.697ms             3  
                                       aten::contiguous         0.11%       8.479us        26.21%       1.991ms     221.170us       0.000us         0.00%     690.720us      76.747us             9  
                                            aten::clone         0.29%      22.002us        26.10%       1.982ms     220.228us       0.000us         0.00%     690.720us      76.747us             9  
                                            aten::copy_         0.83%      62.671us        25.16%       1.911ms     212.305us     636.032us        11.11%     690.720us      76.747us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     636.032us        11.11%     636.032us      70.670us             9  
                                Activity Buffer Request        23.48%       1.783ms        23.48%       1.783ms       1.783ms      54.688us         0.95%      54.688us      54.688us             1  
                                        aten::transpose         0.64%      48.410us         0.84%      63.823us       2.659us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.20%      15.413us         0.20%      15.413us       0.642us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.729us         0.65%      49.301us       5.478us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.82%      62.552us         0.82%      62.552us       2.979us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.14%      86.431us         1.14%      86.431us       7.203us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.280us         0.03%       2.280us       0.760us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       2.990us         0.04%       2.990us       0.997us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        68.35%       5.191ms        68.35%       5.191ms       5.191ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.595ms
Self CUDA time total: 5.727ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L320_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.06%     239.384us        30.93%       2.420ms       2.420ms       0.000us         0.00%       5.994ms       5.994ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       5.947ms       100.14%       5.947ms       5.947ms             1  
                     aten::scaled_dot_product_attention         0.22%      17.549us         1.74%     135.892us      45.297us       0.000us         0.00%       5.295ms       1.765ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      18.333us         1.51%     118.343us      39.448us       0.000us         0.00%       5.295ms       1.765ms             3  
                     aten::_efficient_attention_forward         0.35%      27.055us         1.01%      79.012us      26.337us       5.295ms        89.16%       5.295ms       1.765ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.295ms        89.16%       5.295ms       1.765ms             3  
                                       aten::contiguous         0.10%       7.948us        25.59%       2.002ms     222.464us       0.000us         0.00%     699.457us      77.717us             9  
                                            aten::clone         0.26%      20.152us        25.49%       1.994ms     221.581us       0.000us         0.00%     699.457us      77.717us             9  
                                            aten::copy_         0.79%      62.172us        24.60%       1.924ms     213.808us     643.713us        10.84%     699.457us      77.717us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     643.713us        10.84%     643.713us      71.524us             9  
                                Activity Buffer Request        22.96%       1.796ms        22.96%       1.796ms       1.796ms      55.744us         0.94%      55.744us      55.744us             1  
                                        aten::transpose         0.61%      48.091us         0.81%      63.198us       2.633us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.19%      15.107us         0.19%      15.107us       0.629us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.152us         0.64%      49.811us       5.535us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.80%      62.567us         0.80%      62.567us       2.979us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         1.12%      87.709us         1.12%      87.709us       7.309us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.429us         0.03%       2.429us       0.810us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.05%       3.800us         0.05%       3.800us       1.267us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        69.07%       5.404ms        69.07%       5.404ms       5.404ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.823ms
Self CUDA time total: 5.939ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L384_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         3.00%     242.264us        30.89%       2.499ms       2.499ms       0.000us         0.00%       6.191ms       6.191ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.141ms       100.14%       6.141ms       6.141ms             1  
                     aten::scaled_dot_product_attention         0.23%      18.320us         1.69%     136.812us      45.604us       0.000us         0.00%       5.471ms       1.824ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      18.630us         1.46%     118.492us      39.497us       0.000us         0.00%       5.471ms       1.824ms             3  
                     aten::_efficient_attention_forward         0.33%      26.674us         0.96%      77.952us      25.984us       5.471ms        89.22%       5.471ms       1.824ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.471ms        89.22%       5.471ms       1.824ms             3  
                                       aten::contiguous         0.10%       8.440us        25.67%       2.076ms     230.653us       0.000us         0.00%     719.363us      79.929us             9  
                                            aten::clone         0.28%      22.639us        25.56%       2.067ms     229.716us       0.000us         0.00%     719.363us      79.929us             9  
                                            aten::copy_         0.78%      63.183us        24.67%       1.995ms     221.702us     660.931us        10.78%     719.363us      79.929us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     660.931us        10.78%     660.931us      73.437us             9  
                                Activity Buffer Request        21.08%       1.705ms        21.08%       1.705ms       1.705ms      58.432us         0.95%      58.432us      58.432us             1  
                                        aten::transpose         0.61%      49.449us         0.81%      65.670us       2.736us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.20%      16.221us         0.20%      16.221us       0.676us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.15%      11.742us         0.61%      49.481us       5.498us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.77%      62.526us         0.77%      62.526us       2.977us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         3.07%     248.624us         3.07%     248.624us      20.719us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.250us         0.03%       2.250us       0.750us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.020us         0.04%       3.020us       1.007us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        69.11%       5.590ms        69.11%       5.590ms       5.590ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.088ms
Self CUDA time total: 6.132ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L448_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.96%     243.644us        31.20%       2.571ms       2.571ms       0.000us         0.00%       6.270ms       6.270ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.220ms       100.13%       6.220ms       6.220ms             1  
                     aten::scaled_dot_product_attention         0.22%      18.340us         1.66%     136.411us      45.470us       0.000us         0.00%       5.544ms       1.848ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      18.620us         1.43%     118.071us      39.357us       0.000us         0.00%       5.544ms       1.848ms             3  
                     aten::_efficient_attention_forward         0.33%      26.920us         0.94%      77.841us      25.947us       5.544ms        89.24%       5.544ms       1.848ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.544ms        89.24%       5.544ms       1.848ms             3  
                                       aten::contiguous         0.10%       8.441us        26.08%       2.149ms     238.754us       0.000us         0.00%     726.626us      80.736us             9  
                                            aten::clone         0.27%      22.559us        25.98%       2.140ms     237.816us       0.000us         0.00%     726.626us      80.736us             9  
                                            aten::copy_         0.77%      63.181us        25.09%       2.068ms     229.736us     668.130us        10.76%     726.626us      80.736us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     668.130us        10.76%     668.130us      74.237us             9  
                                Activity Buffer Request        21.61%       1.780ms        21.61%       1.780ms       1.780ms      58.496us         0.94%      58.496us      58.496us             1  
                                        aten::transpose         0.58%      47.889us         0.77%      63.801us       2.658us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.19%      15.912us         0.19%      15.912us       0.663us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.14%      11.871us         0.61%      50.162us       5.574us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.75%      62.051us         0.75%      62.051us       2.955us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         2.98%     245.563us         2.98%     245.563us      20.464us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.280us         0.03%       2.280us       0.760us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.301us         0.04%       3.301us       1.100us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        68.80%       5.669ms        68.80%       5.669ms       5.669ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.240ms
Self CUDA time total: 6.212ms



======================================================================
PROFILE TRACE: torch_mem_eff | cuda_attn_L512_bfloat16
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          torch_mem_eff         2.78%     238.352us        29.12%       2.495ms       2.495ms       0.000us         0.00%       6.680ms       6.680ms             1  
                                          torch_mem_eff         0.00%       0.000us         0.00%       0.000us       0.000us       6.628ms       100.13%       6.628ms       6.628ms             1  
                     aten::scaled_dot_product_attention         0.31%      26.242us         1.71%     146.743us      48.914us       0.000us         0.00%       5.945ms       1.982ms             3  
          aten::_scaled_dot_product_efficient_attention         0.23%      19.839us         1.41%     120.501us      40.167us       0.000us         0.00%       5.945ms       1.982ms             3  
                     aten::_efficient_attention_forward         0.31%      26.859us         0.92%      78.900us      26.300us       5.945ms        89.80%       5.945ms       1.982ms             3  
fmha_cutlassF_bf16_aligned_64x128_rf_sm80(PyTorchMem...         0.00%       0.000us         0.00%       0.000us       0.000us       5.945ms        89.80%       5.945ms       1.982ms             3  
                                       aten::contiguous         0.09%       7.528us        24.13%       2.068ms     229.726us       0.000us         0.00%     735.685us      81.743us             9  
                                            aten::clone         0.24%      20.962us        24.04%       2.060ms     228.889us       0.000us         0.00%     735.685us      81.743us             9  
                                            aten::copy_         0.75%      64.071us        23.20%       1.988ms     220.897us     675.044us        10.20%     735.685us      81.743us             9  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     675.044us        10.20%     675.044us      75.005us             9  
                                Activity Buffer Request        19.86%       1.702ms        19.86%       1.702ms       1.702ms      60.641us         0.92%      60.641us      60.641us             1  
                                        aten::transpose         0.56%      47.940us         0.74%      63.783us       2.658us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::as_strided         0.18%      15.843us         0.18%      15.843us       0.660us       0.000us         0.00%       0.000us       0.000us            24  
                                       aten::empty_like         0.13%      11.513us         0.59%      50.972us       5.664us       0.000us         0.00%       0.000us       0.000us             9  
                                            aten::empty         0.75%      64.430us         0.75%      64.430us       3.068us       0.000us         0.00%       0.000us       0.000us            21  
                                       cudaLaunchKernel         2.85%     243.883us         2.85%     243.883us      20.324us       0.000us         0.00%       0.000us       0.000us            12  
                                  cudaStreamIsCapturing         0.03%       2.530us         0.03%       2.530us       0.843us       0.000us         0.00%       0.000us       0.000us             3  
                                   cudaFuncSetAttribute         0.04%       3.050us         0.04%       3.050us       1.017us       0.000us         0.00%       0.000us       0.000us             3  
                                  cudaDeviceSynchronize        70.88%       6.073ms        70.88%       6.073ms       6.073ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.568ms
Self CUDA time total: 6.620ms


impl                     wl                  p50(ms)  ok
torch_mem_eff            cuda_attn_L128_bfloat16     1.83  True
torch_mem_eff            cuda_attn_L256_bfloat16     1.94  True
torch_mem_eff            cuda_attn_L320_bfloat16     1.96  True
torch_mem_eff            cuda_attn_L384_bfloat16     2.03  True
torch_mem_eff            cuda_attn_L448_bfloat16     2.02  True
torch_mem_eff            cuda_attn_L512_bfloat16     2.23  True

Artifacts:

attention.jsonl