Let user configurate bbox_thr & nms_thr in GUI
Browse files
README.md
CHANGED
|
@@ -28,4 +28,9 @@ This HuggingFace Space runs the RTMO (Real-Time Multi-Person) 2D pose estimation
|
|
| 28 |
We use the `rtmo` alias defined in MMPose’s model zoo. To override, upload your own checkpoint.
|
| 29 |
|
| 30 |
## Development
|
| 31 |
-
If you need to update dependencies or change the model, modify `requirements.txt` and `app.py` accordingly.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
We use the `rtmo` alias defined in MMPose’s model zoo. To override, upload your own checkpoint.
|
| 29 |
|
| 30 |
## Development
|
| 31 |
+
If you need to update dependencies or change the model, modify `requirements.txt` and `app.py` accordingly.
|
| 32 |
+
|
| 33 |
+
## Todos
|
| 34 |
+
1. Let user configurate bbox_thr & nms_thr in GUI
|
| 35 |
+
2. Support video input
|
| 36 |
+
3. Support models in ONNX format via rtmlib
|
app.py
CHANGED
|
@@ -56,7 +56,7 @@ def detect_rtmo_variant(checkpoint_path: str) -> str:
|
|
| 56 |
|
| 57 |
# ——— Gradio prediction function ———
|
| 58 |
@spaces.GPU()
|
| 59 |
-
def predict(image: Image.Image, checkpoint):
|
| 60 |
# save upload to temp file
|
| 61 |
inp_path = "/tmp/upload.jpg"
|
| 62 |
image.save(inp_path)
|
|
@@ -70,8 +70,8 @@ def predict(image: Image.Image, checkpoint):
|
|
| 70 |
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
|
| 71 |
for result in inferencer(
|
| 72 |
inputs=inp_path,
|
| 73 |
-
bbox_thr=
|
| 74 |
-
nms_thr=
|
| 75 |
pose_based_nms=True,
|
| 76 |
show=False,
|
| 77 |
vis_out_dir=vis_dir,
|
|
@@ -89,6 +89,8 @@ demo = gr.Interface(
|
|
| 89 |
inputs=[
|
| 90 |
gr.Image(type="pil", label="Upload Image"),
|
| 91 |
gr.File(file_types=['.pth'], label="Upload RTMO .pth Checkpoint (optional)")
|
|
|
|
|
|
|
| 92 |
],
|
| 93 |
outputs=gr.Image(type="pil", label="Annotated Image"),
|
| 94 |
title="RTMO Pose Demo",
|
|
|
|
| 56 |
|
| 57 |
# ——— Gradio prediction function ———
|
| 58 |
@spaces.GPU()
|
| 59 |
+
def predict(image: Image.Image, checkpoint, bbox_thr: float, nms_thr: float):
|
| 60 |
# save upload to temp file
|
| 61 |
inp_path = "/tmp/upload.jpg"
|
| 62 |
image.save(inp_path)
|
|
|
|
| 70 |
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
|
| 71 |
for result in inferencer(
|
| 72 |
inputs=inp_path,
|
| 73 |
+
bbox_thr=bbox_thr,
|
| 74 |
+
nms_thr=nms_thr,
|
| 75 |
pose_based_nms=True,
|
| 76 |
show=False,
|
| 77 |
vis_out_dir=vis_dir,
|
|
|
|
| 89 |
inputs=[
|
| 90 |
gr.Image(type="pil", label="Upload Image"),
|
| 91 |
gr.File(file_types=['.pth'], label="Upload RTMO .pth Checkpoint (optional)")
|
| 92 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.1, label="Bounding Box Threshold (bbox_thr)"),
|
| 93 |
+
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.65, label="NMS Threshold (nms_thr)"),
|
| 94 |
],
|
| 95 |
outputs=gr.Image(type="pil", label="Annotated Image"),
|
| 96 |
title="RTMO Pose Demo",
|