Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import math | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from glob import glob | |
| from utils.plots import Annotator, colors | |
| from utils.augmentations import letterbox | |
| from models.common import DetectMultiBackend | |
| from utils.general import non_max_suppression, scale_boxes | |
| from utils.torch_utils import select_device, smart_inference_mode | |
| from pytorch_grad_cam import EigenCAM | |
| import torchvision.transforms as transforms | |
| from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image | |
| weights = "runs/train/best_striped.pt" | |
| data = "data.yaml" | |
| # Load model | |
| device = select_device('cpu') | |
| model = DetectMultiBackend(weights=weights, device=device, fp16=False, data=data) | |
| target_layers = [model.model.model[-2]] | |
| false_detection_data = glob(os.path.join("false_detection", '*.jpg')) | |
| false_detection_data = [x.replace('\\', '/') for x in false_detection_data] | |
| def resize_image_pil(image, new_width, new_height): | |
| # Convert to PIL image | |
| img = Image.fromarray(np.array(image)) | |
| # Get original size | |
| width, height = img.size | |
| # Calculate scale | |
| width_scale = new_width / width | |
| height_scale = new_height / height | |
| scale = min(width_scale, height_scale) | |
| # Resize | |
| resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) | |
| # Crop to exact size | |
| resized = resized.crop((0, 0, new_width, new_height)) | |
| return resized | |
| def display_false_detection_data(false_detection_data, number_of_samples): | |
| fig = plt.figure(figsize=(10, 10)) | |
| x_count = 5 | |
| y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) | |
| for i in range(number_of_samples): | |
| plt.subplot(y_count, x_count, i + 1) | |
| img = cv2.imread(false_detection_data[i]) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| plt.imshow(img) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| return fig | |
| def inference(input_img, conf_thres, iou_thres, is_eigen_cam=True, is_false_detection_images=True, num_false_detection_images=10): | |
| stride, names, pt = model.stride, model.names, model.pt | |
| # Load image | |
| img0 = input_img.copy() | |
| img = letterbox(img0, 640, stride=stride, auto=True)[0] | |
| img = img[:, :, ::-1].transpose(2, 0, 1) | |
| img = np.ascontiguousarray(img) | |
| img = torch.from_numpy(img).to(device).float() | |
| img /= 255.0 | |
| if img.ndimension() == 3: | |
| img = img.unsqueeze(0) | |
| # Inference | |
| pred = model(img, augment=False, visualize=False) | |
| # Apply NMS | |
| pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, max_det=1000) | |
| # Process predictions | |
| seen = 0 | |
| for i, det in enumerate(pred): # per image | |
| seen += 1 | |
| annotator = Annotator(img0, line_width=2, example=str(model.names)) | |
| if len(det): | |
| # Rescale boxes from img_size to im0 size | |
| det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img0.shape).round() | |
| # Write results | |
| for *xyxy, conf, cls in reversed(det): | |
| c = int(cls) # integer class | |
| label = f'{names[c]} {conf:.2f}' | |
| annotator.box_label(xyxy, label, color=colors(c, True)) | |
| if is_false_detection_images: | |
| # Plot the misclassified data | |
| misclassified_images = display_false_detection_data(false_detection_data, number_of_samples=num_false_detection_images) | |
| else: | |
| misclassified_images = None | |
| if is_eigen_cam: | |
| img_GC = cv2.resize(input_img, (640, 640)) | |
| rgb_img = img_GC.copy() | |
| img_GC = np.float32(img_GC) / 255 | |
| transform = transforms.ToTensor() | |
| tensor = transform(img_GC).unsqueeze(0) | |
| cam = EigenCAM(model, target_layers) | |
| grayscale_cam = cam(tensor)[0, :, :] | |
| cam_image = show_cam_on_image(img_GC, grayscale_cam, use_rgb=True) | |
| else: | |
| cam_image = None | |
| return img0, cam_image, misclassified_images | |
| title = "YOLOv9 model to detect shirt/tshirt" | |
| description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image" | |
| examples = [["image_1.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_2.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_3.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_4.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_5.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_6.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_7.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_8.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_9.jpg", 0.25, 0.45, True, True, 10], | |
| ["image_10.jpg", 0.25, 0.45, True, True, 10]] | |
| demo = gr.Interface(inference, | |
| inputs = [gr.Image(width=320, height=320, label="Input Image"), | |
| gr.Slider(0, 1, 0.25, label="Confidence Threshold"), | |
| gr.Slider(0, 1, 0.45, label="IoU Thresold"), | |
| gr.Checkbox(label="Show Eigen CAM"), | |
| gr.Checkbox(label="Show False Detection"), | |
| gr.Slider(5, 35, value=10, step=5, label="Number of False Detection")], | |
| outputs= [gr.Image(width=640, height=640, label="Output"), | |
| gr.Image(label="EigenCAM"), | |
| gr.Plot(label="False Detection")], | |
| title=title, | |
| description=description, | |
| examples=examples) | |
| demo.launch() |