Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import tqdm | |
| import uuid | |
| import logging | |
| import torch | |
| import spaces | |
| import trackers | |
| import numpy as np | |
| import gradio as gr | |
| import imageio.v3 as iio | |
| import supervision as sv | |
| from pathlib import Path | |
| from functools import lru_cache | |
| from typing import List, Optional, Tuple | |
| from PIL import Image | |
| from transformers import AutoModelForObjectDetection, AutoImageProcessor | |
| from transformers.image_utils import load_image | |
| # Configuration constants | |
| CHECKPOINTS = [ | |
| "ustc-community/dfine-medium-obj2coco", | |
| "ustc-community/dfine-medium-coco", | |
| "ustc-community/dfine-medium-obj365", | |
| "ustc-community/dfine-nano-coco", | |
| "ustc-community/dfine-small-coco", | |
| "ustc-community/dfine-large-coco", | |
| "ustc-community/dfine-xlarge-coco", | |
| "ustc-community/dfine-small-obj365", | |
| "ustc-community/dfine-large-obj365", | |
| "ustc-community/dfine-xlarge-obj365", | |
| "ustc-community/dfine-small-obj2coco", | |
| "ustc-community/dfine-large-obj2coco-e25", | |
| "ustc-community/dfine-xlarge-obj2coco", | |
| ] | |
| DEFAULT_CHECKPOINT = CHECKPOINTS[0] | |
| DEFAULT_CONFIDENCE_THRESHOLD = 0.3 | |
| TORCH_DTYPE = torch.float32 | |
| # Image | |
| IMAGE_EXAMPLES = [ | |
| {"path": "./examples/images/tennis.jpg", "use_url": False, "url": "", "label": "Local Image"}, | |
| {"path": "./examples/images/dogs.jpg", "use_url": False, "url": "", "label": "Local Image"}, | |
| {"path": "./examples/images/nascar.jpg", "use_url": False, "url": "", "label": "Local Image"}, | |
| {"path": "./examples/images/crossroad.jpg", "use_url": False, "url": "", "label": "Local Image"}, | |
| { | |
| "path": None, | |
| "use_url": True, | |
| "url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", | |
| "label": "Flickr Image", | |
| }, | |
| ] | |
| # Video | |
| MAX_NUM_FRAMES = 250 | |
| BATCH_SIZE = 4 | |
| ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} | |
| VIDEO_OUTPUT_DIR = Path("static/videos") | |
| VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| class TrackingAlgorithm: | |
| BYTETRACK = "ByteTrack (2021)" | |
| DEEPSORT = "DeepSORT (2017)" | |
| SORT = "SORT (2016)" | |
| TRACKERS = [None, TrackingAlgorithm.BYTETRACK, TrackingAlgorithm.DEEPSORT, TrackingAlgorithm.SORT] | |
| VIDEO_EXAMPLES = [ | |
| {"path": "./examples/videos/dogs_running.mp4", "label": "Local Video", "tracker": None, "classes": "all"}, | |
| {"path": "./examples/videos/traffic.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK, "classes": "car, truck, bus"}, | |
| {"path": "./examples/videos/fast_and_furious.mp4", "label": "Local Video", "tracker": None, "classes": "all"}, | |
| {"path": "./examples/videos/break_dance.mp4", "label": "Local Video", "tracker": None, "classes": "all"}, | |
| ] | |
| # Create a color palette for visualization | |
| # These hex color codes define different colors for tracking different objects | |
| color = sv.ColorPalette.from_hex([ | |
| "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", | |
| "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00" | |
| ]) | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def get_model_and_processor(checkpoint: str): | |
| model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE) | |
| image_processor = AutoImageProcessor.from_pretrained(checkpoint) | |
| return model, image_processor | |
| def detect_objects( | |
| checkpoint: str, | |
| images: List[np.ndarray] | np.ndarray, | |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
| target_size: Optional[Tuple[int, int]] = None, | |
| batch_size: int = BATCH_SIZE, | |
| classes: Optional[List[str]] = None, | |
| ): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, image_processor = get_model_and_processor(checkpoint) | |
| model = model.to(device) | |
| if classes is not None: | |
| wrong_classes = [cls for cls in classes if cls not in model.config.label2id] | |
| if wrong_classes: | |
| gr.Warning(f"Classes not found in model config: {wrong_classes}") | |
| keep_ids = [model.config.label2id[cls] for cls in classes if cls in model.config.label2id] | |
| else: | |
| keep_ids = None | |
| if isinstance(images, np.ndarray) and images.ndim == 4: | |
| images = [x for x in images] # split video array into list of images | |
| batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)] | |
| results = [] | |
| for batch in tqdm.tqdm(batches, desc="Processing frames"): | |
| # preprocess images | |
| inputs = image_processor(images=batch, return_tensors="pt") | |
| inputs = inputs.to(device).to(TORCH_DTYPE) | |
| # forward pass | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # postprocess outputs | |
| if target_size: | |
| target_sizes = [target_size] * len(batch) | |
| else: | |
| target_sizes = [(image.shape[0], image.shape[1]) for image in batch] | |
| batch_results = image_processor.post_process_object_detection( | |
| outputs, target_sizes=target_sizes, threshold=confidence_threshold | |
| ) | |
| results.extend(batch_results) | |
| # move results to cpu | |
| for i, result in enumerate(results): | |
| results[i] = {k: v.cpu() for k, v in result.items()} | |
| if keep_ids is not None: | |
| keep = torch.isin(results[i]["labels"], torch.tensor(keep_ids)) | |
| results[i] = {k: v[keep] for k, v in results[i].items()} | |
| return results, model.config.id2label | |
| def process_image( | |
| checkpoint: str = DEFAULT_CHECKPOINT, | |
| image: Optional[Image.Image] = None, | |
| url: Optional[str] = None, | |
| use_url: bool = False, | |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
| ): | |
| if not use_url: | |
| url = None | |
| if (image is None) ^ bool(url): | |
| raise ValueError(f"Either image or url must be provided, but not both.") | |
| if url: | |
| image = load_image(url) | |
| results, id2label = detect_objects( | |
| checkpoint=checkpoint, | |
| images=[np.array(image)], | |
| confidence_threshold=confidence_threshold, | |
| ) | |
| result = results[0] # first image in batch (we have batch size 1) | |
| annotations = [] | |
| for label, score, box in zip(result["labels"], result["scores"], result["boxes"]): | |
| text_label = id2label[label.item()] | |
| formatted_label = f"{text_label} ({score:.2f})" | |
| x_min, y_min, x_max, y_max = box.cpu().numpy().round().astype(int) | |
| x_min = max(0, x_min) | |
| y_min = max(0, y_min) | |
| x_max = min(image.width - 1, x_max) | |
| y_max = min(image.height - 1, y_max) | |
| annotations.append(((x_min, y_min, x_max, y_max), formatted_label)) | |
| return (image, annotations) | |
| def get_target_size(image_height, image_width, max_size: int): | |
| if image_height < max_size and image_width < max_size: | |
| new_height, new_width = image_height, image_width | |
| elif image_height > image_width: | |
| new_height = max_size | |
| new_width = int(image_width * max_size / image_height) | |
| else: | |
| new_width = max_size | |
| new_height = int(image_height * max_size / image_width) | |
| # make even (for video codec compatibility) | |
| new_height = new_height // 2 * 2 | |
| new_width = new_width // 2 * 2 | |
| return new_width, new_height | |
| def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1): | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| i = 0 | |
| progress_bar = tqdm.tqdm(total=k, desc="Reading frames") | |
| while cap.isOpened() and len(frames) < k: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if i % read_every_i_frame == 0: | |
| frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| progress_bar.update(1) | |
| i += 1 | |
| cap.release() | |
| progress_bar.close() | |
| return frames | |
| def get_tracker(tracker: str, fps: float): | |
| if tracker == TrackingAlgorithm.SORT: | |
| return trackers.SORTTracker(frame_rate=fps) | |
| elif tracker == TrackingAlgorithm.DEEPSORT: | |
| feature_extractor = trackers.DeepSORTFeatureExtractor.from_timm("mobilenetv4_conv_small.e1200_r224_in1k", device="cpu") | |
| return trackers.DeepSORTTracker(feature_extractor, frame_rate=fps) | |
| elif tracker == TrackingAlgorithm.BYTETRACK: | |
| return sv.ByteTrack(frame_rate=int(fps)) | |
| else: | |
| raise ValueError(f"Invalid tracker: {tracker}") | |
| def update_tracker(tracker, detections, frame): | |
| tracker_name = tracker.__class__.__name__ | |
| if tracker_name == "SORTTracker": | |
| return tracker.update(detections) | |
| elif tracker_name == "DeepSORTTracker": | |
| return tracker.update(detections, frame) | |
| elif tracker_name == "ByteTrack": | |
| return tracker.update_with_detections(detections) | |
| else: | |
| raise ValueError(f"Invalid tracker: {tracker}") | |
| def process_video( | |
| video_path: str, | |
| checkpoint: str, | |
| tracker_algorithm: Optional[str] = None, | |
| classes: str = "all", | |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), | |
| ) -> str: | |
| if not video_path or not os.path.isfile(video_path): | |
| raise ValueError(f"Invalid video path: {video_path}") | |
| ext = os.path.splitext(video_path)[1].lower() | |
| if ext not in ALLOWED_VIDEO_EXTENSIONS: | |
| raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}") | |
| video_info = sv.VideoInfo.from_video_path(video_path) | |
| read_each_i_frame = max(1, video_info.fps // 25) | |
| target_fps = video_info.fps / read_each_i_frame | |
| target_width, target_height = get_target_size(video_info.height, video_info.width, 1080) | |
| n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame) | |
| frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame) | |
| frames = [cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC) for frame in frames] | |
| # Set the color lookup mode to assign colors by track ID | |
| # This mean objects with the same track ID will be annotated by the same color | |
| color_lookup = sv.ColorLookup.TRACK if tracker_algorithm else sv.ColorLookup.CLASS | |
| box_annotator = sv.BoxAnnotator(color, color_lookup=color_lookup, thickness=1) | |
| label_annotator = sv.LabelAnnotator(color, color_lookup=color_lookup, text_scale=0.5) | |
| trace_annotator = sv.TraceAnnotator(color, color_lookup=color_lookup, thickness=1, trace_length=100) | |
| # preprocess classes | |
| if classes != "all": | |
| classes_list = [cls.strip().lower() for cls in classes.split(",")] | |
| else: | |
| classes_list = None | |
| results, id2label = detect_objects( | |
| images=np.array(frames), | |
| checkpoint=checkpoint, | |
| confidence_threshold=confidence_threshold, | |
| target_size=(target_height, target_width), | |
| classes=classes_list, | |
| ) | |
| annotated_frames = [] | |
| # detections | |
| if tracker_algorithm: | |
| tracker = get_tracker(tracker_algorithm, target_fps) | |
| for frame, result in progress.tqdm(zip(frames, results), desc="Tracking objects", total=len(frames)): | |
| detections = sv.Detections.from_transformers(result, id2label=id2label) | |
| detections = detections.with_nms(threshold=0.95, class_agnostic=True) | |
| detections = update_tracker(tracker, detections, frame) | |
| labels = [f"#{tracker_id} {id2label[class_id]}" for class_id, tracker_id in zip(detections.class_id, detections.tracker_id)] | |
| annotated_frame = box_annotator.annotate(scene=frame, detections=detections) | |
| annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) | |
| annotated_frame = trace_annotator.annotate(scene=annotated_frame, detections=detections) | |
| annotated_frames.append(annotated_frame) | |
| else: | |
| for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)): | |
| detections = sv.Detections.from_transformers(result, id2label=id2label) | |
| detections = detections.with_nms(threshold=0.95, class_agnostic=True) | |
| annotated_frame = box_annotator.annotate(scene=frame, detections=detections) | |
| annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections) | |
| annotated_frames.append(annotated_frame) | |
| output_filename = os.path.join(VIDEO_OUTPUT_DIR, f"output_{uuid.uuid4()}.mp4") | |
| iio.imwrite(output_filename, annotated_frames, fps=target_fps, codec="h264") | |
| return output_filename | |
| def create_image_inputs() -> List[gr.components.Component]: | |
| return [ | |
| gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| sources=["upload", "webcam"], | |
| interactive=True, | |
| elem_classes="input-component", | |
| ), | |
| gr.Checkbox(label="Use Image URL Instead", value=False), | |
| gr.Textbox( | |
| label="Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| visible=False, | |
| elem_classes="input-component", | |
| ), | |
| gr.Dropdown( | |
| choices=CHECKPOINTS, | |
| label="Select Model Checkpoint", | |
| value=DEFAULT_CHECKPOINT, | |
| elem_classes="input-component", | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=DEFAULT_CONFIDENCE_THRESHOLD, | |
| step=0.1, | |
| label="Confidence Threshold", | |
| elem_classes="input-component", | |
| ), | |
| ] | |
| def create_video_inputs() -> List[gr.components.Component]: | |
| return [ | |
| gr.Video( | |
| label="Upload Video", | |
| sources=["upload"], | |
| interactive=True, | |
| format="mp4", # Ensure MP4 format | |
| elem_classes="input-component", | |
| ), | |
| gr.Dropdown( | |
| choices=CHECKPOINTS, | |
| label="Select Model Checkpoint", | |
| value=DEFAULT_CHECKPOINT, | |
| elem_classes="input-component", | |
| ), | |
| gr.Dropdown( | |
| choices=TRACKERS, | |
| label="Select Tracker (Optional)", | |
| value=None, | |
| elem_classes="input-component", | |
| ), | |
| gr.TextArea( | |
| label="Specify Class Names to Detect (comma separated)", | |
| value="all", | |
| lines=1, | |
| elem_classes="input-component", | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=DEFAULT_CONFIDENCE_THRESHOLD, | |
| step=0.1, | |
| label="Confidence Threshold", | |
| elem_classes="input-component", | |
| ), | |
| ] | |
| def create_button_row() -> List[gr.Button]: | |
| return [ | |
| gr.Button( | |
| f"Detect Objects", variant="primary", elem_classes="action-button" | |
| ), | |
| gr.Button(f"Clear", variant="secondary", elem_classes="action-button"), | |
| ] | |
| # Gradio interface | |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Object Detection Demo | |
| Experience state-of-the-art object detection with USTC's [D-FINE](https://huggingface.co/docs/transformers/main/model_doc/d_fine) models. | |
| - **Image** and **Video** modes are supported. | |
| - Select a model and adjust the confidence threshold to see detections! | |
| - On video mode, you can enable tracking powered by [Supervision](https://github.com/roboflow/supervision) and [Trackers](https://github.com/roboflow/trackers) from Roboflow. | |
| """, | |
| elem_classes="header-text", | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Group(): | |
| ( | |
| image_input, | |
| use_url, | |
| url_input, | |
| image_model_checkpoint, | |
| image_confidence_threshold, | |
| ) = create_image_inputs() | |
| image_detect_button, image_clear_button = create_button_row() | |
| with gr.Column(scale=2): | |
| image_output = gr.AnnotatedImage( | |
| label="Detection Results", | |
| show_label=True, | |
| color_map=None, | |
| elem_classes="output-component", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| DEFAULT_CHECKPOINT, | |
| example["path"], | |
| example["url"], | |
| example["use_url"], | |
| DEFAULT_CONFIDENCE_THRESHOLD, | |
| ] | |
| for example in IMAGE_EXAMPLES | |
| ], | |
| inputs=[ | |
| image_model_checkpoint, | |
| image_input, | |
| url_input, | |
| use_url, | |
| image_confidence_threshold, | |
| ], | |
| outputs=[image_output], | |
| fn=process_image, | |
| label="Select an image example to populate inputs", | |
| cache_examples=True, | |
| cache_mode="lazy", | |
| ) | |
| with gr.Tab("Video"): | |
| gr.Markdown( | |
| f"The input video will be processed in ~25 FPS (up to {MAX_NUM_FRAMES} frames in result)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Group(): | |
| video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold = create_video_inputs() | |
| video_detect_button, video_clear_button = create_button_row() | |
| with gr.Column(scale=2): | |
| video_output = gr.Video( | |
| label="Detection Results", | |
| format="mp4", # Explicit MP4 format | |
| elem_classes="output-component", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [example["path"], DEFAULT_CHECKPOINT, example["tracker"], example["classes"], DEFAULT_CONFIDENCE_THRESHOLD] | |
| for example in VIDEO_EXAMPLES | |
| ], | |
| inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold], | |
| outputs=[video_output], | |
| fn=process_video, | |
| cache_examples=False, | |
| label="Select a video example to populate inputs", | |
| ) | |
| # Dynamic visibility for URL input | |
| use_url.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=use_url, | |
| outputs=url_input, | |
| ) | |
| # Image clear button | |
| image_clear_button.click( | |
| fn=lambda: ( | |
| None, | |
| False, | |
| "", | |
| DEFAULT_CHECKPOINT, | |
| DEFAULT_CONFIDENCE_THRESHOLD, | |
| None, | |
| ), | |
| outputs=[ | |
| image_input, | |
| use_url, | |
| url_input, | |
| image_model_checkpoint, | |
| image_confidence_threshold, | |
| image_output, | |
| ], | |
| ) | |
| # Video clear button | |
| video_clear_button.click( | |
| fn=lambda: ( | |
| None, | |
| DEFAULT_CHECKPOINT, | |
| None, | |
| "all", | |
| DEFAULT_CONFIDENCE_THRESHOLD, | |
| None, | |
| ), | |
| outputs=[ | |
| video_input, | |
| video_checkpoint, | |
| video_tracker, | |
| video_classes, | |
| video_confidence_threshold, | |
| video_output, | |
| ], | |
| ) | |
| # Image detect button | |
| image_detect_button.click( | |
| fn=process_image, | |
| inputs=[ | |
| image_model_checkpoint, | |
| image_input, | |
| url_input, | |
| use_url, | |
| image_confidence_threshold, | |
| ], | |
| outputs=[image_output], | |
| ) | |
| # Video detect button | |
| video_detect_button.click( | |
| fn=process_video, | |
| inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold], | |
| outputs=[video_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |