|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torchvision |
|
|
|
|
|
|
|
|
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False): |
|
|
""" |
|
|
Copied from YOLOX/yolox/utils/boxes.py |
|
|
""" |
|
|
box_corner = prediction.new(prediction.shape) |
|
|
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 |
|
|
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 |
|
|
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 |
|
|
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 |
|
|
prediction[:, :, :4] = box_corner[:, :, :4] |
|
|
|
|
|
output = [None for _ in range(len(prediction))] |
|
|
for i, image_pred in enumerate(prediction): |
|
|
|
|
|
|
|
|
if not image_pred.size(0): |
|
|
continue |
|
|
|
|
|
class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True) |
|
|
|
|
|
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() |
|
|
|
|
|
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) |
|
|
detections = detections[conf_mask] |
|
|
if not detections.size(0): |
|
|
continue |
|
|
|
|
|
if class_agnostic: |
|
|
nms_out_index = torchvision.ops.nms( |
|
|
detections[:, :4], |
|
|
detections[:, 4] * detections[:, 5], |
|
|
nms_thre, |
|
|
) |
|
|
else: |
|
|
nms_out_index = torchvision.ops.batched_nms( |
|
|
detections[:, :4], |
|
|
detections[:, 4] * detections[:, 5], |
|
|
detections[:, 6], |
|
|
nms_thre, |
|
|
) |
|
|
|
|
|
detections = detections[nms_out_index] |
|
|
if output[i] is None: |
|
|
output[i] = detections |
|
|
else: |
|
|
output[i] = torch.cat((output[i], detections)) |
|
|
|
|
|
return output |
|
|
|