PyTorch Native - Deformable DETR

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.22s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Mon Nov 10 21:58:17 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   28C    P0             79W /  350W |       0MiB /  46068MiB |     11%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Deformable DETR Multi-Scale Deformable Attention Benchmark (PyTorch Native)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 5.50s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_deformable_detr(
    value, spatial_shapes, level_start_index, sampling_locations, attention_weights, im2col_step=64
):
    """
    PyTorch native reference implementation of multi-scale deformable attention.
    Uses vectorized bilinear interpolation for reasonable performance.
    """
    bs, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
    _, _, _, channels = value.shape

    output = torch.zeros(bs, num_queries, num_heads, channels, device=value.device, dtype=value.dtype)

    # Split value tensor by levels
    value_list = value.split([int(h * w) for h, w in spatial_shapes.tolist()], dim=1)

    # Iterate through each level (can't avoid this loop easily)
    for level_idx in range(num_levels):
        h, w = spatial_shapes[level_idx].tolist()
        value_level = value_list[level_idx]  # (bs, h*w, num_heads, channels)

        # Reshape to spatial grid: (bs, num_heads, channels, h, w)
        value_spatial = value_level.reshape(bs, h, w, num_heads, channels).permute(0, 3, 4, 1, 2)

        # Get sampling locations and weights for this level
        # loc: (bs, num_queries, num_heads, num_points, 2)
        loc = sampling_locations[:, :, :, level_idx, :, :]
        # weight: (bs, num_queries, num_heads, num_points)
        weight = attention_weights[:, :, :, level_idx, :]

        # Convert normalized coordinates to pixel coordinates
        # loc[..., 0] is x (width), loc[..., 1] is y (height)
        x = loc[..., 0] * w - 0.5  # (bs, num_queries, num_heads, num_points)
        y = loc[..., 1] * h - 0.5

        # Get integer coordinates for bilinear interpolation
        x0 = torch.floor(x).long()
        y0 = torch.floor(y).long()
        x1 = x0 + 1
        y1 = y0 + 1

        # Compute interpolation weights BEFORE clamping (important!)
        lw = x - x0.float()  # weight for x direction
        lh = y - y0.float()  # weight for y direction
        hw = 1 - lw
        hh = 1 - lh

        # Create mask for valid sample locations
        valid = (y > -1) & (x > -1) & (y < h) & (x < w)

        # Create masks for each corner being in bounds
        mask_tl = ((y0 >= 0) & (x0 >= 0)).unsqueeze(-1).float()
        mask_tr = ((y0 >= 0) & (x1 <= w - 1)).unsqueeze(-1).float()
        mask_bl = ((y1 <= h - 1) & (x0 >= 0)).unsqueeze(-1).float()
        mask_br = ((y1 <= h - 1) & (x1 <= w - 1)).unsqueeze(-1).float()

        # Clamp coordinates for safe indexing
        x0_clamped = torch.clamp(x0, 0, w - 1)
        x1_clamped = torch.clamp(x1, 0, w - 1)
        y0_clamped = torch.clamp(y0, 0, h - 1)
        y1_clamped = torch.clamp(y1, 0, h - 1)

        # Bilinear interpolation weights for all 4 corners
        w_tl = (hh * hw).unsqueeze(-1)  # top-left: (bs, num_queries, num_heads, num_points, 1)
        w_tr = (hh * lw).unsqueeze(-1)  # top-right
        w_bl = (lh * hw).unsqueeze(-1)  # bottom-left
        w_br = (lh * lw).unsqueeze(-1)  # bottom-right

        # Gather values from the 4 corners using advanced indexing
        batch_idx = torch.arange(bs, device=value.device).view(bs, 1, 1, 1).expand(bs, num_queries, num_heads, num_points)
        head_idx = torch.arange(num_heads, device=value.device).view(1, 1, num_heads, 1).expand(bs, num_queries, num_heads, num_points)

        # Gather corner values with clamped indices, then apply corner masks
        v_tl = value_spatial[batch_idx, head_idx, :, y0_clamped, x0_clamped] * mask_tl
        v_tr = value_spatial[batch_idx, head_idx, :, y0_clamped, x1_clamped] * mask_tr
        v_bl = value_spatial[batch_idx, head_idx, :, y1_clamped, x0_clamped] * mask_bl
        v_br = value_spatial[batch_idx, head_idx, :, y1_clamped, x1_clamped] * mask_br

        # Bilinear interpolation
        sampled = w_tl * v_tl + w_tr * v_tr + w_bl * v_bl + w_br * v_br

        # Apply valid mask (only accumulate if entire sample location is valid)
        sampled = sampled * valid.unsqueeze(-1).float()

        # Apply attention weights and sum over points
        # weight: (bs, num_queries, num_heads, num_points)
        # Expand weight: (bs, num_queries, num_heads, num_points, 1)
        weighted_sampled = sampled * weight.unsqueeze(-1)

        # Sum over points: (bs, num_queries, num_heads, channels)
        output += weighted_sampled.sum(dim=3)

    # Flatten last two dimensions to match kernel output
    return output.reshape(bs, num_queries, num_heads * channels)


run_benchmark(
    kernel_type=KernelTypeEnum.DEFORMABLE_DETR,
    impl_name="torch_eager",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=torch_deformable_detr,
    dtype="float32",
)
Running deformable_detr benchmark on cuda with 4 workloads.

======================================================================
PROFILE TRACE: torch_eager | cuda_B1_Q100_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      20.386ms      1374.87%      20.386ms      20.386ms             1  
                                            torch_eager        20.04%       4.485ms        99.97%      22.369ms      22.369ms       0.000us         0.00%       1.484ms       1.484ms             1  
                                            aten::index         4.49%       1.004ms        16.23%       3.633ms      75.679us     237.283us        16.00%     370.795us       7.725us            48  
                                            aten::copy_         4.62%       1.034ms        11.24%       2.516ms      11.489us     365.611us        24.66%     365.611us       1.669us           219  
                                              aten::mul         5.81%       1.299ms        10.43%       2.335ms      12.160us     293.820us        19.82%     293.820us       1.530us           192  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     237.283us        16.00%     237.283us       4.943us            48  
                                               aten::to         0.57%     127.097us        11.08%       2.479ms      14.499us       0.000us         0.00%     232.099us       1.357us           171  
                                         aten::_to_copy         2.30%     514.876us        10.51%       2.352ms      19.124us       0.000us         0.00%     232.099us       1.887us           123  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     202.015us        13.62%     202.015us       1.683us           120  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     167.684us        11.31%     167.684us       1.996us            84  
                                       aten::contiguous         0.35%      77.804us         8.37%       1.873ms      19.513us       0.000us         0.00%     133.512us       1.391us            96  
                                            aten::clone         0.74%     165.226us         8.02%       1.795ms      18.702us       0.000us         0.00%     133.512us       1.391us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     133.512us         9.00%     133.512us       1.391us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     115.524us         7.79%     115.524us       1.203us            96  
                                          aten::__and__         1.20%     268.284us         4.94%       1.105ms      13.160us       0.000us         0.00%      99.070us       1.179us            84  
                                      aten::bitwise_and         2.22%     496.516us         3.74%     837.149us       9.966us      99.070us         6.68%      99.070us       1.179us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      99.070us         6.68%      99.070us       1.179us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      86.210us         5.81%      86.210us       1.197us            72  
                                              aten::sub         2.17%     485.693us         3.77%     844.019us      11.722us      79.300us         5.35%      79.300us       1.101us            72  
                                              aten::add         1.71%     382.016us         2.87%     642.388us      10.706us      74.367us         5.02%      74.367us       1.239us            60  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 22.377ms
Self CUDA time total: 1.483ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B1_Q300_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      18.901ms      1183.82%      18.901ms      18.901ms             1  
                                            torch_eager        19.58%       4.093ms        99.97%      20.894ms      20.894ms       0.000us         0.00%       1.598ms       1.598ms             1  
                                            aten::index         4.47%     934.204us        16.39%       3.425ms      71.358us     251.679us        15.76%     384.126us       8.003us            48  
                                            aten::copy_         4.82%       1.008ms        11.62%       2.429ms      11.090us     366.752us        22.97%     366.752us       1.675us           219  
                                              aten::mul         6.02%       1.258ms        10.56%       2.208ms      11.499us     358.660us        22.46%     358.660us       1.868us           192  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     266.913us        16.72%     266.913us       2.224us           120  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     251.679us        15.76%     251.679us       5.243us            48  
                                               aten::to         0.53%     111.534us        10.80%       2.257ms      13.199us       0.000us         0.00%     234.305us       1.370us           171  
                                         aten::_to_copy         1.86%     389.526us        10.27%       2.146ms      17.443us       0.000us         0.00%     234.305us       1.905us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     169.699us        10.63%     169.699us       2.020us            84  
                                       aten::contiguous         0.36%      76.248us         8.65%       1.808ms      18.835us       0.000us         0.00%     132.447us       1.380us            96  
                                            aten::clone         0.75%     157.022us         8.29%       1.732ms      18.040us       0.000us         0.00%     132.447us       1.380us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     132.447us         8.30%     132.447us       1.380us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     117.700us         7.37%     117.700us       1.226us            96  
                                          aten::__and__         0.39%      80.574us         4.34%     907.528us      10.804us       0.000us         0.00%     104.931us       1.249us            84  
                                      aten::bitwise_and         2.39%     499.734us         3.96%     826.954us       9.845us     104.931us         6.57%     104.931us       1.249us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     104.931us         6.57%     104.931us       1.249us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     104.254us         6.53%     104.254us       1.448us            72  
                                              aten::add         1.76%     366.940us         2.98%     622.302us      10.372us      91.679us         5.74%      91.679us       1.528us            60  
                                              aten::sub         2.26%     472.751us         3.91%     817.040us      11.348us      80.412us         5.04%      80.412us       1.117us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.900ms
