zhangziang
cuda
939a4f3
import gradio as gr
import numpy as np
from PIL import Image
import torch
import os
import tempfile
from paths import *
from vision_tower import VGGT_OriAny_Ref
from inference import *
from app_utils import *
from axis_renderer import BlendRenderer
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)
if torch.cuda.is_available():
mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
device = torch.device('cuda')
else:
mark_dtype = torch.float16
device = torch.device('cpu')
model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
model.eval()
model = model.to(device)
print('Model loaded.')
axis_renderer = BlendRenderer(RENDER_FILE)
# ====== 工具函数:安全图像处理 ======
def safe_image_input(image):
"""确保返回合法的 numpy 数组或 None"""
if image is None:
return None
if isinstance(image, np.ndarray):
return image
try:
return np.array(image)
except Exception:
return None
# ====== 推理函数 ======
@torch.no_grad()
def run_inference(image_ref, image_tgt, do_rm_bkg):
image_ref = safe_image_input(image_ref)
image_tgt = safe_image_input(image_tgt)
if image_ref is None:
raise gr.Error("Please upload a reference image before running inference.")
# 转为 PIL(用于背景去除和后续叠加)
pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB")
pil_tgt = None
if image_tgt is not None:
pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB")
if do_rm_bkg:
pil_ref = background_preprocess(pil_ref, True)
pil_tgt = background_preprocess(pil_tgt, True)
else:
if do_rm_bkg:
pil_ref = background_preprocess(pil_ref, True)
try:
ans_dict = inf_single_case(model, pil_ref, pil_tgt)
except Exception as e:
print("Inference error:", e)
raise gr.Error(f"Inference failed: {str(e)}")
def safe_float(val, default=0.0):
try:
return float(val)
except:
return float(default)
az = safe_float(ans_dict.get('ref_az_pred', 0))
el = safe_float(ans_dict.get('ref_el_pred', 0))
ro = safe_float(ans_dict.get('ref_ro_pred', 0))
alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
# ===== 用临时文件保存渲染结果 =====
tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tmp_ref.close()
tmp_tgt.close()
try:
# ===== 渲染参考图的坐标轴 =====
axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name)
axis_ref = Image.open(tmp_ref.name).convert("RGBA")
# 叠加坐标轴到参考图
# 确保尺寸一致
if axis_ref.size != pil_ref.size:
pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC)
pil_ref_rgba = pil_ref.convert("RGBA")
overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")
# ===== 处理目标图(如果有)=====
if pil_tgt is not None:
rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))
tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
# target 默认 alpha=1(根据你的说明)
axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name)
axis_tgt = Image.open(tmp_tgt.name).convert("RGBA")
if axis_tgt.size != pil_tgt.size:
pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC)
pil_tgt_rgba = pil_tgt.convert("RGBA")
overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
else:
overlaid_tgt = None
rel_az = rel_el = rel_ro = 0.0
finally:
# 安全删除临时文件(即使出错也清理)
if os.path.exists(tmp_ref.name):
os.remove(tmp_ref.name)
print('cleaned {}'.format(tmp_ref.name))
if os.path.exists(tmp_tgt.name):
os.remove(tmp_tgt.name)
print('cleaned {}'.format(tmp_tgt.name))
return [
overlaid_ref, # 渲染+叠加后的参考图
overlaid_tgt, # 渲染+叠加后的目标图(可能为 None)
f"{az:.2f}",
f"{el:.2f}",
f"{ro:.2f}",
str(alpha),
f"{rel_az:.2f}",
f"{rel_el:.2f}",
f"{rel_ro:.2f}",
]
# ====== Gradio Blocks UI ======
with gr.Blocks(title="Orient-Anything-V2 Demo") as demo:
gr.Markdown("# Orient-Anything-V2 Demo")
gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.")
with gr.Row():
# 左侧:输入图像(参考图 + 目标图,同一行)
with gr.Column():
with gr.Row():
ref_img = gr.Image(
label="Reference Image (required)",
type="numpy",
height=256,
width=256,
value=None,
interactive=True
)
tgt_img = gr.Image(
label="Target Image (optional)",
type="numpy",
height=256,
width=256,
value=None,
interactive=True
)
rm_bkg = gr.Checkbox(label="Remove Background", value=True)
run_btn = gr.Button("Run Inference", variant="primary")
# === 在这里插入示例 ===
with gr.Row():
gr.Examples(
examples=[
["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"],
["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"],
],
inputs=[ref_img, tgt_img],
examples_per_page=2,
label="Example Inputs (click to load)"
)
gr.Examples(
examples=[
["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"],
["assets/examples/bottle.jpg", None],
],
inputs=[ref_img, tgt_img],
examples_per_page=2,
label=""
)
# 右侧:结果图像 + 文本输出
with gr.Column():
# 结果图像:参考结果 + 目标结果(可选)
with gr.Row():
res_ref_img = gr.Image(
label="Rendered Reference",
type="pil",
height=256,
width=256,
interactive=False
)
res_tgt_img = gr.Image(
label="Rendered Target (if provided)",
type="pil",
height=256,
width=256,
interactive=False
)
# 文本输出放在图像下方
with gr.Row():
with gr.Column():
gr.Markdown("### Absolute Pose (Reference)")
az_out = gr.Textbox(label="Azimuth (0~360°)")
el_out = gr.Textbox(label="Polar (-90~90°)")
ro_out = gr.Textbox(label="Rotation (-90~90°)")
alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)")
with gr.Column():
gr.Markdown("### Relative Pose (Target w.r.t Reference)")
rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)")
rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)")
rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)")
# 绑定点击事件
run_btn.click(
fn=run_inference,
inputs=[ref_img, tgt_img, rm_bkg],
outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out],
preprocess=True,
postprocess=True
)
# 启动(禁用 API 避免 schema 错误)
demo.launch(show_api=False)