File size: 4,970 Bytes
9b13459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_deformable_detr(
    value, spatial_shapes, level_start_index, sampling_locations, attention_weights, im2col_step=64
):
    """
    PyTorch native reference implementation of multi-scale deformable attention.
    Uses vectorized bilinear interpolation for reasonable performance.
    """
    bs, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
    _, _, _, channels = value.shape

    output = torch.zeros(bs, num_queries, num_heads, channels, device=value.device, dtype=value.dtype)

    # Split value tensor by levels
    value_list = value.split([int(h * w) for h, w in spatial_shapes.tolist()], dim=1)

    # Iterate through each level (can't avoid this loop easily)
    for level_idx in range(num_levels):
        h, w = spatial_shapes[level_idx].tolist()
        value_level = value_list[level_idx]  # (bs, h*w, num_heads, channels)

        # Reshape to spatial grid: (bs, num_heads, channels, h, w)
        value_spatial = value_level.reshape(bs, h, w, num_heads, channels).permute(0, 3, 4, 1, 2)

        # Get sampling locations and weights for this level
        # loc: (bs, num_queries, num_heads, num_points, 2)
        loc = sampling_locations[:, :, :, level_idx, :, :]
        # weight: (bs, num_queries, num_heads, num_points)
        weight = attention_weights[:, :, :, level_idx, :]

        # Convert normalized coordinates to pixel coordinates
        # loc[..., 0] is x (width), loc[..., 1] is y (height)
        x = loc[..., 0] * w - 0.5  # (bs, num_queries, num_heads, num_points)
        y = loc[..., 1] * h - 0.5

        # Get integer coordinates for bilinear interpolation
        x0 = torch.floor(x).long()
        y0 = torch.floor(y).long()
        x1 = x0 + 1
        y1 = y0 + 1

        # Compute interpolation weights BEFORE clamping (important!)
        lw = x - x0.float()  # weight for x direction
        lh = y - y0.float()  # weight for y direction
        hw = 1 - lw
        hh = 1 - lh

        # Create mask for valid sample locations
        valid = (y > -1) & (x > -1) & (y < h) & (x < w)

        # Create masks for each corner being in bounds
        mask_tl = ((y0 >= 0) & (x0 >= 0)).unsqueeze(-1).float()
        mask_tr = ((y0 >= 0) & (x1 <= w - 1)).unsqueeze(-1).float()
        mask_bl = ((y1 <= h - 1) & (x0 >= 0)).unsqueeze(-1).float()
        mask_br = ((y1 <= h - 1) & (x1 <= w - 1)).unsqueeze(-1).float()

        # Clamp coordinates for safe indexing
        x0_clamped = torch.clamp(x0, 0, w - 1)
        x1_clamped = torch.clamp(x1, 0, w - 1)
        y0_clamped = torch.clamp(y0, 0, h - 1)
        y1_clamped = torch.clamp(y1, 0, h - 1)

        # Bilinear interpolation weights for all 4 corners
        w_tl = (hh * hw).unsqueeze(-1)  # top-left: (bs, num_queries, num_heads, num_points, 1)
        w_tr = (hh * lw).unsqueeze(-1)  # top-right
        w_bl = (lh * hw).unsqueeze(-1)  # bottom-left
        w_br = (lh * lw).unsqueeze(-1)  # bottom-right

        # Gather values from the 4 corners using advanced indexing
        batch_idx = torch.arange(bs, device=value.device).view(bs, 1, 1, 1).expand(bs, num_queries, num_heads, num_points)
        head_idx = torch.arange(num_heads, device=value.device).view(1, 1, num_heads, 1).expand(bs, num_queries, num_heads, num_points)

        # Gather corner values with clamped indices, then apply corner masks
        v_tl = value_spatial[batch_idx, head_idx, :, y0_clamped, x0_clamped] * mask_tl
        v_tr = value_spatial[batch_idx, head_idx, :, y0_clamped, x1_clamped] * mask_tr
        v_bl = value_spatial[batch_idx, head_idx, :, y1_clamped, x0_clamped] * mask_bl
        v_br = value_spatial[batch_idx, head_idx, :, y1_clamped, x1_clamped] * mask_br

        # Bilinear interpolation
        sampled = w_tl * v_tl + w_tr * v_tr + w_bl * v_bl + w_br * v_br

        # Apply valid mask (only accumulate if entire sample location is valid)
        sampled = sampled * valid.unsqueeze(-1).float()

        # Apply attention weights and sum over points
        # weight: (bs, num_queries, num_heads, num_points)
        # Expand weight: (bs, num_queries, num_heads, num_points, 1)
        weighted_sampled = sampled * weight.unsqueeze(-1)

        # Sum over points: (bs, num_queries, num_heads, channels)
        output += weighted_sampled.sum(dim=3)

    # Flatten last two dimensions to match kernel output
    return output.reshape(bs, num_queries, num_heads * channels)


run_benchmark(
    kernel_type=KernelTypeEnum.DEFORMABLE_DETR,
    impl_name="torch_eager",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=torch_deformable_detr,
    dtype="float32",
)