Self CUDA time total: 1.597ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B2_Q100_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      19.237ms      1248.03%      19.237ms      19.237ms             1  
                                            torch_eager        19.69%       4.158ms        99.97%      21.112ms      21.112ms       0.000us         0.00%       1.542ms       1.542ms             1  
                                            aten::index         4.41%     930.777us        16.28%       3.439ms      71.641us     244.707us        15.88%     379.074us       7.897us            48  
                                            aten::copy_         4.79%       1.012ms        11.88%       2.509ms      11.455us     367.613us        23.85%     367.613us       1.679us           219  
                                              aten::mul         6.03%       1.274ms        10.79%       2.279ms      11.869us     324.897us        21.08%     324.897us       1.692us           192  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     244.707us        15.88%     244.707us       5.098us            48  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     233.822us        15.17%     233.822us       1.949us           120  
                                               aten::to         0.53%     111.710us        11.01%       2.324ms      13.591us       0.000us         0.00%     233.246us       1.364us           171  
                                         aten::_to_copy         1.89%     399.701us        10.48%       2.212ms      17.986us       0.000us         0.00%     233.246us       1.896us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     168.798us        10.95%     168.798us       2.010us            84  
                                       aten::contiguous         0.36%      76.215us         8.56%       1.808ms      18.834us       0.000us         0.00%     134.367us       1.400us            96  
                                            aten::clone         0.70%     147.727us         8.20%       1.732ms      18.040us       0.000us         0.00%     134.367us       1.400us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     134.367us         8.72%     134.367us       1.400us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     116.097us         7.53%     116.097us       1.209us            96  
                                          aten::__and__         0.38%      80.351us         4.40%     929.654us      11.067us       0.000us         0.00%     104.257us       1.241us            84  
                                      aten::bitwise_and         2.34%     493.964us         4.02%     849.303us      10.111us     104.257us         6.76%     104.257us       1.241us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     104.257us         6.76%     104.257us       1.241us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      96.124us         6.24%      96.124us       1.335us            72  
                                              aten::add         1.63%     344.862us         2.97%     627.717us      10.462us      83.898us         5.44%      83.898us       1.398us            60  
                                              aten::sub         2.25%     476.045us         3.91%     826.060us      11.473us      79.295us         5.14%      79.295us       1.101us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.118ms
Self CUDA time total: 1.541ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B2_Q300_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      19.519ms      1100.37%      19.519ms      19.519ms             1  
                                            torch_eager        20.47%       4.142ms        99.97%      20.229ms      20.229ms       0.000us         0.00%       1.775ms       1.775ms             1  
                                              aten::mul         6.23%       1.261ms        11.26%       2.279ms      11.871us     452.223us        25.49%     452.223us       2.355us           192  
                                            aten::index         5.19%       1.050ms        17.90%       3.622ms      75.460us     284.479us        16.04%     422.205us       8.796us            48  
                                            aten::copy_         4.94%       1.000ms        12.35%       2.500ms      11.414us     371.807us        20.96%     371.807us       1.698us           219  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     357.379us        20.15%     357.379us       2.978us           120  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     284.479us        16.04%     284.479us       5.927us            48  
                                               aten::to         0.55%     111.602us        11.50%       2.327ms      13.611us       0.000us         0.00%     234.081us       1.369us           171  
                                         aten::_to_copy         2.05%     415.176us        10.95%       2.216ms      18.015us       0.000us         0.00%     234.081us       1.903us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     168.127us         9.48%     168.127us       2.002us            84  
                                       aten::contiguous         0.39%      79.104us         9.03%       1.827ms      19.029us       0.000us         0.00%     137.726us       1.435us            96  
                                            aten::clone         0.75%     151.809us         8.64%       1.748ms      18.205us       0.000us         0.00%     137.726us       1.435us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     137.726us         7.76%     137.726us       1.435us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     129.254us         7.29%     129.254us       1.795us            72  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     120.034us         6.77%     120.034us       1.250us            96  
                                              aten::add         1.70%     344.853us         3.02%     611.127us      10.185us     113.603us         6.40%     113.603us       1.893us            60  
                                          aten::__and__         0.42%      84.251us         4.73%     957.185us      11.395us       0.000us         0.00%     108.833us       1.296us            84  
                                      aten::bitwise_and         2.53%     511.745us         4.31%     872.934us      10.392us     108.833us         6.14%     108.833us       1.296us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     108.833us         6.14%     108.833us       1.296us            84  
                                              aten::sub         2.33%     472.119us         4.10%     828.789us      11.511us      84.547us         4.77%      84.547us       1.174us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.235ms
Self CUDA time total: 1.774ms


impl                     wl                  p50(ms)  ok
torch_eager              cuda_B1_Q100_H8_E256_L4_P4     3.28  True
torch_eager              cuda_B1_Q300_H8_E256_L4_P4     4.01  True
torch_eager              cuda_B2_Q100_H8_E256_L4_P4     4.03  True
torch_eager              cuda_B2_Q300_H8_E256_L4_P4     4.14  True