Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| import pathlib | |
| import tempfile | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import PIL.Image | |
| import spaces | |
| import supervision as sv | |
| import torch | |
| import tqdm | |
| from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation | |
| DESCRIPTION = """ | |
| # ViTPose | |
| <div style="display: flex; gap: 10px;"> | |
| <a href="https://huggingface.co/docs/transformers/en/model_doc/vitpose"> | |
| <img src="https://img.shields.io/badge/Huggingface-FFD21E?style=flat&logo=Huggingface&logoColor=black" alt="Huggingface"> | |
| </a> | |
| <a href="https://arxiv.org/abs/2204.12484"> | |
| <img src="https://img.shields.io/badge/Arvix-B31B1B?style=flat&logo=arXiv&logoColor=white" alt="Paper"> | |
| </a> | |
| <a href="https://github.com/ViTAE-Transformer/ViTPose"> | |
| <img src="https://img.shields.io/badge/Github-100000?style=flat&logo=github&logoColor=white" alt="Github"> | |
| </a> | |
| </div> | |
| ViTPose is a state-of-the-art human pose estimation model based on Vision Transformers (ViT). It employs a standard, non-hierarchical ViT backbone and a simple decoder head to predict keypoint heatmaps from images. Despite its simplicity, ViTPose achieves top results on the MS COCO Keypoint Detection benchmark. | |
| ViTPose++ further improves performance with a mixture-of-experts (MoE) module and extensive pre-training. The model is scalable, flexible, and demonstrates strong transferability across pose estimation tasks. | |
| **Key features:** | |
| - PyTorch implementation | |
| - Scalable model size (100M to 1B parameters) | |
| - Flexible training and inference | |
| - State-of-the-art accuracy on challenging benchmarks | |
| """ | |
| COLORS = [ | |
| "#A351FB", | |
| "#FF4040", | |
| "#FFA1A0", | |
| "#FF7633", | |
| "#FFB633", | |
| "#D1D435", | |
| "#4CFB12", | |
| "#94CF1A", | |
| "#40DE8A", | |
| "#1B9640", | |
| "#00D6C1", | |
| "#2E9CAA", | |
| "#00C4FF", | |
| "#364797", | |
| "#6675FF", | |
| "#0019EF", | |
| "#863AFF", | |
| ] | |
| COLORS = [sv.Color.from_hex(color_hex=c) for c in COLORS] | |
| MAX_NUM_FRAMES = 300 | |
| keypoint_score = 0.3 | |
| enable_labels_annotator = True | |
| enable_vertices_annotator = True | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| person_detector_name = "PekingU/rtdetr_r50vd_coco_o365" | |
| person_image_processor = AutoProcessor.from_pretrained(person_detector_name) | |
| person_model = RTDetrForObjectDetection.from_pretrained(person_detector_name, device_map=device) | |
| pose_model_name = "usyd-community/vitpose-base-simple" | |
| pose_image_processor = AutoProcessor.from_pretrained(pose_model_name) | |
| pose_model = VitPoseForPoseEstimation.from_pretrained(pose_model_name, device_map=device) | |
| def detect_pose_image( | |
| image: PIL.Image.Image, | |
| threshold: float = 0.3, | |
| enable_labels_annotator: bool = True, | |
| enable_vertices_annotator: bool = True, | |
| ) -> tuple[PIL.Image.Image, list[dict]]: | |
| """Detects persons and estimates their poses in a single image. | |
| Args: | |
| image (PIL.Image.Image): Input image in which to detect persons and estimate poses. | |
| threshold (Float): Confidence threshold for pose keypoints. | |
| enable_labels_annotator (bool): Whether to enable annotating labels for pose keypoints. | |
| enable_vertices_annotator (bool): Whether to enable annotating vertices for pose keypoints | |
| Returns: | |
| tuple[PIL.Image.Image, list[dict]]: | |
| - Annotated image with bounding boxes and pose keypoints drawn. | |
| - List of dictionaries containing human-readable pose estimation results for each detected person. | |
| """ | |
| inputs = person_image_processor(images=image, return_tensors="pt").to(device) | |
| outputs = person_model(**inputs) | |
| results = person_image_processor.post_process_object_detection( | |
| outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=threshold | |
| ) | |
| result = results[0] # take first image results | |
| detections = sv.Detections.from_transformers(result) | |
| person_detections_xywh = sv.xyxy_to_xywh(detections[detections.class_id == 0].xyxy) | |
| inputs = pose_image_processor(image, boxes=[person_detections_xywh], return_tensors="pt").to(device) | |
| # for vitpose-plus-base checkpoint we should additionally provide dataset_index | |
| # to specify which MOE experts to use for inference | |
| if pose_model.config.backbone_config.num_experts > 1: | |
| dataset_index = torch.tensor([0] * len(inputs["pixel_values"])) | |
| dataset_index = dataset_index.to(inputs["pixel_values"].device) | |
| inputs["dataset_index"] = dataset_index | |
| outputs = pose_model(**inputs) | |
| pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[person_detections_xywh]) | |
| image_pose_result = pose_results[0] # results for first image | |
| # make results more human-readable | |
| human_readable_results = [] | |
| person_pose_labels = [] | |
| for i, person_pose in enumerate(image_pose_result): | |
| data = { | |
| "person_id": i, | |
| "bbox": person_pose["bbox"].numpy().tolist(), | |
| "keypoints": [], | |
| } | |
| for keypoint, label, score in zip( | |
| person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True | |
| ): | |
| keypoint_name = pose_model.config.id2label[label.item()] | |
| person_pose_labels.append(keypoint_name) | |
| x, y = keypoint | |
| data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()}) | |
| human_readable_results.append(data) | |
| line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=(image.width, image.height)) | |
| text_scale = sv.calculate_optimal_text_scale(resolution_wh=(image.width, image.height)) | |
| edge_annotator = sv.EdgeAnnotator(color=sv.Color.WHITE, thickness=line_thickness) | |
| vertex_annotator = sv.VertexAnnotator(color=sv.Color.BLUE, radius=3) | |
| box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=3) | |
| vertex_label_annotator = sv.VertexLabelAnnotator( | |
| color=COLORS, smart_position=True, border_radius=3, text_thickness=2, text_scale=text_scale | |
| ) | |
| annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections) | |
| for _, person_pose in enumerate(image_pose_result): | |
| person_keypoints = sv.KeyPoints.from_transformers([person_pose]) | |
| person_labels = [pose_model.config.id2label[label.item()] for label in person_pose["labels"]] | |
| # annotate edges and vertices for this person | |
| annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=person_keypoints) | |
| # annotate labels for this person | |
| if enable_labels_annotator: | |
| annotated_frame = vertex_label_annotator.annotate( | |
| scene=np.array(annotated_frame), key_points=person_keypoints, labels=person_labels | |
| ) | |
| # annotate vertices for this person | |
| if enable_vertices_annotator: | |
| annotated_frame = vertex_annotator.annotate(scene=annotated_frame, key_points=person_keypoints) | |
| return annotated_frame, human_readable_results | |
| # Decorate this function with `@spaces.GPU` to ensure that ZeroGPU is allocated once for the entire video processing. | |
| # Although `detect_pose_image` (called per frame) is already decorated, without this decorator, ZeroGPU would be invoked for each frame, | |
| # causing significant overhead and slowdowns. By decorating this function, all frames are processed sequentially after a single GPU allocation. | |
| def detect_pose_video( | |
| video_path: str, | |
| threshold: float, | |
| enable_labels_annotator: bool = True, | |
| enable_vertices_annotator: bool = True, | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008 | |
| ) -> str: | |
| """Detects persons and estimates their poses for each frame in a video, saving the annotated video. | |
| Args: | |
| video_path (str): Path to the input video file. | |
| threshold (Float): Confidence threshold for pose keypoints. | |
| enable_labels_annotator (bool): Whether to enable annotating labels for pose keypoints. | |
| enable_vertices_annotator (bool): Whether to enable annotating vertices for pose keypoints. | |
| progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True). | |
| Returns: | |
| str: Path to the output video file with annotated bounding boxes and pose keypoints. | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_file: | |
| writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height)) | |
| for _ in tqdm.auto.tqdm(range(min(MAX_NUM_FRAMES, num_frames))): | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| rgb_frame = frame[:, :, ::-1] | |
| annotated_frame, _ = detect_pose_image( | |
| PIL.Image.fromarray(rgb_frame), | |
| threshold=threshold, | |
| enable_labels_annotator=enable_labels_annotator, | |
| enable_vertices_annotator=enable_vertices_annotator, | |
| ) | |
| writer.write(np.asarray(annotated_frame)[:, :, ::-1]) | |
| writer.release() | |
| cap.release() | |
| return out_file.name | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| keypoint_score = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.6, | |
| step=0.01, | |
| info="Adjust the confidence threshold for keypoint detection.", | |
| label="Keypoint Score Threshold", | |
| ) | |
| enable_labels_annotator = gr.Checkbox(interactive=True, value=True, label="Enable Labels") | |
| enable_vertices_annotator = gr.Checkbox(interactive=True, value=True, label="Enable Vertices") | |
| with gr.Tabs(): | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| run_button_image = gr.Button() | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Image") | |
| output_json = gr.JSON(label="Output JSON") | |
| gr.Examples( | |
| examples=[[str(img), 0.5, True, True] for img in sorted(pathlib.Path("images").glob("*.jpg"))], | |
| inputs=[input_image, keypoint_score, enable_labels_annotator, enable_vertices_annotator], | |
| outputs=[output_image, output_json], | |
| fn=detect_pose_image, | |
| ) | |
| run_button_image.click( | |
| fn=detect_pose_image, | |
| inputs=[input_image, keypoint_score, enable_labels_annotator, enable_vertices_annotator], | |
| outputs=[output_image, output_json], | |
| ) | |
| with gr.Tab("Video"): | |
| gr.Markdown(f"The input video will be truncated to {MAX_NUM_FRAMES} frames.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video") | |
| run_button_video = gr.Button() | |
| with gr.Column(): | |
| output_video = gr.Video(label="Output Video") | |
| gr.Examples( | |
| examples=[[str(video), 0.5, True, True] for video in sorted(pathlib.Path("videos").glob("*.mp4"))], | |
| inputs=[input_video, keypoint_score, enable_labels_annotator, enable_vertices_annotator], | |
| outputs=output_video, | |
| fn=detect_pose_video, | |
| cache_examples=False, | |
| ) | |
| run_button_video.click( | |
| fn=detect_pose_video, | |
| inputs=[input_video, keypoint_score, enable_labels_annotator, enable_vertices_annotator], | |
| outputs=output_video, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True, ssr_mode=False) | |