support all variants with added vairant detection
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import spaces
|
|
| 3 |
import os, sys, importlib.util, re
|
| 4 |
import gradio as gr
|
| 5 |
from PIL import Image
|
|
|
|
| 6 |
|
| 7 |
# βββ Monkey-patch mmdet to remove its mmcv-version assertion βββ
|
| 8 |
spec = importlib.util.find_spec('mmdet')
|
|
@@ -23,8 +24,33 @@ def load_inferencer(checkpoint_path=None, device=None):
|
|
| 23 |
kwargs = {'pose2d': 'rtmo', 'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
|
| 24 |
if checkpoint_path:
|
| 25 |
kwargs['pose2d_weights'] = checkpoint_path
|
|
|
|
|
|
|
|
|
|
| 26 |
return MMPoseInferencer(**kwargs)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# βββ Gradio prediction function βββ
|
| 29 |
@spaces.GPU()
|
| 30 |
def predict(image: Image.Image, checkpoint):
|
|
|
|
| 3 |
import os, sys, importlib.util, re
|
| 4 |
import gradio as gr
|
| 5 |
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
|
| 8 |
# βββ Monkey-patch mmdet to remove its mmcv-version assertion βββ
|
| 9 |
spec = importlib.util.find_spec('mmdet')
|
|
|
|
| 24 |
kwargs = {'pose2d': 'rtmo', 'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
|
| 25 |
if checkpoint_path:
|
| 26 |
kwargs['pose2d_weights'] = checkpoint_path
|
| 27 |
+
# detect model variant
|
| 28 |
+
variant = detect_rtmo_variant(checkpoint_path)
|
| 29 |
+
kwargs['pose2d'] = variant
|
| 30 |
return MMPoseInferencer(**kwargs)
|
| 31 |
|
| 32 |
+
def detect_rtmo_variant(checkpoint_path: str) -> str:
|
| 33 |
+
"""
|
| 34 |
+
Inspect an RTMO .pth checkpoint and return its variant alias:
|
| 35 |
+
one of 'rtmo-l', 'rtmo-m', 'rtmo-s', 'rtmo-t', or 'unknown'.
|
| 36 |
+
"""
|
| 37 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
| 38 |
+
state_dict = ckpt.get('state_dict', ckpt)
|
| 39 |
+
|
| 40 |
+
key = 'backbone.stem.conv.conv.weight'
|
| 41 |
+
if key not in state_dict:
|
| 42 |
+
raise KeyError(f"Cannot find '{key}' in checkpoint.")
|
| 43 |
+
|
| 44 |
+
out_ch = state_dict[key].shape[0]
|
| 45 |
+
|
| 46 |
+
mapping = {
|
| 47 |
+
24: "rtmo-t_8xb32-600e_body7-416x416",
|
| 48 |
+
32: "rtmo-s_8xb32-600e_body7-640x640",
|
| 49 |
+
48: "rtmo-m_16xb16-600e_body7-640x640",
|
| 50 |
+
64: "rtmo-l_16xb16-600e_body7-640x640",
|
| 51 |
+
}
|
| 52 |
+
return mapping.get(out_ch, f'unknown (stem out_channels={out_ch})')
|
| 53 |
+
|
| 54 |
# βββ Gradio prediction function βββ
|
| 55 |
@spaces.GPU()
|
| 56 |
def predict(image: Image.Image, checkpoint):
|