Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from models.common import DetectMultiBackend | |
| from utils.plots import Annotator, colors | |
| from utils.torch_utils import select_device, smart_inference_mode | |
| from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes | |
| weights = "runs/train/best_striped.pt" | |
| data = "data.yaml" | |
| def resize_image_pil(image, new_width, new_height): | |
| # Convert to PIL image | |
| img = Image.fromarray(np.array(image)) | |
| # Get original size | |
| width, height = img.size | |
| # Calculate scale | |
| width_scale = new_width / width | |
| height_scale = new_height / height | |
| scale = min(width_scale, height_scale) | |
| # Resize | |
| resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) | |
| # Crop to exact size | |
| resized = resized.crop((0, 0, new_width, new_height)) | |
| return resized | |
| def inference(input_img, conf_thres, iou_thres): | |
| im0 = input_img.copy() | |
| # Load model | |
| device = select_device(device) | |
| model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False) | |
| stride, names, pt = model.stride, model.names, model.pt | |
| imgsz = check_img_size(imgsz, s=stride) # check image size | |
| bs = 1 | |
| # Run inference | |
| model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) | |
| seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) | |
| with dt[0]: | |
| im = torch.from_numpy(input_img).to(model.device) | |
| im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 | |
| im /= 255 # 0 - 255 to 0.0 - 1.0 | |
| if len(im.shape) == 3: | |
| im = im[None] # expand for batch dim | |
| # Inference | |
| with dt[1]: | |
| pred = model(im, augment=False, visualize=False) | |
| # NMS | |
| with dt[2]: | |
| pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=10) | |
| # Process predictions | |
| for i, det in enumerate(pred): # per image | |
| seen += 1 | |
| annotator = Annotator(im0, line_width=2, example=str(model.names)) | |
| if len(det): | |
| # Rescale boxes from img_size to im0 size | |
| det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() | |
| # Write results | |
| for *xyxy, conf, cls in reversed(det): | |
| c = int(cls) # integer class | |
| label = '{names[c]} {conf:.2f}' | |
| annotator.box_label(xyxy, label, color=colors(c, True)) | |
| return im0 | |
| title = "YOLOv9 model to detect shirt/tshirt" | |
| description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image" | |
| examples = [["image_1.jpg", 0.25, 0.45], ["image_2.jpg", 0.25, 0.45], | |
| ["image_3.jpg", 0.25, 0.45], ["image_4.jpg", 0.25, 0.45], | |
| ["image_5.jpg", 0.25, 0.45], ["image_6.jpg", 0.25, 0.45], | |
| ["image_7.jpg", 0.25, 0.45], ["image_8.jpg", 0.25, 0.45], | |
| ["image_9.jpg", 0.25, 0.45], ["image_10.jpg", 0.25, 0.45]] | |
| demo = gr.Interface(inference, | |
| inputs = [gr.Image(width=320, height=320, label="Input Image"), | |
| gr.Slider(0, 1, 0.25, label="Confidance Thresold"), | |
| gr.Slider(0, 1, 0.45, label="IoU Thresold")], | |
| outputs= [gr.Image(width=640, height=640, label="Output")], | |
| title=title, | |
| description=description, | |
| examples=examples) |