Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 8,815 Bytes
f783161 a1a407f f783161 939a4f3 f783161 a1a407f f783161 939a4f3 f783161 a1a407f f783161 a1a407f f783161 a1a407f f783161 a1a407f f783161 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import gradio as gr
import numpy as np
from PIL import Image
import torch
import os
import tempfile
from paths import *
from vision_tower import VGGT_OriAny_Ref
from inference import *
from app_utils import *
from axis_renderer import BlendRenderer
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)
if torch.cuda.is_available():
mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
device = torch.device('cuda')
else:
mark_dtype = torch.float16
device = torch.device('cpu')
model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
model.eval()
model = model.to(device)
print('Model loaded.')
axis_renderer = BlendRenderer(RENDER_FILE)
# ====== 工具函数:安全图像处理 ======
def safe_image_input(image):
"""确保返回合法的 numpy 数组或 None"""
if image is None:
return None
if isinstance(image, np.ndarray):
return image
try:
return np.array(image)
except Exception:
return None
# ====== 推理函数 ======
@torch.no_grad()
def run_inference(image_ref, image_tgt, do_rm_bkg):
image_ref = safe_image_input(image_ref)
image_tgt = safe_image_input(image_tgt)
if image_ref is None:
raise gr.Error("Please upload a reference image before running inference.")
# 转为 PIL(用于背景去除和后续叠加)
pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB")
pil_tgt = None
if image_tgt is not None:
pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB")
if do_rm_bkg:
pil_ref = background_preprocess(pil_ref, True)
pil_tgt = background_preprocess(pil_tgt, True)
else:
if do_rm_bkg:
pil_ref = background_preprocess(pil_ref, True)
try:
ans_dict = inf_single_case(model, pil_ref, pil_tgt)
except Exception as e:
print("Inference error:", e)
raise gr.Error(f"Inference failed: {str(e)}")
def safe_float(val, default=0.0):
try:
return float(val)
except:
return float(default)
az = safe_float(ans_dict.get('ref_az_pred', 0))
el = safe_float(ans_dict.get('ref_el_pred', 0))
ro = safe_float(ans_dict.get('ref_ro_pred', 0))
alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
# ===== 用临时文件保存渲染结果 =====
tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tmp_ref.close()
tmp_tgt.close()
try:
# ===== 渲染参考图的坐标轴 =====
axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name)
axis_ref = Image.open(tmp_ref.name).convert("RGBA")
# 叠加坐标轴到参考图
# 确保尺寸一致
if axis_ref.size != pil_ref.size:
pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC)
pil_ref_rgba = pil_ref.convert("RGBA")
overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")
# ===== 处理目标图(如果有)=====
if pil_tgt is not None:
rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))
tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
# target 默认 alpha=1(根据你的说明)
axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name)
axis_tgt = Image.open(tmp_tgt.name).convert("RGBA")
if axis_tgt.size != pil_tgt.size:
pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC)
pil_tgt_rgba = pil_tgt.convert("RGBA")
overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
else:
overlaid_tgt = None
rel_az = rel_el = rel_ro = 0.0
finally:
# 安全删除临时文件(即使出错也清理)
if os.path.exists(tmp_ref.name):
os.remove(tmp_ref.name)
print('cleaned {}'.format(tmp_ref.name))
if os.path.exists(tmp_tgt.name):
os.remove(tmp_tgt.name)
print('cleaned {}'.format(tmp_tgt.name))
return [
overlaid_ref, # 渲染+叠加后的参考图
overlaid_tgt, # 渲染+叠加后的目标图(可能为 None)
f"{az:.2f}",
f"{el:.2f}",
f"{ro:.2f}",
str(alpha),
f"{rel_az:.2f}",
f"{rel_el:.2f}",
f"{rel_ro:.2f}",
]
# ====== Gradio Blocks UI ======
with gr.Blocks(title="Orient-Anything-V2 Demo") as demo:
gr.Markdown("# Orient-Anything-V2 Demo")
gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.")
with gr.Row():
# 左侧:输入图像(参考图 + 目标图,同一行)
with gr.Column():
with gr.Row():
ref_img = gr.Image(
label="Reference Image (required)",
type="numpy",
height=256,
width=256,
value=None,
interactive=True
)
tgt_img = gr.Image(
label="Target Image (optional)",
type="numpy",
height=256,
width=256,
value=None,
interactive=True
)
rm_bkg = gr.Checkbox(label="Remove Background", value=True)
run_btn = gr.Button("Run Inference", variant="primary")
# === 在这里插入示例 ===
with gr.Row():
gr.Examples(
examples=[
["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"],
["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"],
],
inputs=[ref_img, tgt_img],
examples_per_page=2,
label="Example Inputs (click to load)"
)
gr.Examples(
examples=[
["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"],
["assets/examples/bottle.jpg", None],
],
inputs=[ref_img, tgt_img],
examples_per_page=2,
label=""
)
# 右侧:结果图像 + 文本输出
with gr.Column():
# 结果图像:参考结果 + 目标结果(可选)
with gr.Row():
res_ref_img = gr.Image(
label="Rendered Reference",
type="pil",
height=256,
width=256,
interactive=False
)
res_tgt_img = gr.Image(
label="Rendered Target (if provided)",
type="pil",
height=256,
width=256,
interactive=False
)
# 文本输出放在图像下方
with gr.Row():
with gr.Column():
gr.Markdown("### Absolute Pose (Reference)")
az_out = gr.Textbox(label="Azimuth (0~360°)")
el_out = gr.Textbox(label="Polar (-90~90°)")
ro_out = gr.Textbox(label="Rotation (-90~90°)")
alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)")
with gr.Column():
gr.Markdown("### Relative Pose (Target w.r.t Reference)")
rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)")
rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)")
rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)")
# 绑定点击事件
run_btn.click(
fn=run_inference,
inputs=[ref_img, tgt_img, rm_bkg],
outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out],
preprocess=True,
postprocess=True
)
# 启动(禁用 API 避免 schema 错误)
demo.launch(show_api=False)
|