#!/usr/bin/env python3 # -*- coding:utf-8 -*- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # Copyright (c) Megvii Inc. All rights reserved. import torch import torch.nn as nn from .network_blocks import BaseConv, DWConv _TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]] def meshgrid(*tensors): """ Copied from YOLOX/yolox/utils/compat.py """ if _TORCH_VER >= [1, 10]: return torch.meshgrid(*tensors, indexing="ij") else: return torch.meshgrid(*tensors) def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): """ Copied from YOLOX/yolox/utils/boxes.py """ if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: raise IndexError if xyxy: tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) else: tl = torch.max( (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), ) br = torch.min( (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), ) area_a = torch.prod(bboxes_a[:, 2:], 1) area_b = torch.prod(bboxes_b[:, 2:], 1) en = (tl < br).type(tl.type()).prod(dim=2) area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) return area_i / (area_a[:, None] + area_b - area_i) class YOLOXHead(nn.Module): def __init__( self, num_classes, width=1.0, strides=[8, 16, 32], in_channels=[256, 512, 1024], act="silu", depthwise=False, ): """ Args: act (str): activation type of conv. Defalut value: "silu". depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. """ super().__init__() self.num_classes = num_classes self.decode_in_inference = True # for deploy, set to False self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() self.cls_preds = nn.ModuleList() self.reg_preds = nn.ModuleList() self.obj_preds = nn.ModuleList() self.stems = nn.ModuleList() Conv = DWConv if depthwise else BaseConv for i in range(len(in_channels)): self.stems.append( BaseConv( in_channels=int(in_channels[i] * width), out_channels=int(256 * width), ksize=1, stride=1, act=act, ) ) self.cls_convs.append( nn.Sequential( *[ Conv( in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ), Conv( in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ), ] ) ) self.reg_convs.append( nn.Sequential( *[ Conv( in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ), Conv( in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ), ] ) ) self.cls_preds.append( nn.Conv2d( in_channels=int(256 * width), out_channels=self.num_classes, kernel_size=1, stride=1, padding=0, ) ) self.reg_preds.append( nn.Conv2d( in_channels=int(256 * width), out_channels=4, kernel_size=1, stride=1, padding=0, ) ) self.obj_preds.append( nn.Conv2d( in_channels=int(256 * width), out_channels=1, kernel_size=1, stride=1, padding=0, ) ) self.use_l1 = False self.l1_loss = nn.L1Loss(reduction="none") self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") self.iou_loss = None self.strides = strides self.grids = [torch.zeros(1)] * len(in_channels) def forward(self, xin, labels=None, imgs=None): outputs = [] for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( zip(self.cls_convs, self.reg_convs, self.strides, xin) ): x = self.stems[k](x) cls_x = x reg_x = x cls_feat = cls_conv(cls_x) cls_output = self.cls_preds[k](cls_feat) reg_feat = reg_conv(reg_x) reg_output = self.reg_preds[k](reg_feat) obj_output = self.obj_preds[k](reg_feat) output = torch.cat( [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1 ) outputs.append(output) self.hw = [x.shape[-2:] for x in outputs] # [batch, n_anchors_all, 85] outputs = torch.cat( [x.flatten(start_dim=2) for x in outputs], dim=2 ).permute(0, 2, 1) if self.decode_in_inference: return self.decode_outputs(outputs, dtype=xin[0].type()) else: return outputs def get_output_and_grid(self, output, k, stride, dtype): grid = self.grids[k] batch_size = output.shape[0] n_ch = 5 + self.num_classes hsize, wsize = output.shape[-2:] if grid.shape[2:4] != output.shape[2:4]: yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) self.grids[k] = grid output = output.view(batch_size, 1, n_ch, hsize, wsize) output = output.permute(0, 1, 3, 4, 2).reshape( batch_size, hsize * wsize, -1 ) grid = grid.view(1, -1, 2) output[..., :2] = (output[..., :2] + grid) * stride output[..., 2:4] = torch.exp(output[..., 2:4]) * stride return output, grid def decode_outputs(self, outputs, dtype): grids = [] strides = [] for (hsize, wsize), stride in zip(self.hw, self.strides): yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) grid = torch.stack((xv, yv), 2).view(1, -1, 2) grids.append(grid) shape = grid.shape[:2] strides.append(torch.full((*shape, 1), stride)) grids = torch.cat(grids, dim=1).type(dtype) strides = torch.cat(strides, dim=1).type(dtype) outputs = torch.cat([ (outputs[..., 0:2] + grids) * strides, torch.exp(outputs[..., 2:4]) * strides, outputs[..., 4:] ], dim=-1) return outputs