Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import supervision as sv | |
| from transformers import ( | |
| RTDetrForObjectDetection, | |
| RTDetrImageProcessor, | |
| VitPoseConfig, | |
| VitPoseForPoseEstimation, | |
| VitPoseImageProcessor, | |
| ) | |
| KEYPOINT_LABEL_MAP = { | |
| 0: "Nose", | |
| 1: "L_Eye", | |
| 2: "R_Eye", | |
| 3: "L_Ear", | |
| 4: "R_Ear", | |
| 5: "L_Shoulder", | |
| 6: "R_Shoulder", | |
| 7: "L_Elbow", | |
| 8: "R_Elbow", | |
| 9: "L_Wrist", | |
| 10: "R_Wrist", | |
| 11: "L_Hip", | |
| 12: "R_Hip", | |
| 13: "L_Knee", | |
| 14: "R_Knee", | |
| 15: "L_Ankle", | |
| 16: "R_Ankle", | |
| } | |
| class KeypointDetector: | |
| def __init__(self): | |
| self.person_detector = None | |
| self.person_processor = None | |
| self.pose_model = None | |
| self.pose_processor = None | |
| self.load_models() | |
| def load_models(self): | |
| """Load all required models""" | |
| # Object detection model | |
| self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") | |
| self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") | |
| # Pose estimation model | |
| self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple") | |
| self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple") | |
| def pascal_voc_to_coco(bboxes: np.ndarray) -> np.ndarray: | |
| """Convert Pascal VOC format to COCO format""" | |
| bboxes = bboxes.copy() # Create a copy to avoid modifying the input | |
| bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] | |
| bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] | |
| return bboxes | |
| def coco_to_xyxy(bboxes: np.ndarray) -> np.ndarray: | |
| """Convert COCO format (x,y,w,h) to xyxy format (x1,y1,x2,y2)""" | |
| bboxes = bboxes.copy() | |
| bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] | |
| bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] | |
| return bboxes | |
| def detect_persons(self, image: Image.Image): | |
| """Detect persons in the image""" | |
| inputs = self.person_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.person_detector(**inputs) | |
| results = self.person_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([(image.height, image.width)]), | |
| threshold=0.3 | |
| ) | |
| dets = sv.Detections.from_transformers(results[0]).with_nms(0.5) | |
| # Get boxes and scores for human class (index 0 in COCO dataset) | |
| boxes = dets.xyxy[dets.class_id == 0] | |
| scores = dets.confidence[dets.class_id == 0] | |
| return boxes, scores | |
| def detect_keypoints(self, image: Image.Image): | |
| """Detect keypoints in the image""" | |
| # Detect persons first | |
| boxes, scores = self.detect_persons(image) | |
| boxes_coco = [self.pascal_voc_to_coco(boxes)] | |
| # Detect pose keypoints | |
| pixel_values = self.pose_processor(image, boxes=boxes_coco, return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| outputs = self.pose_model(pixel_values) | |
| pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=boxes_coco)[0] | |
| return pose_results, boxes, scores | |
| def visualize_detections(self, image: Image.Image, pose_results, boxes, scores): | |
| """Visualize both bounding boxes and keypoints on the image""" | |
| # Convert image to numpy array if needed | |
| image_array = np.array(image) | |
| # Setup detections for bounding boxes | |
| detections = sv.Detections( | |
| xyxy=boxes, | |
| confidence=scores, | |
| class_id=np.array([0]*len(scores)) | |
| ) | |
| # Create box annotator | |
| box_annotator = sv.BoxAnnotator( | |
| color=sv.ColorPalette.DEFAULT, | |
| thickness=2 | |
| ) | |
| # Create edge annotator for keypoints | |
| edge_annotator = sv.EdgeAnnotator( | |
| color=sv.Color.GREEN, | |
| thickness=3 | |
| ) | |
| # Convert keypoints to supervision format | |
| key_points = sv.KeyPoints( | |
| xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy() | |
| ) | |
| # Annotate image with boxes first | |
| annotated_frame = box_annotator.annotate( | |
| scene=image_array.copy(), | |
| detections=detections | |
| ) | |
| # Then add keypoints | |
| annotated_frame = edge_annotator.annotate( | |
| scene=annotated_frame, | |
| key_points=key_points | |
| ) | |
| return Image.fromarray(annotated_frame) | |
| def process_image(self, input_image): | |
| """Process image and return visualization""" | |
| if input_image is None: | |
| return None, "" | |
| # Convert to PIL Image if necessary | |
| if isinstance(input_image, np.ndarray): | |
| image = Image.fromarray(input_image) | |
| else: | |
| image = input_image | |
| # Detect keypoints and boxes | |
| pose_results, boxes, scores = self.detect_keypoints(image) | |
| # Visualize results | |
| result_image = self.visualize_detections(image, pose_results, boxes, scores) | |
| # Create detection information text | |
| info_text = [] | |
| # Box information | |
| for i, (box, score) in enumerate(zip(boxes, scores)): | |
| info_text.append(f"\nPerson {i + 1} (confidence: {score:.2f})") | |
| info_text.append(f"Bounding Box: x1={box[0]:.1f}, y1={box[1]:.1f}, x2={box[2]:.1f}, y2={box[3]:.1f}") | |
| # Add keypoint information for this person | |
| pose_result = pose_results[i] | |
| for j, keypoint in enumerate(pose_result["keypoints"]): | |
| x, y, confidence = keypoint | |
| info_text.append(f"Keypoint {KEYPOINT_LABEL_MAP[j]}: x={x:.1f}, y={y:.1f}, confidence={confidence:.2f}") | |
| return result_image, "\n".join(info_text) | |
| def create_gradio_interface(): | |
| """Create Gradio interface""" | |
| detector = KeypointDetector() | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# Human Detection and Keypoint Estimation using VitPose") | |
| gr.Markdown("Upload an image to detect people and their keypoints. The model will:") | |
| gr.Markdown("1. Detect people in the image (shown as bounding boxes)") | |
| gr.Markdown("2. Identify keypoints for each detected person (shown as connected green lines)") | |
| gr.Markdown("Huge shoutout to @NielsRogge and @SangbumChoi for this work!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image") | |
| process_button = gr.Button("Detect People & Keypoints") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Detection Results") | |
| detection_info = gr.Textbox( | |
| label="Detection Information", | |
| lines=10, | |
| placeholder="Detection details will appear here..." | |
| ) | |
| process_button.click( | |
| fn=detector.process_image, | |
| inputs=input_image, | |
| outputs=[output_image, detection_info] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "http://images.cocodataset.org/val2017/000000000139.jpg" | |
| ], | |
| inputs=input_image | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_gradio_interface() | |
| interface.launch() |