PyTorch Native - Deformable DETR

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.23s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Oct 31 20:13:34 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   43C    P0             83W /  350W |       0MiB /  46068MiB |     60%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Deformable DETR Multi-Scale Deformable Attention Benchmark (PyTorch Native)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 5.33s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_deformable_detr(
    value, spatial_shapes, level_start_index, sampling_locations, attention_weights, im2col_step=64
):
    """
    PyTorch native reference implementation of multi-scale deformable attention.
    Uses vectorized bilinear interpolation for reasonable performance.
    """
    bs, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
    _, _, _, channels = value.shape

    output = torch.zeros(bs, num_queries, num_heads, channels, device=value.device, dtype=value.dtype)

    # Split value tensor by levels
    value_list = value.split([int(h * w) for h, w in spatial_shapes.tolist()], dim=1)

    # Iterate through each level (can't avoid this loop easily)
    for level_idx in range(num_levels):
        h, w = spatial_shapes[level_idx].tolist()
        value_level = value_list[level_idx]  # (bs, h*w, num_heads, channels)

        # Reshape to spatial grid: (bs, num_heads, channels, h, w)
        value_spatial = value_level.reshape(bs, h, w, num_heads, channels).permute(0, 3, 4, 1, 2)

        # Get sampling locations and weights for this level
        # loc: (bs, num_queries, num_heads, num_points, 2)
        loc = sampling_locations[:, :, :, level_idx, :, :]
        # weight: (bs, num_queries, num_heads, num_points)
        weight = attention_weights[:, :, :, level_idx, :]

        # Convert normalized coordinates to pixel coordinates
        # loc[..., 0] is x (width), loc[..., 1] is y (height)
        x = loc[..., 0] * w - 0.5  # (bs, num_queries, num_heads, num_points)
        y = loc[..., 1] * h - 0.5

        # Get integer coordinates for bilinear interpolation
        x0 = torch.floor(x).long()
        y0 = torch.floor(y).long()
        x1 = x0 + 1
        y1 = y0 + 1

        # Compute interpolation weights BEFORE clamping (important!)
        lw = x - x0.float()  # weight for x direction
        lh = y - y0.float()  # weight for y direction
        hw = 1 - lw
        hh = 1 - lh

        # Create mask for valid sample locations
        valid = (y > -1) & (x > -1) & (y < h) & (x < w)

        # Create masks for each corner being in bounds
        mask_tl = ((y0 >= 0) & (x0 >= 0)).unsqueeze(-1).float()
        mask_tr = ((y0 >= 0) & (x1 <= w - 1)).unsqueeze(-1).float()
        mask_bl = ((y1 <= h - 1) & (x0 >= 0)).unsqueeze(-1).float()
        mask_br = ((y1 <= h - 1) & (x1 <= w - 1)).unsqueeze(-1).float()

        # Clamp coordinates for safe indexing
        x0_clamped = torch.clamp(x0, 0, w - 1)
        x1_clamped = torch.clamp(x1, 0, w - 1)
        y0_clamped = torch.clamp(y0, 0, h - 1)
        y1_clamped = torch.clamp(y1, 0, h - 1)

        # Bilinear interpolation weights for all 4 corners
        w_tl = (hh * hw).unsqueeze(-1)  # top-left: (bs, num_queries, num_heads, num_points, 1)
        w_tr = (hh * lw).unsqueeze(-1)  # top-right
        w_bl = (lh * hw).unsqueeze(-1)  # bottom-left
        w_br = (lh * lw).unsqueeze(-1)  # bottom-right

        # Gather values from the 4 corners using advanced indexing
        batch_idx = torch.arange(bs, device=value.device).view(bs, 1, 1, 1).expand(bs, num_queries, num_heads, num_points)
        head_idx = torch.arange(num_heads, device=value.device).view(1, 1, num_heads, 1).expand(bs, num_queries, num_heads, num_points)

        # Gather corner values with clamped indices, then apply corner masks
        v_tl = value_spatial[batch_idx, head_idx, :, y0_clamped, x0_clamped] * mask_tl
        v_tr = value_spatial[batch_idx, head_idx, :, y0_clamped, x1_clamped] * mask_tr
        v_bl = value_spatial[batch_idx, head_idx, :, y1_clamped, x0_clamped] * mask_bl
        v_br = value_spatial[batch_idx, head_idx, :, y1_clamped, x1_clamped] * mask_br

        # Bilinear interpolation
        sampled = w_tl * v_tl + w_tr * v_tr + w_bl * v_bl + w_br * v_br

        # Apply valid mask (only accumulate if entire sample location is valid)
        sampled = sampled * valid.unsqueeze(-1).float()

        # Apply attention weights and sum over points
        # weight: (bs, num_queries, num_heads, num_points)
        # Expand weight: (bs, num_queries, num_heads, num_points, 1)
        weighted_sampled = sampled * weight.unsqueeze(-1)

        # Sum over points: (bs, num_queries, num_heads, channels)
        output += weighted_sampled.sum(dim=3)

    # Flatten last two dimensions to match kernel output
    return output.reshape(bs, num_queries, num_heads * channels)


run_benchmark(
    kernel_type=KernelTypeEnum.DEFORMABLE_DETR,
    impl_name="torch_eager",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=torch_deformable_detr,
    dtype="float32",
)
Running deformable_detr benchmark on cuda with 4 workloads.

======================================================================
PROFILE TRACE: torch_eager | cuda_B1_Q100_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      20.095ms      1353.99%      20.095ms      20.095ms             1  
                                            torch_eager        21.57%       4.703ms        99.97%      21.796ms      21.796ms       0.000us         0.00%       1.485ms       1.485ms             1  
                                            aten::index         4.62%       1.006ms        16.78%       3.660ms      76.241us     237.342us        15.99%     371.712us       7.744us            48  
                                            aten::copy_         4.87%       1.061ms        11.32%       2.469ms      11.275us     365.385us        24.62%     365.385us       1.668us           219  
                                              aten::mul         5.80%       1.265ms         9.92%       2.163ms      11.267us     294.264us        19.83%     294.264us       1.533us           192  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     237.342us        15.99%     237.342us       4.945us            48  
                                               aten::to         0.67%     145.268us        11.20%       2.441ms      14.275us       0.000us         0.00%     231.015us       1.351us           171  
                                         aten::_to_copy         2.25%     489.538us        10.53%       2.296ms      18.665us       0.000us         0.00%     231.015us       1.878us           123  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     202.558us        13.65%     202.558us       1.688us           120  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     167.074us        11.26%     167.074us       1.989us            84  
                                       aten::contiguous         0.40%      86.639us         8.70%       1.898ms      19.769us       0.000us         0.00%     134.370us       1.400us            96  
                                            aten::clone         0.85%     185.683us         8.31%       1.811ms      18.866us       0.000us         0.00%     134.370us       1.400us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     134.370us         9.05%     134.370us       1.400us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     115.390us         7.77%     115.390us       1.202us            96  
                                          aten::__and__         0.63%     137.184us         4.49%     979.904us      11.666us       0.000us         0.00%     100.670us       1.198us            84  
                                      aten::bitwise_and         2.39%     521.552us         3.87%     842.720us      10.032us     100.670us         6.78%     100.670us       1.198us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     100.670us         6.78%     100.670us       1.198us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      85.858us         5.78%      85.858us       1.192us            72  
                                              aten::sub         2.24%     488.685us         3.68%     801.476us      11.132us      78.884us         5.32%      78.884us       1.096us            72  
                                              aten::add         1.55%     338.597us         2.59%     564.753us       9.413us      74.082us         4.99%      74.082us       1.235us            60  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.803ms
Self CUDA time total: 1.484ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B1_Q300_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      18.852ms      1182.31%      18.852ms      18.852ms             1  
                                            torch_eager        20.99%       4.304ms        99.97%      20.495ms      20.495ms       0.000us         0.00%       1.595ms       1.595ms             1  
                                            aten::index         4.61%     945.020us        16.80%       3.444ms      71.750us     251.167us        15.75%     382.850us       7.976us            48  
                                            aten::copy_         5.04%       1.033ms        11.78%       2.414ms      11.023us     364.991us        22.89%     364.991us       1.667us           219  
                                              aten::mul         5.94%       1.218ms        10.22%       2.095ms      10.911us     359.138us        22.52%     359.138us       1.871us           192  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     267.618us        16.78%     267.618us       2.230us           120  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     251.167us        15.75%     251.167us       5.233us            48  
                                               aten::to         0.59%     120.975us        11.17%       2.290ms      13.390us       0.000us         0.00%     233.308us       1.364us           171  
                                         aten::_to_copy         2.01%     411.895us        10.58%       2.169ms      17.632us       0.000us         0.00%     233.308us       1.897us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     168.797us        10.59%     168.797us       2.009us            84  
                                       aten::contiguous         0.41%      84.261us         8.87%       1.818ms      18.936us       0.000us         0.00%     131.683us       1.372us            96  
                                            aten::clone         0.84%     172.318us         8.46%       1.734ms      18.058us       0.000us         0.00%     131.683us       1.372us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     131.683us         8.26%     131.683us       1.372us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     118.123us         7.41%     118.123us       1.230us            96  
                                          aten::__and__         0.40%      81.276us         4.41%     903.196us      10.752us       0.000us         0.00%     104.833us       1.248us            84  
                                      aten::bitwise_and         2.46%     504.088us         4.01%     821.920us       9.785us     104.833us         6.57%     104.833us       1.248us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     104.833us         6.57%     104.833us       1.248us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     104.190us         6.53%     104.190us       1.447us            72  
                                              aten::add         1.62%     331.582us         2.72%     557.857us       9.298us      91.491us         5.74%      91.491us       1.525us            60  
                                              aten::sub         2.17%     445.533us         3.70%     758.959us      10.541us      80.509us         5.05%      80.509us       1.118us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.501ms
Self CUDA time total: 1.595ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B2_Q100_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      18.792ms      1222.95%      18.792ms      18.792ms             1  
                                            torch_eager        21.02%       4.299ms        99.97%      20.449ms      20.449ms       0.000us         0.00%       1.538ms       1.538ms             1  
                                            aten::index         4.62%     944.347us        16.78%       3.432ms      71.497us     243.904us        15.87%     378.785us       7.891us            48  
                                            aten::copy_         5.14%       1.051ms        11.72%       2.396ms      10.942us     368.961us        24.01%     368.961us       1.685us           219  
                                              aten::mul         5.96%       1.219ms        10.23%       2.092ms      10.898us     325.334us        21.17%     325.334us       1.694us           192  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     243.904us        15.87%     243.904us       5.081us            48  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     234.457us        15.26%     234.457us       1.954us           120  
                                               aten::to         0.61%     125.558us        11.02%       2.255ms      13.184us       0.000us         0.00%     234.080us       1.369us           171  
                                         aten::_to_copy         1.92%     392.900us        10.41%       2.129ms      17.309us       0.000us         0.00%     234.080us       1.903us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     169.246us        11.01%     169.246us       2.015us            84  
                                       aten::contiguous         0.42%      85.559us         8.81%       1.802ms      18.772us       0.000us         0.00%     134.881us       1.405us            96  
                                            aten::clone         0.80%     164.449us         8.39%       1.717ms      17.880us       0.000us         0.00%     134.881us       1.405us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     134.881us         8.78%     134.881us       1.405us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     115.650us         7.53%     115.650us       1.205us            96  
                                          aten::__and__         0.39%      78.814us         4.36%     891.116us      10.609us       0.000us         0.00%     101.539us       1.209us            84  
                                      aten::bitwise_and         2.44%     499.687us         3.97%     812.302us       9.670us     101.539us         6.61%     101.539us       1.209us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     101.539us         6.61%     101.539us       1.209us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      96.065us         6.25%      96.065us       1.334us            72  
                                              aten::add         1.62%     331.717us         2.71%     554.333us       9.239us      83.900us         5.46%      83.900us       1.398us            60  
                                              aten::sub         2.21%     451.413us         3.69%     755.537us      10.494us      79.361us         5.16%      79.361us       1.102us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.454ms
Self CUDA time total: 1.537ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B2_Q300_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      19.115ms      1086.36%      19.115ms      19.115ms             1  
                                            torch_eager        21.90%       4.346ms        99.98%      19.842ms      19.842ms       0.000us         0.00%       1.761ms       1.761ms             1  
                                              aten::mul         6.18%       1.226ms        10.60%       2.104ms      10.960us     450.887us        25.63%     450.887us       2.348us           192  
                                            aten::index         4.92%     977.403us        17.78%       3.530ms      73.537us     282.433us        16.05%     420.451us       8.759us            48  
                                            aten::copy_         5.20%       1.031ms        12.05%       2.392ms      10.922us     372.637us        21.18%     372.637us       1.702us           219  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     357.955us        20.34%     357.955us       2.983us           120  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     282.433us        16.05%     282.433us       5.884us            48  
                                               aten::to         0.65%     128.684us        11.66%       2.315ms      13.536us       0.000us         0.00%     234.619us       1.372us           171  
                                         aten::_to_copy         2.23%     442.466us        11.01%       2.186ms      17.772us       0.000us         0.00%     234.619us       1.907us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     170.397us         9.68%     170.397us       2.029us            84  
                                       aten::contiguous         0.44%      87.582us         9.26%       1.837ms      19.140us       0.000us         0.00%     138.018us       1.438us            96  
                                            aten::clone         0.85%     168.452us         8.82%       1.750ms      18.228us       0.000us         0.00%     138.018us       1.438us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     138.018us         7.84%     138.018us       1.438us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     129.055us         7.33%     129.055us       1.792us            72  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     117.244us         6.66%     117.244us       1.221us            96  
                                              aten::add         1.68%     334.180us         2.81%     557.305us       9.288us     113.660us         6.46%     113.660us       1.894us            60  
                                          aten::__and__         0.41%      80.800us         4.55%     902.601us      10.745us       0.000us         0.00%     105.726us       1.259us            84  
                                      aten::bitwise_and         2.56%     508.561us         4.14%     821.801us       9.783us     105.726us         6.01%     105.726us       1.259us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     105.726us         6.01%     105.726us       1.259us            84  
                                              aten::sub         2.25%     446.108us         3.80%     754.277us      10.476us      82.273us         4.68%      82.273us       1.143us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 19.847ms
Self CUDA time total: 1.760ms


impl                     wl                  p50(ms)  ok
torch_eager              cuda_B1_Q100_H8_E256_L4_P4     3.39  True
torch_eager              cuda_B1_Q300_H8_E256_L4_P4     4.01  True
torch_eager              cuda_B2_Q100_H8_E256_L4_P4     4.02  True
torch_eager              cuda_B2_Q300_H8_E256_L4_P4     4.02  True