import os import cv2 import math import torch import numpy as np import gradio as gr import albumentations import matplotlib.pyplot as plt from glob import glob from PIL import Image from pytorch_grad_cam import EigenCAM from models.common import DetectMultiBackend from albumentations.pytorch import ToTensorV2 from utils.augmentations import letterbox from utils.plots import Annotator, colors from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image from utils.torch_utils import select_device, smart_inference_mode from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes weights = "runs/train/best_striped.pt" data = "data.yaml" # Load model device = select_device('cpu') model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False) #target_layers = [model.model.model[-1]] false_detection_data = glob(os.path.join("false_detection", '*.jpg')) false_detection_data = [x.replace('\\', '/') for x in false_detection_data] def resize_image_pil(image, new_width, new_height): # Convert to PIL image img = Image.fromarray(np.array(image)) # Get original size width, height = img.size # Calculate scale width_scale = new_width / width height_scale = new_height / height scale = min(width_scale, height_scale) # Resize resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) # Crop to exact size resized = resized.crop((0, 0, new_width, new_height)) return resized def display_false_detection_data(false_detection_data, number_of_samples): fig = plt.figure(figsize=(10, 10)) x_count = 5 y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) for i in range(number_of_samples): plt.subplot(y_count, x_count, i + 1) img = cv2.imread(false_detection_data[i]) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) plt.imshow(img) plt.xticks([]) plt.yticks([]) return fig def inference(input_img, conf_thres, iou_thres, is_false_detection_images=True, num_false_detection_images=10): im0 = input_img.copy() rgb_img = cv2.resize(im0, (640, 640)) stride, names, pt = model.stride, model.names, model.pt imgsz = check_img_size((640, 640), s=stride) # check image size bs = 1 # Run inference model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) with dt[0]: im = letterbox(input_img, imgsz, stride=stride, auto=True)[0] # padded resize im = im.transpose((2, 0, 1))[::-1] im = np.ascontiguousarray(im) im = torch.from_numpy(im).to(model.device) im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim # Inference with dt[1]: pred = model(im, augment=False, visualize=False) # NMS with dt[2]: pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=10) # Process predictions for i, det in enumerate(pred): # per image seen += 1 annotator = Annotator(im0, line_width=2, example=str(model.names)) if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # Write results for *xyxy, conf, cls in reversed(det): c = int(cls) # integer class label = f'{names[c]} {conf:.2f}' annotator.box_label(xyxy, label, color=colors(c, True)) if is_false_detection_images: # Plot the misclassified data misclassified_images = display_false_detection_data(false_detection_data, number_of_samples=num_false_detection_images) else: misclassified_images = None # cam = EigenCAM(model, target_layers) # grayscale_cam = cam(im)[0, :] # cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) return im0, misclassified_images title = "YOLOv9 model to detect shirt/tshirt" description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image" examples = [["image_1.jpg", 0.25, 0.45, True, 10], ["image_2.jpg", 0.25, 0.45, True, 10], ["image_3.jpg", 0.25, 0.45, True, 10], ["image_4.jpg", 0.25, 0.45, True, 10], ["image_5.jpg", 0.25, 0.45, True, 10], ["image_6.jpg", 0.25, 0.45, True, 10], ["image_7.jpg", 0.25, 0.45, True, 10], ["image_8.jpg", 0.25, 0.45, True, 10], ["image_9.jpg", 0.25, 0.45, True, 10], ["image_10.jpg", 0.25, 0.45, True, 10]] demo = gr.Interface(inference, inputs = [gr.Image(width=320, height=320, label="Input Image"), gr.Slider(0, 1, 0.25, label="Confidence Threshold"), gr.Slider(0, 1, 0.45, label="IoU Thresold"), gr.Checkbox(label="Show False Detection"), gr.Slider(5, 35, value=10, step=5, label="Number of False Detection")], outputs= [gr.Image(width=640, height=640, label="Output"), gr.Plot(label="False Detection")], title=title, description=description, examples=examples) demo.launch()