# /// script # requires-python = ">=3.10" # dependencies = [ # "numpy", # "torch==2.8.0", # "kernels-benchmark-tools", # ] # # [tool.uv.sources] # kernels-benchmark-tools = { path = "../../../../../tools", editable = true } # /// import torch import sys from kernels_benchmark_tools import KernelTypeEnum, run_benchmark def torch_deformable_detr( value, spatial_shapes, level_start_index, sampling_locations, attention_weights, im2col_step=64 ): """ PyTorch native reference implementation of multi-scale deformable attention. Uses vectorized bilinear interpolation for reasonable performance. """ bs, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape _, _, _, channels = value.shape output = torch.zeros(bs, num_queries, num_heads, channels, device=value.device, dtype=value.dtype) # Split value tensor by levels value_list = value.split([int(h * w) for h, w in spatial_shapes.tolist()], dim=1) # Iterate through each level (can't avoid this loop easily) for level_idx in range(num_levels): h, w = spatial_shapes[level_idx].tolist() value_level = value_list[level_idx] # (bs, h*w, num_heads, channels) # Reshape to spatial grid: (bs, num_heads, channels, h, w) value_spatial = value_level.reshape(bs, h, w, num_heads, channels).permute(0, 3, 4, 1, 2) # Get sampling locations and weights for this level # loc: (bs, num_queries, num_heads, num_points, 2) loc = sampling_locations[:, :, :, level_idx, :, :] # weight: (bs, num_queries, num_heads, num_points) weight = attention_weights[:, :, :, level_idx, :] # Convert normalized coordinates to pixel coordinates # loc[..., 0] is x (width), loc[..., 1] is y (height) x = loc[..., 0] * w - 0.5 # (bs, num_queries, num_heads, num_points) y = loc[..., 1] * h - 0.5 # Get integer coordinates for bilinear interpolation x0 = torch.floor(x).long() y0 = torch.floor(y).long() x1 = x0 + 1 y1 = y0 + 1 # Compute interpolation weights BEFORE clamping (important!) lw = x - x0.float() # weight for x direction lh = y - y0.float() # weight for y direction hw = 1 - lw hh = 1 - lh # Create mask for valid sample locations valid = (y > -1) & (x > -1) & (y < h) & (x < w) # Create masks for each corner being in bounds mask_tl = ((y0 >= 0) & (x0 >= 0)).unsqueeze(-1).float() mask_tr = ((y0 >= 0) & (x1 <= w - 1)).unsqueeze(-1).float() mask_bl = ((y1 <= h - 1) & (x0 >= 0)).unsqueeze(-1).float() mask_br = ((y1 <= h - 1) & (x1 <= w - 1)).unsqueeze(-1).float() # Clamp coordinates for safe indexing x0_clamped = torch.clamp(x0, 0, w - 1) x1_clamped = torch.clamp(x1, 0, w - 1) y0_clamped = torch.clamp(y0, 0, h - 1) y1_clamped = torch.clamp(y1, 0, h - 1) # Bilinear interpolation weights for all 4 corners w_tl = (hh * hw).unsqueeze(-1) # top-left: (bs, num_queries, num_heads, num_points, 1) w_tr = (hh * lw).unsqueeze(-1) # top-right w_bl = (lh * hw).unsqueeze(-1) # bottom-left w_br = (lh * lw).unsqueeze(-1) # bottom-right # Gather values from the 4 corners using advanced indexing batch_idx = torch.arange(bs, device=value.device).view(bs, 1, 1, 1).expand(bs, num_queries, num_heads, num_points) head_idx = torch.arange(num_heads, device=value.device).view(1, 1, num_heads, 1).expand(bs, num_queries, num_heads, num_points) # Gather corner values with clamped indices, then apply corner masks v_tl = value_spatial[batch_idx, head_idx, :, y0_clamped, x0_clamped] * mask_tl v_tr = value_spatial[batch_idx, head_idx, :, y0_clamped, x1_clamped] * mask_tr v_bl = value_spatial[batch_idx, head_idx, :, y1_clamped, x0_clamped] * mask_bl v_br = value_spatial[batch_idx, head_idx, :, y1_clamped, x1_clamped] * mask_br # Bilinear interpolation sampled = w_tl * v_tl + w_tr * v_tr + w_bl * v_bl + w_br * v_br # Apply valid mask (only accumulate if entire sample location is valid) sampled = sampled * valid.unsqueeze(-1).float() # Apply attention weights and sum over points # weight: (bs, num_queries, num_heads, num_points) # Expand weight: (bs, num_queries, num_heads, num_points, 1) weighted_sampled = sampled * weight.unsqueeze(-1) # Sum over points: (bs, num_queries, num_heads, channels) output += weighted_sampled.sum(dim=3) # Flatten last two dimensions to match kernel output return output.reshape(bs, num_queries, num_heads * channels) run_benchmark( kernel_type=KernelTypeEnum.DEFORMABLE_DETR, impl_name="torch_eager", impl_tags={"family": "pytorch", "backend": "eager"}, impl_func=torch_deformable_detr, dtype="float32", )