|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import sys |
|
|
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark |
|
|
|
|
|
|
|
|
def torch_deformable_detr( |
|
|
value, spatial_shapes, level_start_index, sampling_locations, attention_weights, im2col_step=64 |
|
|
): |
|
|
""" |
|
|
PyTorch native reference implementation of multi-scale deformable attention. |
|
|
Uses vectorized bilinear interpolation for reasonable performance. |
|
|
""" |
|
|
bs, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape |
|
|
_, _, _, channels = value.shape |
|
|
|
|
|
output = torch.zeros(bs, num_queries, num_heads, channels, device=value.device, dtype=value.dtype) |
|
|
|
|
|
|
|
|
value_list = value.split([int(h * w) for h, w in spatial_shapes.tolist()], dim=1) |
|
|
|
|
|
|
|
|
for level_idx in range(num_levels): |
|
|
h, w = spatial_shapes[level_idx].tolist() |
|
|
value_level = value_list[level_idx] |
|
|
|
|
|
|
|
|
value_spatial = value_level.reshape(bs, h, w, num_heads, channels).permute(0, 3, 4, 1, 2) |
|
|
|
|
|
|
|
|
|
|
|
loc = sampling_locations[:, :, :, level_idx, :, :] |
|
|
|
|
|
weight = attention_weights[:, :, :, level_idx, :] |
|
|
|
|
|
|
|
|
|
|
|
x = loc[..., 0] * w - 0.5 |
|
|
y = loc[..., 1] * h - 0.5 |
|
|
|
|
|
|
|
|
x0 = torch.floor(x).long() |
|
|
y0 = torch.floor(y).long() |
|
|
x1 = x0 + 1 |
|
|
y1 = y0 + 1 |
|
|
|
|
|
|
|
|
lw = x - x0.float() |
|
|
lh = y - y0.float() |
|
|
hw = 1 - lw |
|
|
hh = 1 - lh |
|
|
|
|
|
|
|
|
valid = (y > -1) & (x > -1) & (y < h) & (x < w) |
|
|
|
|
|
|
|
|
mask_tl = ((y0 >= 0) & (x0 >= 0)).unsqueeze(-1).float() |
|
|
mask_tr = ((y0 >= 0) & (x1 <= w - 1)).unsqueeze(-1).float() |
|
|
mask_bl = ((y1 <= h - 1) & (x0 >= 0)).unsqueeze(-1).float() |
|
|
mask_br = ((y1 <= h - 1) & (x1 <= w - 1)).unsqueeze(-1).float() |
|
|
|
|
|
|
|
|
x0_clamped = torch.clamp(x0, 0, w - 1) |
|
|
x1_clamped = torch.clamp(x1, 0, w - 1) |
|
|
y0_clamped = torch.clamp(y0, 0, h - 1) |
|
|
y1_clamped = torch.clamp(y1, 0, h - 1) |
|
|
|
|
|
|
|
|
w_tl = (hh * hw).unsqueeze(-1) |
|
|
w_tr = (hh * lw).unsqueeze(-1) |
|
|
w_bl = (lh * hw).unsqueeze(-1) |
|
|
w_br = (lh * lw).unsqueeze(-1) |
|
|
|
|
|
|
|
|
batch_idx = torch.arange(bs, device=value.device).view(bs, 1, 1, 1).expand(bs, num_queries, num_heads, num_points) |
|
|
head_idx = torch.arange(num_heads, device=value.device).view(1, 1, num_heads, 1).expand(bs, num_queries, num_heads, num_points) |
|
|
|
|
|
|
|
|
v_tl = value_spatial[batch_idx, head_idx, :, y0_clamped, x0_clamped] * mask_tl |
|
|
v_tr = value_spatial[batch_idx, head_idx, :, y0_clamped, x1_clamped] * mask_tr |
|
|
v_bl = value_spatial[batch_idx, head_idx, :, y1_clamped, x0_clamped] * mask_bl |
|
|
v_br = value_spatial[batch_idx, head_idx, :, y1_clamped, x1_clamped] * mask_br |
|
|
|
|
|
|
|
|
sampled = w_tl * v_tl + w_tr * v_tr + w_bl * v_bl + w_br * v_br |
|
|
|
|
|
|
|
|
sampled = sampled * valid.unsqueeze(-1).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weighted_sampled = sampled * weight.unsqueeze(-1) |
|
|
|
|
|
|
|
|
output += weighted_sampled.sum(dim=3) |
|
|
|
|
|
|
|
|
return output.reshape(bs, num_queries, num_heads * channels) |
|
|
|
|
|
|
|
|
run_benchmark( |
|
|
kernel_type=KernelTypeEnum.DEFORMABLE_DETR, |
|
|
impl_name="torch_eager", |
|
|
impl_tags={"family": "pytorch", "backend": "eager"}, |
|
|
impl_func=torch_deformable_detr, |
|
|
dtype="float32", |
|
|
) |