Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import supervision as sv | |
| import torch | |
| from render import draw_links, draw_points, keypoint_colors, link_colors | |
| from tqdm import tqdm | |
| from transformers import ( | |
| AutoProcessor, | |
| RTDetrForObjectDetection, | |
| VitPoseForPoseEstimation, | |
| ) | |
| css = """ | |
| .feedback textarea {font-size: 24px !important} | |
| """ | |
| device = "cuda" | |
| def calculate_end_frame_index(source_video_path): | |
| video_info = sv.VideoInfo.from_video_path(source_video_path) | |
| return video_info.total_frames | |
| def process_image( | |
| input_image, | |
| model_variant, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| # You can choose detector by your choice | |
| person_image_processor = AutoProcessor.from_pretrained( | |
| "PekingU/rtdetr_r50vd_coco_o365" | |
| ) | |
| person_model = RTDetrForObjectDetection.from_pretrained( | |
| "PekingU/rtdetr_r50vd_coco_o365", device_map=device | |
| ) | |
| if model_variant == "Base": | |
| model_name = "yonigozlan/synthpose-vitpose-base-hf" | |
| else: | |
| model_name = "yonigozlan/synthpose-vitpose-huge-hf" | |
| image_processor = AutoProcessor.from_pretrained(model_name) | |
| model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device) | |
| keypoint_edges = model.config.edges | |
| frame = np.array(input_image) | |
| inputs = person_image_processor(images=frame, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = person_model(**inputs) | |
| results = person_image_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([(frame.shape[0], frame.shape[1])]), | |
| threshold=0.4, | |
| ) | |
| result = results[0] # take first image results | |
| # Human label refers 0 index in COCO dataset | |
| person_boxes = result["boxes"][result["labels"] == 0] | |
| person_boxes = person_boxes.cpu().numpy() | |
| # Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format | |
| person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0] | |
| person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1] | |
| # ------------------------------------------------------------------------ | |
| # Stage 2. Detect keypoints for each person found | |
| # ------------------------------------------------------------------------ | |
| inputs = image_processor(frame, boxes=[person_boxes], return_tensors="pt").to( | |
| device | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| pose_results = image_processor.post_process_pose_estimation( | |
| outputs, boxes=[person_boxes] | |
| ) | |
| image_pose_result = pose_results[0] # results for first image | |
| for pose_result in image_pose_result: | |
| scores = np.array(pose_result["scores"]) | |
| keypoints = np.array(pose_result["keypoints"]) | |
| # draw each point on image | |
| draw_points( | |
| frame, | |
| keypoints, | |
| scores, | |
| keypoint_colors, | |
| keypoint_score_threshold=0.3, | |
| radius=max(2, int(max(frame.shape[0], frame.shape[1]) / 500)), | |
| show_keypoint_weight=False, | |
| ) | |
| # draw links | |
| draw_links( | |
| frame, | |
| keypoints, | |
| scores, | |
| keypoint_edges, | |
| link_colors, | |
| keypoint_score_threshold=0.3, | |
| thickness=max(2, int(max(frame.shape[0], frame.shape[1]) / 1000)), | |
| show_keypoint_weight=False, | |
| ) | |
| return frame | |
| def process_video( | |
| input_video, | |
| model_variant, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| video_info = sv.VideoInfo.from_video_path(input_video) | |
| total = calculate_end_frame_index(input_video) | |
| frame_generator = sv.get_video_frames_generator(source_path=input_video, end=total) | |
| result_file_name = "output.mp4" | |
| result_file_path = os.path.join(os.getcwd(), result_file_name) | |
| # You can choose detector by your choice | |
| person_image_processor = AutoProcessor.from_pretrained( | |
| "PekingU/rtdetr_r50vd_coco_o365" | |
| ) | |
| person_model = RTDetrForObjectDetection.from_pretrained( | |
| "PekingU/rtdetr_r50vd_coco_o365", device_map=device | |
| ) | |
| if model_variant == "Base": | |
| model_name = "yonigozlan/synthpose-vitpose-base-hf" | |
| else: | |
| model_name = "yonigozlan/synthpose-vitpose-huge-hf" | |
| image_processor = AutoProcessor.from_pretrained(model_name) | |
| model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device) | |
| keypoint_edges = model.config.edges | |
| with sv.VideoSink(result_file_path, video_info=video_info) as sink: | |
| for _ in tqdm(range(total), desc="Processing video.."): | |
| try: | |
| frame = next(frame_generator) | |
| except StopIteration: | |
| break | |
| # ------------------------------------------------------------------------ | |
| # Stage 1. Detect humans on the image | |
| # ------------------------------------------------------------------------ | |
| inputs = person_image_processor(images=frame, return_tensors="pt").to( | |
| device | |
| ) | |
| with torch.no_grad(): | |
| outputs = person_model(**inputs) | |
| results = person_image_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([(frame.shape[0], frame.shape[1])]), | |
| threshold=0.4, | |
| ) | |
| result = results[0] # take first image results | |
| # Human label refers 0 index in COCO dataset | |
| person_boxes = result["boxes"][result["labels"] == 0] | |
| person_boxes = person_boxes.cpu().numpy() | |
| # Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format | |
| person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0] | |
| person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1] | |
| # ------------------------------------------------------------------------ | |
| # Stage 2. Detect keypoints for each person found | |
| # ------------------------------------------------------------------------ | |
| if len(person_boxes) == 0: | |
| sink.write_frame(frame) | |
| continue | |
| inputs = image_processor( | |
| frame, boxes=[person_boxes], return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| pose_results = image_processor.post_process_pose_estimation( | |
| outputs, boxes=[person_boxes] | |
| ) | |
| image_pose_result = pose_results[0] # results for first image | |
| for pose_result in image_pose_result: | |
| scores = np.array(pose_result["scores"]) | |
| keypoints = np.array(pose_result["keypoints"]) | |
| # draw each point on image | |
| draw_points( | |
| frame, | |
| keypoints, | |
| scores, | |
| keypoint_colors, | |
| keypoint_score_threshold=0.3, | |
| radius=max(2, int(frame.shape[0] / 500)), | |
| show_keypoint_weight=False, | |
| ) | |
| # draw links | |
| draw_links( | |
| frame, | |
| keypoints, | |
| scores, | |
| keypoint_edges, | |
| link_colors, | |
| keypoint_score_threshold=0.3, | |
| thickness=max(1, int(frame.shape[0] / 1000)), | |
| show_keypoint_weight=False, | |
| ) | |
| sink.write_frame(frame) | |
| return result_file_path | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| gr.Markdown("## Markerless Motion Capture with SynthPose") | |
| gr.Markdown( | |
| """ | |
| SynthPose is a new approach that enables finetuning of pre-trained 2D human pose models to predict an arbitrarily denser set of keypoints for accurate kinematic analysis through the use of synthetic data. | |
| More details are available in [OpenCapBench: A Benchmark to Bridge Pose Estimation and Biomechanics](https://arxiv.org/abs/2406.09788).<br /> | |
| This particular variant was finetuned on a set of keypoints usually found on motion capture setups, and include coco keypoints as well.<br /> | |
| The keypoints part of the skeleton are the COCO keypoints, and the pink ones the anatomical markers. | |
| """ | |
| ) | |
| gr.Markdown( | |
| "Simply upload a video, and press run to start the inference! You can also try the examples below. π" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_choice = gr.Radio( | |
| ["Video", "Image"], label="Input Type", value="Video", interactive=True | |
| ) | |
| model_variant = gr.Radio( | |
| ["Base", "Huge"], label="Model Variant", value="Base", interactive=True | |
| ) | |
| input_video = gr.Video(label="Input Video") | |
| input_image = gr.Image(label="Input Image", visible=False) | |
| with gr.Column(): | |
| output_video = gr.Video(label="Output Video") | |
| output_image = gr.Image(label="Output Image", visible=False) | |
| with gr.Row(): | |
| submit_video = gr.Button(variant="primary") | |
| submit_image = gr.Button(variant="primary", visible=False) | |
| def switch_input_type(input_choice): | |
| input_type = input_choice | |
| if input_type == "Video": | |
| return [ | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| ] | |
| # input_video.visible = True | |
| # input_image.visible = False | |
| # output_video.visible = True | |
| # output_image.visible = False | |
| # submit_video.visible = True | |
| # submit_image.visible = False | |
| else: | |
| return [ | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| ] | |
| # input_video.visible = False | |
| # input_image.visible = True | |
| # output_video.visible = False | |
| # output_image.visible = True | |
| # submit_video.visible = False | |
| # submit_image.visible = True | |
| input_choice.change( | |
| switch_input_type, | |
| inputs=input_choice, | |
| outputs=[ | |
| input_video, | |
| input_image, | |
| output_video, | |
| output_image, | |
| submit_video, | |
| submit_image, | |
| ], | |
| ) | |
| example = gr.Examples( | |
| examples=[ | |
| ["./tennis.mp4"], | |
| ["./football.mp4"], | |
| ["./basket.mp4"], | |
| ["./hurdles.mp4"], | |
| ], | |
| inputs=[input_video], | |
| outputs=output_video, | |
| ) | |
| submit_video.click( | |
| fn=process_video, | |
| inputs=[input_video, model_variant], | |
| outputs=[output_video], | |
| ) | |
| submit_image.click( | |
| fn=process_image, | |
| inputs=[input_image, model_variant], | |
| outputs=[output_image], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |