Add video input
Browse files
README.md
CHANGED
|
@@ -61,13 +61,3 @@ The following variants are available out of the box:
|
|
| 61 |
- **app.py**: Main Gradio application script.
|
| 62 |
- **requirements.txt**: Python dependencies, including MMCV and MMPose.
|
| 63 |
- **README.md**: This documentation file.
|
| 64 |
-
|
| 65 |
-
## Development
|
| 66 |
-
|
| 67 |
-
To update dependencies, edit `requirements.txt`. To extend functionality or add new variants, modify `app.py` accordingly.
|
| 68 |
-
|
| 69 |
-
## Future Plans
|
| 70 |
-
|
| 71 |
-
1. Support video input streams.
|
| 72 |
-
2. Enable ONNX model inference via `rtmlib`.
|
| 73 |
-
|
|
|
|
| 61 |
- **app.py**: Main Gradio application script.
|
| 62 |
- **requirements.txt**: Python dependencies, including MMCV and MMPose.
|
| 63 |
- **README.md**: This documentation file.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
import spaces
|
| 3 |
-
import os
|
|
|
|
|
|
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
from PIL import Image
|
| 6 |
import torch
|
| 7 |
import requests # for downloading remote checkpoints
|
|
|
|
| 8 |
|
| 9 |
# CUDA info
|
| 10 |
try:
|
|
@@ -92,22 +96,45 @@ def load_inferencer(checkpoint_path=None, device=None):
|
|
| 92 |
# ββββ Prediction function ββββ
|
| 93 |
@spaces.GPU()
|
| 94 |
def predict(image: Image.Image,
|
|
|
|
| 95 |
remote_ckpt: str,
|
| 96 |
upload_ckpt,
|
| 97 |
bbox_thr: float,
|
| 98 |
nms_thr: float):
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
if upload_ckpt:
|
| 102 |
ckpt_path = upload_ckpt.name
|
| 103 |
active = os.path.basename(ckpt_path)
|
| 104 |
else:
|
| 105 |
ckpt_path = get_checkpoint(remote_ckpt)
|
| 106 |
active = remote_ckpt
|
|
|
|
|
|
|
| 107 |
vis_dir = "/tmp/vis"
|
|
|
|
|
|
|
| 108 |
os.makedirs(vis_dir, exist_ok=True)
|
|
|
|
|
|
|
| 109 |
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
|
| 110 |
-
for
|
| 111 |
inputs=inp_path,
|
| 112 |
bbox_thr=bbox_thr,
|
| 113 |
nms_thr=nms_thr,
|
|
@@ -116,9 +143,18 @@ def predict(image: Image.Image,
|
|
| 116 |
vis_out_dir=vis_dir,
|
| 117 |
):
|
| 118 |
pass
|
|
|
|
|
|
|
| 119 |
out_files = sorted(os.listdir(vis_dir))
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
# ββββ Gradio UI ββββ
|
| 124 |
def main():
|
|
@@ -126,43 +162,44 @@ def main():
|
|
| 126 |
gr.Markdown("## RTMO Pose Demo")
|
| 127 |
with gr.Row():
|
| 128 |
with gr.Column(scale=1, min_width=300):
|
| 129 |
-
img_input
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
| 133 |
upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)")
|
| 134 |
-
bbox_thr
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
value=0.65, label="NMS Threshold")
|
| 138 |
-
run_btn = gr.Button("Run Inference")
|
| 139 |
with gr.Column(scale=2):
|
| 140 |
-
output_img
|
| 141 |
-
|
| 142 |
-
active_tb
|
| 143 |
-
|
| 144 |
# Examples for quick testing
|
| 145 |
gr.Examples(
|
| 146 |
examples=[
|
| 147 |
-
["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&
|
| 148 |
-
|
| 149 |
-
["https://images.pexels.com/photos/
|
| 150 |
-
"rtmo-t_8xb32-600e_body7", None, 0.1, 0.65],
|
| 151 |
-
["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
|
| 152 |
-
"rtmo-s_8xb32-600e_coco", None, 0.1, 0.65],
|
| 153 |
],
|
| 154 |
-
inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
|
| 155 |
-
outputs=[output_img, active_tb],
|
| 156 |
fn=predict,
|
| 157 |
cache_examples=False,
|
| 158 |
label="Examples",
|
| 159 |
examples_per_page=3
|
| 160 |
)
|
| 161 |
|
| 162 |
-
run_btn.click(
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
| 165 |
demo.launch()
|
| 166 |
|
| 167 |
if __name__ == "__main__":
|
| 168 |
-
main()
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
import spaces
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import importlib.util
|
| 6 |
+
import re
|
| 7 |
import gradio as gr
|
| 8 |
from PIL import Image
|
| 9 |
import torch
|
| 10 |
import requests # for downloading remote checkpoints
|
| 11 |
+
import shutil
|
| 12 |
|
| 13 |
# CUDA info
|
| 14 |
try:
|
|
|
|
| 96 |
# ββββ Prediction function ββββ
|
| 97 |
@spaces.GPU()
|
| 98 |
def predict(image: Image.Image,
|
| 99 |
+
video, # new video input
|
| 100 |
remote_ckpt: str,
|
| 101 |
upload_ckpt,
|
| 102 |
bbox_thr: float,
|
| 103 |
nms_thr: float):
|
| 104 |
+
# 1) Write image or pick up video file
|
| 105 |
+
if video:
|
| 106 |
+
# Gradio Video can come in as a filepath string or dict
|
| 107 |
+
if isinstance(video, dict) and 'name' in video:
|
| 108 |
+
inp_path = video['name']
|
| 109 |
+
elif hasattr(video, "name"):
|
| 110 |
+
inp_path = video.name
|
| 111 |
+
else:
|
| 112 |
+
inp_path = video
|
| 113 |
+
else:
|
| 114 |
+
inp_path = "/tmp/upload.jpg"
|
| 115 |
+
image.save(inp_path)
|
| 116 |
+
|
| 117 |
+
# 2) Determine by extension if this is video
|
| 118 |
+
ext = os.path.splitext(inp_path)[1].lower()
|
| 119 |
+
is_video = ext in (".mp4", ".mov", ".avi", ".mkv", ".webm")
|
| 120 |
+
|
| 121 |
+
# checkpoint selection
|
| 122 |
if upload_ckpt:
|
| 123 |
ckpt_path = upload_ckpt.name
|
| 124 |
active = os.path.basename(ckpt_path)
|
| 125 |
else:
|
| 126 |
ckpt_path = get_checkpoint(remote_ckpt)
|
| 127 |
active = remote_ckpt
|
| 128 |
+
|
| 129 |
+
# prepare (and clear) output dir
|
| 130 |
vis_dir = "/tmp/vis"
|
| 131 |
+
if os.path.exists(vis_dir):
|
| 132 |
+
shutil.rmtree(vis_dir)
|
| 133 |
os.makedirs(vis_dir, exist_ok=True)
|
| 134 |
+
|
| 135 |
+
# run inferencer (handles both image & video)
|
| 136 |
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
|
| 137 |
+
for _ in inferencer(
|
| 138 |
inputs=inp_path,
|
| 139 |
bbox_thr=bbox_thr,
|
| 140 |
nms_thr=nms_thr,
|
|
|
|
| 143 |
vis_out_dir=vis_dir,
|
| 144 |
):
|
| 145 |
pass
|
| 146 |
+
|
| 147 |
+
# collect and return results
|
| 148 |
out_files = sorted(os.listdir(vis_dir))
|
| 149 |
+
if is_video:
|
| 150 |
+
# return only the annotated video path
|
| 151 |
+
out_vid = next((f for f in out_files if f.lower().endswith((".mp4", ".mov", ".avi"))), None)
|
| 152 |
+
return None, os.path.join(vis_dir, out_vid) if out_vid else None, active
|
| 153 |
+
else:
|
| 154 |
+
# return only the annotated image
|
| 155 |
+
img_f = out_files[0] if out_files else None
|
| 156 |
+
vis_img = Image.open(os.path.join(vis_dir, img_f)) if img_f and not img_f.lower().endswith((".mp4", ".mov", ".avi")) else None
|
| 157 |
+
return vis_img, None, active
|
| 158 |
|
| 159 |
# ββββ Gradio UI ββββ
|
| 160 |
def main():
|
|
|
|
| 162 |
gr.Markdown("## RTMO Pose Demo")
|
| 163 |
with gr.Row():
|
| 164 |
with gr.Column(scale=1, min_width=300):
|
| 165 |
+
img_input = gr.Image(type="pil", label="Upload Image")
|
| 166 |
+
video_input = gr.Video(label="Upload Video")
|
| 167 |
+
remote_dd = gr.Dropdown(
|
| 168 |
+
label="Select Remote Checkpoint",
|
| 169 |
+
choices=list(REMOTE_CHECKPOINTS.keys()),
|
| 170 |
+
value=list(REMOTE_CHECKPOINTS.keys())[0]
|
| 171 |
+
)
|
| 172 |
upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)")
|
| 173 |
+
bbox_thr = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="Bounding Box Threshold")
|
| 174 |
+
nms_thr = gr.Slider(0.0, 1.0, value=0.65, step=0.01, label="NMS Threshold")
|
| 175 |
+
run_btn = gr.Button("Run Inference")
|
|
|
|
|
|
|
| 176 |
with gr.Column(scale=2):
|
| 177 |
+
output_img = gr.Image(type="pil", label="Annotated Image", elem_id="output_image", interactive=False)
|
| 178 |
+
output_video = gr.Video(label="Annotated Video", interactive=False)
|
| 179 |
+
active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
|
| 180 |
+
|
| 181 |
# Examples for quick testing
|
| 182 |
gr.Examples(
|
| 183 |
examples=[
|
| 184 |
+
["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_coco_retrainable", None, 0.1, 0.65],
|
| 185 |
+
["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65],
|
| 186 |
+
["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65],
|
|
|
|
|
|
|
|
|
|
| 187 |
],
|
| 188 |
+
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
|
| 189 |
+
outputs=[output_img, output_video, active_tb],
|
| 190 |
fn=predict,
|
| 191 |
cache_examples=False,
|
| 192 |
label="Examples",
|
| 193 |
examples_per_page=3
|
| 194 |
)
|
| 195 |
|
| 196 |
+
run_btn.click(
|
| 197 |
+
predict,
|
| 198 |
+
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
|
| 199 |
+
outputs=[output_img, output_video, active_tb]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
demo.launch()
|
| 203 |
|
| 204 |
if __name__ == "__main__":
|
| 205 |
+
main()
|