Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
zhangziang
commited on
Commit
·
a1a407f
1
Parent(s):
9868529
update layout,model,example
Browse files- app.py +77 -41
- assets/{axis_ref.png → examples/F35-0.jpg} +2 -2
- assets/{axis_tgt.png → examples/F35-1.jpg} +2 -2
- assets/examples/bottle.jpg +3 -0
- assets/examples/hat.jpg +3 -0
- assets/examples/skateboard-0.jpg +3 -0
- assets/examples/skateboard-1.jpg +3 -0
- assets/examples/table-0.jpg +3 -0
- assets/examples/table-1.jpg +3 -0
- inference.py +3 -3
- orianyV2_demo.ipynb +0 -0
- paths.py +1 -4
app.py
CHANGED
|
@@ -2,8 +2,9 @@ import gradio as gr
|
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
import torch
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
# ====== 你的原有导入和模型加载保持不变 ======
|
| 7 |
from paths import *
|
| 8 |
from vision_tower import VGGT_OriAny_Ref
|
| 9 |
from inference import *
|
|
@@ -81,37 +82,52 @@ def run_inference(image_ref, image_tgt, do_rm_bkg):
|
|
| 81 |
ro = safe_float(ans_dict.get('ref_ro_pred', 0))
|
| 82 |
alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
|
| 83 |
|
| 84 |
-
# =====
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
return [
|
| 117 |
overlaid_ref, # 渲染+叠加后的参考图
|
|
@@ -153,6 +169,26 @@ with gr.Blocks(title="Orient-Anything Demo") as demo:
|
|
| 153 |
)
|
| 154 |
rm_bkg = gr.Checkbox(label="Remove Background", value=True)
|
| 155 |
run_btn = gr.Button("Run Inference", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# 右侧:结果图像 + 文本输出
|
| 158 |
with gr.Column():
|
|
@@ -175,17 +211,17 @@ with gr.Blocks(title="Orient-Anything Demo") as demo:
|
|
| 175 |
|
| 176 |
# 文本输出放在图像下方
|
| 177 |
with gr.Row():
|
| 178 |
-
with gr.Column(
|
| 179 |
gr.Markdown("### Absolute Pose (Reference)")
|
| 180 |
-
az_out = gr.Textbox(label="Azimuth (0~360°)"
|
| 181 |
-
el_out = gr.Textbox(label="Polar (-90~90°)"
|
| 182 |
-
ro_out = gr.Textbox(label="Rotation (-90~90°)"
|
| 183 |
-
alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)"
|
| 184 |
-
with gr.Column(
|
| 185 |
gr.Markdown("### Relative Pose (Target w.r.t Reference)")
|
| 186 |
-
rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)"
|
| 187 |
-
rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)"
|
| 188 |
-
rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)"
|
| 189 |
|
| 190 |
# 绑定点击事件
|
| 191 |
run_btn.click(
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
import torch
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
|
|
|
|
| 8 |
from paths import *
|
| 9 |
from vision_tower import VGGT_OriAny_Ref
|
| 10 |
from inference import *
|
|
|
|
| 82 |
ro = safe_float(ans_dict.get('ref_ro_pred', 0))
|
| 83 |
alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
|
| 84 |
|
| 85 |
+
# ===== 用临时文件保存渲染结果 =====
|
| 86 |
+
tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 87 |
+
tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 88 |
+
tmp_ref.close()
|
| 89 |
+
tmp_tgt.close()
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# ===== 渲染参考图的坐标轴 =====
|
| 93 |
+
axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name)
|
| 94 |
+
axis_ref = Image.open(tmp_ref.name).convert("RGBA")
|
| 95 |
+
|
| 96 |
+
# 叠加坐标轴到参考图
|
| 97 |
+
# 确保尺寸一致
|
| 98 |
+
if axis_ref.size != pil_ref.size:
|
| 99 |
+
pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC)
|
| 100 |
+
pil_ref_rgba = pil_ref.convert("RGBA")
|
| 101 |
+
overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")
|
| 102 |
+
|
| 103 |
+
# ===== 处理目标图(如果有)=====
|
| 104 |
+
if pil_tgt is not None:
|
| 105 |
+
rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
|
| 106 |
+
rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
|
| 107 |
+
rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))
|
| 108 |
+
|
| 109 |
+
tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
|
| 110 |
+
print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
|
| 111 |
+
|
| 112 |
+
# target 默认 alpha=1(根据你的说明)
|
| 113 |
+
axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name)
|
| 114 |
+
axis_tgt = Image.open(tmp_tgt.name).convert("RGBA")
|
| 115 |
+
|
| 116 |
+
if axis_tgt.size != pil_tgt.size:
|
| 117 |
+
pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC)
|
| 118 |
+
pil_tgt_rgba = pil_tgt.convert("RGBA")
|
| 119 |
+
overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
|
| 120 |
+
else:
|
| 121 |
+
overlaid_tgt = None
|
| 122 |
+
rel_az = rel_el = rel_ro = 0.0
|
| 123 |
+
finally:
|
| 124 |
+
# 安全删除临时文件(即使出错也清理)
|
| 125 |
+
if os.path.exists(tmp_ref.name):
|
| 126 |
+
os.remove(tmp_ref.name)
|
| 127 |
+
print('cleaned {}'.format(tmp_ref.name))
|
| 128 |
+
if os.path.exists(tmp_tgt.name):
|
| 129 |
+
os.remove(tmp_tgt.name)
|
| 130 |
+
print('cleaned {}'.format(tmp_tgt.name))
|
| 131 |
|
| 132 |
return [
|
| 133 |
overlaid_ref, # 渲染+叠加后的参考图
|
|
|
|
| 169 |
)
|
| 170 |
rm_bkg = gr.Checkbox(label="Remove Background", value=True)
|
| 171 |
run_btn = gr.Button("Run Inference", variant="primary")
|
| 172 |
+
# === 在这里插入示例 ===
|
| 173 |
+
with gr.Row():
|
| 174 |
+
gr.Examples(
|
| 175 |
+
examples=[
|
| 176 |
+
["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"],
|
| 177 |
+
["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"],
|
| 178 |
+
],
|
| 179 |
+
inputs=[ref_img, tgt_img],
|
| 180 |
+
examples_per_page=2,
|
| 181 |
+
label="Example Inputs (click to load)"
|
| 182 |
+
)
|
| 183 |
+
gr.Examples(
|
| 184 |
+
examples=[
|
| 185 |
+
["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"],
|
| 186 |
+
["assets/examples/bottle.jpg", None],
|
| 187 |
+
],
|
| 188 |
+
inputs=[ref_img, tgt_img],
|
| 189 |
+
examples_per_page=2,
|
| 190 |
+
label=""
|
| 191 |
+
)
|
| 192 |
|
| 193 |
# 右侧:结果图像 + 文本输出
|
| 194 |
with gr.Column():
|
|
|
|
| 211 |
|
| 212 |
# 文本输出放在图像下方
|
| 213 |
with gr.Row():
|
| 214 |
+
with gr.Column():
|
| 215 |
gr.Markdown("### Absolute Pose (Reference)")
|
| 216 |
+
az_out = gr.Textbox(label="Azimuth (0~360°)")
|
| 217 |
+
el_out = gr.Textbox(label="Polar (-90~90°)")
|
| 218 |
+
ro_out = gr.Textbox(label="Rotation (-90~90°)")
|
| 219 |
+
alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)")
|
| 220 |
+
with gr.Column():
|
| 221 |
gr.Markdown("### Relative Pose (Target w.r.t Reference)")
|
| 222 |
+
rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)")
|
| 223 |
+
rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)")
|
| 224 |
+
rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)")
|
| 225 |
|
| 226 |
# 绑定点击事件
|
| 227 |
run_btn.click(
|
assets/{axis_ref.png → examples/F35-0.jpg}
RENAMED
|
File without changes
|
assets/{axis_tgt.png → examples/F35-1.jpg}
RENAMED
|
File without changes
|
assets/examples/bottle.jpg
ADDED
|
Git LFS Details
|
assets/examples/hat.jpg
ADDED
|
Git LFS Details
|
assets/examples/skateboard-0.jpg
ADDED
|
Git LFS Details
|
assets/examples/skateboard-1.jpg
ADDED
|
Git LFS Details
|
assets/examples/table-0.jpg
ADDED
|
Git LFS Details
|
assets/examples/table-1.jpg
ADDED
|
Git LFS Details
|
inference.py
CHANGED
|
@@ -51,11 +51,11 @@ def val_fit_alpha(distribute):
|
|
| 51 |
mu_fit, kappa_fit = saved_params[max_index]
|
| 52 |
r_squared = saved_r_squared[max_index]
|
| 53 |
|
| 54 |
-
if alpha == 1. and kappa_fit>=0.
|
| 55 |
pass
|
| 56 |
-
elif alpha == 2. and kappa_fit>=0.
|
| 57 |
pass
|
| 58 |
-
elif alpha == 4. and kappa_fit>=0.25 and r_squared>=0.
|
| 59 |
pass
|
| 60 |
else:
|
| 61 |
alpha=0.
|
|
|
|
| 51 |
mu_fit, kappa_fit = saved_params[max_index]
|
| 52 |
r_squared = saved_r_squared[max_index]
|
| 53 |
|
| 54 |
+
if alpha == 1. and kappa_fit>=0.6 and r_squared>=0.45:
|
| 55 |
pass
|
| 56 |
+
elif alpha == 2. and kappa_fit>=0.4 and r_squared>=0.45:
|
| 57 |
pass
|
| 58 |
+
elif alpha == 4. and kappa_fit>=0.25 and r_squared>=0.45:
|
| 59 |
pass
|
| 60 |
else:
|
| 61 |
alpha=0.
|
orianyV2_demo.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
paths.py
CHANGED
|
@@ -7,10 +7,7 @@ VGGT_1B = "facebook/VGGT-1B"
|
|
| 7 |
|
| 8 |
ORIANY_V2 = "Viglong/OriAnyV2_ckpt"
|
| 9 |
|
| 10 |
-
REMOTE_CKPT_PATH = "demo_ckpts/
|
| 11 |
-
|
| 12 |
|
| 13 |
RENDER_FILE = "assets/axis_render.blend"
|
| 14 |
-
REF_AXIS_IMAGE = "assets/axis_ref.png"
|
| 15 |
-
TGT_AXIS_IMAGE = "assets/axis_tgt.png"
|
| 16 |
|
|
|
|
| 7 |
|
| 8 |
ORIANY_V2 = "Viglong/OriAnyV2_ckpt"
|
| 9 |
|
| 10 |
+
REMOTE_CKPT_PATH = "demo_ckpts/rotmod_realrotaug_best.pt"
|
|
|
|
| 11 |
|
| 12 |
RENDER_FILE = "assets/axis_render.blend"
|
|
|
|
|
|
|
| 13 |
|