|
|
|
|
|
import spaces |
|
|
import os |
|
|
import sys |
|
|
import importlib.util |
|
|
import re |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import torch |
|
|
import requests |
|
|
import shutil |
|
|
|
|
|
|
|
|
try: |
|
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
print(f"CUDA version: {torch.version.cuda}") |
|
|
print(f"GPU device: {torch.cuda.get_device_name(0)}") |
|
|
except: |
|
|
print('CUDA is not available !') |
|
|
|
|
|
|
|
|
spec = importlib.util.find_spec('mmdet') |
|
|
if spec and spec.origin: |
|
|
src = open(spec.origin, encoding='utf-8').read() |
|
|
patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src) |
|
|
m = importlib.util.module_from_spec(spec) |
|
|
m.__loader__ = spec.loader |
|
|
m.__file__ = spec.origin |
|
|
m.__path__ = spec.submodule_search_locations |
|
|
sys.modules['mmdet'] = m |
|
|
exec(compile(patched, spec.origin, 'exec'), m.__dict__) |
|
|
|
|
|
from mmpose.apis.inferencers import MMPoseInferencer |
|
|
|
|
|
|
|
|
REMOTE_CHECKPOINTS = { |
|
|
|
|
|
"rtmo-s_8xb32-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_coco-640x640-8db55a59_20231211.pth", |
|
|
"rtmo-m_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_coco-640x640-6f4e0306_20231211.pth", |
|
|
"rtmo-l_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_coco-640x640-516a421f_20231211.pth", |
|
|
|
|
|
"rtmo-t_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-t_8xb32-600e_body7-416x416-f48f75cb_20231219.pth", |
|
|
"rtmo-s_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_body7-640x640-dac2bf74_20231211.pth", |
|
|
"rtmo-m_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_body7-640x640-39e78cc4_20231211.pth", |
|
|
"rtmo-l_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_body7-640x640-b37118ce_20231211.pth", |
|
|
|
|
|
"rtmo-s_8xb32-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-700e_crowdpose-640x640-79f81c0d_20231211.pth", |
|
|
"rtmo-m_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rrtmo-m_16xb16-700e_crowdpose-640x640-0eaf670d_20231211.pth", |
|
|
"rtmo-l_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-700e_crowdpose-640x640-1008211f_20231211.pth", |
|
|
|
|
|
"rtmo-s_coco_retrainable": "https://huggingface.co/Luigi/Retrainable-RTMO-s/resolve/main/rtmo-s_coco_retrainable.pth", |
|
|
"rtmo-s_body6_retrainable": "https://huggingface.co/Luigi/Retrainable-RTMO-s/resolve/main/body6_epoch_600.pth", |
|
|
} |
|
|
|
|
|
|
|
|
VARIANT_PREFIX = { |
|
|
24: "rtmo-t_8xb32-600e_body7-416x416", |
|
|
32: "rtmo-s_8xb32-600e_body7-640x640", |
|
|
48: "rtmo-m_16xb16-600e_body7-640x640", |
|
|
64: "rtmo-l_16xb16-600e_body7-640x640", |
|
|
} |
|
|
|
|
|
|
|
|
def get_checkpoint(path_or_key: str) -> str: |
|
|
if path_or_key in REMOTE_CHECKPOINTS: |
|
|
url = REMOTE_CHECKPOINTS[path_or_key] |
|
|
local_path = f"/tmp/{path_or_key}.pth" |
|
|
if not os.path.exists(local_path): |
|
|
r = requests.get(url, stream=True) |
|
|
with open(local_path, 'wb') as f: |
|
|
for chunk in r.iter_content(1024): |
|
|
f.write(chunk) |
|
|
return local_path |
|
|
return path_or_key |
|
|
|
|
|
|
|
|
def detect_rtmo_variant(checkpoint_path: str) -> str: |
|
|
ckpt = torch.load(checkpoint_path, map_location='cpu') |
|
|
state_dict = ckpt.get('state_dict', ckpt) |
|
|
key = 'backbone.stem.conv.conv.weight' |
|
|
if key not in state_dict: |
|
|
raise KeyError(f"Cannot find '{key}' in checkpoint.") |
|
|
out_ch = state_dict[key].shape[0] |
|
|
return VARIANT_PREFIX.get(out_ch, 'rtmo-s_8xb32-600e_body7-640x640') |
|
|
|
|
|
|
|
|
def load_inferencer(checkpoint_path=None, device=None): |
|
|
kwargs = {'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]} |
|
|
if checkpoint_path: |
|
|
variant = detect_rtmo_variant(checkpoint_path) |
|
|
kwargs['pose2d'] = variant |
|
|
kwargs['pose2d_weights'] = checkpoint_path |
|
|
else: |
|
|
kwargs['pose2d'] = 'rtmo' |
|
|
return MMPoseInferencer(**kwargs) |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def predict(image: Image.Image, |
|
|
video, |
|
|
remote_ckpt: str, |
|
|
upload_ckpt, |
|
|
bbox_thr: float, |
|
|
nms_thr: float): |
|
|
|
|
|
if video: |
|
|
|
|
|
if isinstance(video, dict) and 'name' in video: |
|
|
inp_path = video['name'] |
|
|
elif hasattr(video, "name"): |
|
|
inp_path = video.name |
|
|
else: |
|
|
inp_path = video |
|
|
else: |
|
|
inp_path = "/tmp/upload.jpg" |
|
|
image.save(inp_path) |
|
|
|
|
|
|
|
|
ext = os.path.splitext(inp_path)[1].lower() |
|
|
is_video = ext in (".mp4", ".mov", ".avi", ".mkv", ".webm") |
|
|
|
|
|
|
|
|
if upload_ckpt: |
|
|
ckpt_path = upload_ckpt.name |
|
|
active = os.path.basename(ckpt_path) |
|
|
else: |
|
|
ckpt_path = get_checkpoint(remote_ckpt) |
|
|
active = remote_ckpt |
|
|
|
|
|
|
|
|
vis_dir = "/tmp/vis" |
|
|
if os.path.exists(vis_dir): |
|
|
shutil.rmtree(vis_dir) |
|
|
os.makedirs(vis_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None) |
|
|
for _ in inferencer( |
|
|
inputs=inp_path, |
|
|
bbox_thr=bbox_thr, |
|
|
nms_thr=nms_thr, |
|
|
pose_based_nms=True, |
|
|
show=False, |
|
|
vis_out_dir=vis_dir, |
|
|
): |
|
|
pass |
|
|
|
|
|
|
|
|
out_files = sorted(os.listdir(vis_dir)) |
|
|
if is_video: |
|
|
|
|
|
out_vid = next((f for f in out_files if f.lower().endswith((".mp4", ".mov", ".avi"))), None) |
|
|
return None, os.path.join(vis_dir, out_vid) if out_vid else None, active |
|
|
else: |
|
|
|
|
|
img_f = out_files[0] if out_files else None |
|
|
vis_img = Image.open(os.path.join(vis_dir, img_f)) if img_f and not img_f.lower().endswith((".mp4", ".mov", ".avi")) else None |
|
|
return vis_img, None, active |
|
|
|
|
|
|
|
|
def main(): |
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## RTMO Pose Demo") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1, min_width=300): |
|
|
img_input = gr.Image(type="pil", label="Upload Image") |
|
|
video_input = gr.Video(label="Upload Video") |
|
|
remote_dd = gr.Dropdown( |
|
|
label="Select Remote Checkpoint", |
|
|
choices=list(REMOTE_CHECKPOINTS.keys()), |
|
|
value=list(REMOTE_CHECKPOINTS.keys())[0] |
|
|
) |
|
|
upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)") |
|
|
bbox_thr = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="Bounding Box Threshold") |
|
|
nms_thr = gr.Slider(0.0, 1.0, value=0.65, step=0.01, label="NMS Threshold") |
|
|
run_btn = gr.Button("Run Inference") |
|
|
with gr.Column(scale=2): |
|
|
output_img = gr.Image(type="pil", label="Annotated Image", elem_id="output_image", interactive=False) |
|
|
output_video = gr.Video(label="Annotated Video", interactive=False) |
|
|
active_tb = gr.Textbox(label="Active Checkpoint", interactive=False) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_coco_retrainable", None, 0.1, 0.65], |
|
|
["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65], |
|
|
["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65], |
|
|
|
|
|
[None, |
|
|
"https://archive.org/download/fred-otts-sneeze/Fred%20Ott%20Sneeze%201894%20GG%20Restore.mp4", |
|
|
"rtmo-s_coco_retrainable", None, 0.1, 0.65], |
|
|
], |
|
|
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr], |
|
|
outputs=[output_img, output_video, active_tb], |
|
|
fn=predict, |
|
|
cache_examples=False, |
|
|
label="Examples", |
|
|
examples_per_page=4 |
|
|
) |
|
|
|
|
|
run_btn.click( |
|
|
predict, |
|
|
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr], |
|
|
outputs=[output_img, output_video, active_tb] |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |