Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| """A demo of the VitPose model. | |
| This code is based on the implementation from the Colab notebook: | |
| https://colab.research.google.com/drive/1e8fcby5rhKZWcr9LSN8mNbQ0TU4Dxxpo | |
| """ | |
| import pathlib | |
| import tempfile | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import PIL.Image | |
| import spaces | |
| import supervision as sv | |
| import torch | |
| import tqdm | |
| from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation | |
| DESCRIPTION = "# DAB-DETR" | |
| MAX_NUM_FRAMES = 300 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| checkpoint = "IDEA-Research/dab-detr-resnet-50-dc5-pat3" | |
| image_processor = AutoProcessor.from_pretrained(checkpoint) | |
| model = RTDetrForObjectDetection.from_pretrained(checkpoint, device_map=device) | |
| def process_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dict]]: | |
| inputs = image_processor(images=image, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| results = image_processor.post_process_object_detection( | |
| outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3 | |
| ) | |
| result = results[0] # take first image results | |
| boxes_xyxy = result["boxes"].cpu().numpy() | |
| detections = sv.Detections(xyxy=boxes_xyxy) | |
| bounding_box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=1) | |
| # annotate bounding boxes | |
| annotated_frame = bounding_box_annotator.annotate(scene=image.copy(), detections=detections) | |
| return annotated_frame | |
| def process_video( | |
| video_path: str, | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008 | |
| ) -> str: | |
| cap = cv2.VideoCapture(video_path) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_file: | |
| writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height)) | |
| for _ in tqdm.auto.tqdm(range(min(MAX_NUM_FRAMES, num_frames))): | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| rgb_frame = frame[:, :, ::-1] | |
| annotated_frame, _ = process_image(PIL.Image.fromarray(rgb_frame)) | |
| writer.write(np.asarray(annotated_frame)[:, :, ::-1]) | |
| writer.release() | |
| cap.release() | |
| return out_file.name | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tabs(): | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| run_button_image = gr.Button() | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Image") | |
| output_json = gr.JSON(label="Output JSON") | |
| gr.Examples( | |
| examples=sorted(pathlib.Path("images").glob("*.jpg")), | |
| inputs=input_image, | |
| outputs=[output_image, output_json], | |
| fn=process_image, | |
| ) | |
| run_button_image.click( | |
| fn=process_image, | |
| inputs=input_image, | |
| outputs=[output_image, output_json], | |
| ) | |
| with gr.Tab("Video"): | |
| gr.Markdown(f"The input video will be truncated to {MAX_NUM_FRAMES} frames.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video") | |
| run_button_video = gr.Button() | |
| with gr.Column(): | |
| output_video = gr.Video(label="Output Video") | |
| gr.Examples( | |
| examples=sorted(pathlib.Path("videos").glob("*.mp4")), | |
| inputs=input_video, | |
| outputs=output_video, | |
| fn=process_video, | |
| cache_examples=False, | |
| ) | |
| run_button_video.click( | |
| fn=process_video, | |
| inputs=input_video, | |
| outputs=output_video, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |