Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import math | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import albumentations | |
| import matplotlib.pyplot as plt | |
| from glob import glob | |
| from PIL import Image | |
| from pytorch_grad_cam import EigenCAM | |
| from models.common import DetectMultiBackend | |
| from albumentations.pytorch import ToTensorV2 | |
| from utils.augmentations import letterbox | |
| from utils.plots import Annotator, colors | |
| from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image | |
| from utils.torch_utils import select_device, smart_inference_mode | |
| from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes | |
| weights = "runs/train/best_striped.pt" | |
| data = "data.yaml" | |
| # Load model | |
| device = select_device('cpu') | |
| model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False) | |
| #target_layers = [model.model.model[-1]] | |
| false_detection_data = glob(os.path.join("false_detection", '*.jpg')) | |
| false_detection_data = [x.replace('\\', '/') for x in false_detection_data] | |
| def resize_image_pil(image, new_width, new_height): | |
| # Convert to PIL image | |
| img = Image.fromarray(np.array(image)) | |
| # Get original size | |
| width, height = img.size | |
| # Calculate scale | |
| width_scale = new_width / width | |
| height_scale = new_height / height | |
| scale = min(width_scale, height_scale) | |
| # Resize | |
| resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) | |
| # Crop to exact size | |
| resized = resized.crop((0, 0, new_width, new_height)) | |
| return resized | |
| def display_false_detection_data(false_detection_data, number_of_samples): | |
| fig = plt.figure(figsize=(10, 10)) | |
| x_count = 5 | |
| y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) | |
| for i in range(number_of_samples): | |
| plt.subplot(y_count, x_count, i + 1) | |
| img = cv2.imread(false_detection_data[i]) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| plt.imshow(img) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| return fig | |
| def inference(input_img, conf_thres, iou_thres, is_false_detection_images=True, num_false_detection_images=10): | |
| im0 = input_img.copy() | |
| rgb_img = cv2.resize(im0, (640, 640)) | |
| stride, names, pt = model.stride, model.names, model.pt | |
| imgsz = check_img_size((640, 640), s=stride) # check image size | |
| bs = 1 | |
| # Run inference | |
| model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) | |
| seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) | |
| with dt[0]: | |
| im = letterbox(input_img, imgsz, stride=stride, auto=True)[0] # padded resize | |
| im = im.transpose((2, 0, 1))[::-1] | |
| im = np.ascontiguousarray(im) | |
| im = torch.from_numpy(im).to(model.device) | |
| im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 | |
| im /= 255 # 0 - 255 to 0.0 - 1.0 | |
| if len(im.shape) == 3: | |
| im = im[None] # expand for batch dim | |
| # Inference | |
| with dt[1]: | |
| pred = model(im, augment=False, visualize=False) | |
| # NMS | |
| with dt[2]: | |
| pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=10) | |
| # Process predictions | |
| for i, det in enumerate(pred): # per image | |
| seen += 1 | |
| annotator = Annotator(im0, line_width=2, example=str(model.names)) | |
| if len(det): | |
| # Rescale boxes from img_size to im0 size | |
| det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() | |
| # Write results | |
| for *xyxy, conf, cls in reversed(det): | |
| c = int(cls) # integer class | |
| label = f'{names[c]} {conf:.2f}' | |
| annotator.box_label(xyxy, label, color=colors(c, True)) | |
| if is_false_detection_images: | |
| # Plot the misclassified data | |
| misclassified_images = display_false_detection_data(false_detection_data, number_of_samples=num_false_detection_images) | |
| else: | |
| misclassified_images = None | |
| # cam = EigenCAM(model, target_layers) | |
| # grayscale_cam = cam(im)[0, :] | |
| # cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
| return im0, misclassified_images | |
| title = "YOLOv9 model to detect shirt/tshirt" | |
| description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image" | |
| examples = [["image_1.jpg", 0.25, 0.45, True, 10], | |
| ["image_2.jpg", 0.25, 0.45, True, 10], | |
| ["image_3.jpg", 0.25, 0.45, True, 10], | |
| ["image_4.jpg", 0.25, 0.45, True, 10], | |
| ["image_5.jpg", 0.25, 0.45, True, 10], | |
| ["image_6.jpg", 0.25, 0.45, True, 10], | |
| ["image_7.jpg", 0.25, 0.45, True, 10], | |
| ["image_8.jpg", 0.25, 0.45, True, 10], | |
| ["image_9.jpg", 0.25, 0.45, True, 10], | |
| ["image_10.jpg", 0.25, 0.45, True, 10]] | |
| demo = gr.Interface(inference, | |
| inputs = [gr.Image(width=320, height=320, label="Input Image"), | |
| gr.Slider(0, 1, 0.25, label="Confidence Threshold"), | |
| gr.Slider(0, 1, 0.45, label="IoU Thresold"), | |
| gr.Checkbox(label="Show False Detection"), | |
| gr.Slider(5, 35, value=10, step=5, label="Number of False Detection")], | |
| outputs= [gr.Image(width=640, height=640, label="Output"), | |
| gr.Plot(label="False Detection")], | |
| title=title, | |
| description=description, | |
| examples=examples) | |
| demo.launch() |