import torch import numpy as np import gradio as gr from PIL import Image from models.common import DetectMultiBackend from utils.plots import Annotator, colors from utils.torch_utils import select_device, smart_inference_mode from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes weights = "runs/train/best_striped.pt" data = "data.yaml" def resize_image_pil(image, new_width, new_height): # Convert to PIL image img = Image.fromarray(np.array(image)) # Get original size width, height = img.size # Calculate scale width_scale = new_width / width height_scale = new_height / height scale = min(width_scale, height_scale) # Resize resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) # Crop to exact size resized = resized.crop((0, 0, new_width, new_height)) return resized def inference(input_img, conf_thres, iou_thres): im0 = input_img.copy() # Load model device = select_device(device) model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False) stride, names, pt = model.stride, model.names, model.pt imgsz = check_img_size(imgsz, s=stride) # check image size bs = 1 # Run inference model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) with dt[0]: im = torch.from_numpy(input_img).to(model.device) im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim # Inference with dt[1]: pred = model(im, augment=False, visualize=False) # NMS with dt[2]: pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=10) # Process predictions for i, det in enumerate(pred): # per image seen += 1 annotator = Annotator(im0, line_width=2, example=str(model.names)) if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # Write results for *xyxy, conf, cls in reversed(det): c = int(cls) # integer class label = '{names[c]} {conf:.2f}' annotator.box_label(xyxy, label, color=colors(c, True)) return im0 title = "YOLOv9 model to detect shirt/tshirt" description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image" examples = [["image_1.jpg", 0.25, 0.45], ["image_2.jpg", 0.25, 0.45], ["image_3.jpg", 0.25, 0.45], ["image_4.jpg", 0.25, 0.45], ["image_5.jpg", 0.25, 0.45], ["image_6.jpg", 0.25, 0.45], ["image_7.jpg", 0.25, 0.45], ["image_8.jpg", 0.25, 0.45], ["image_9.jpg", 0.25, 0.45], ["image_10.jpg", 0.25, 0.45]] demo = gr.Interface(inference, inputs = [gr.Image(width=320, height=320, label="Input Image"), gr.Slider(0, 1, 0.25, label="Confidance Thresold"), gr.Slider(0, 1, 0.45, label="IoU Thresold")], outputs= [gr.Image(width=640, height=640, label="Output")], title=title, description=description, examples=examples)