Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import os | |
| import tempfile | |
| from paths import * | |
| from vision_tower import VGGT_OriAny_Ref | |
| from inference import * | |
| from app_utils import * | |
| from axis_renderer import BlendRenderer | |
| from huggingface_hub import hf_hub_download | |
| ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True) | |
| print(ckpt_path) | |
| if torch.cuda.is_available(): | |
| mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 | |
| device = torch.device('cuda') | |
| else: | |
| mark_dtype = torch.float16 | |
| device = torch.device('cpu') | |
| model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True) | |
| model.load_state_dict(torch.load(ckpt_path, map_location='cpu')) | |
| model.eval() | |
| model = model.to(device) | |
| print('Model loaded.') | |
| axis_renderer = BlendRenderer(RENDER_FILE) | |
| # ====== 工具函数:安全图像处理 ====== | |
| def safe_image_input(image): | |
| """确保返回合法的 numpy 数组或 None""" | |
| if image is None: | |
| return None | |
| if isinstance(image, np.ndarray): | |
| return image | |
| try: | |
| return np.array(image) | |
| except Exception: | |
| return None | |
| # ====== 推理函数 ====== | |
| def run_inference(image_ref, image_tgt, do_rm_bkg): | |
| image_ref = safe_image_input(image_ref) | |
| image_tgt = safe_image_input(image_tgt) | |
| if image_ref is None: | |
| raise gr.Error("Please upload a reference image before running inference.") | |
| # 转为 PIL(用于背景去除和后续叠加) | |
| pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB") | |
| pil_tgt = None | |
| if image_tgt is not None: | |
| pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB") | |
| if do_rm_bkg: | |
| pil_ref = background_preprocess(pil_ref, True) | |
| pil_tgt = background_preprocess(pil_tgt, True) | |
| else: | |
| if do_rm_bkg: | |
| pil_ref = background_preprocess(pil_ref, True) | |
| try: | |
| ans_dict = inf_single_case(model, pil_ref, pil_tgt) | |
| except Exception as e: | |
| print("Inference error:", e) | |
| raise gr.Error(f"Inference failed: {str(e)}") | |
| def safe_float(val, default=0.0): | |
| try: | |
| return float(val) | |
| except: | |
| return float(default) | |
| az = safe_float(ans_dict.get('ref_az_pred', 0)) | |
| el = safe_float(ans_dict.get('ref_el_pred', 0)) | |
| ro = safe_float(ans_dict.get('ref_ro_pred', 0)) | |
| alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是 | |
| # ===== 用临时文件保存渲染结果 ===== | |
| tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| tmp_ref.close() | |
| tmp_tgt.close() | |
| try: | |
| # ===== 渲染参考图的坐标轴 ===== | |
| axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name) | |
| axis_ref = Image.open(tmp_ref.name).convert("RGBA") | |
| # 叠加坐标轴到参考图 | |
| # 确保尺寸一致 | |
| if axis_ref.size != pil_ref.size: | |
| pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC) | |
| pil_ref_rgba = pil_ref.convert("RGBA") | |
| overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB") | |
| # ===== 处理目标图(如果有)===== | |
| if pil_tgt is not None: | |
| rel_az = safe_float(ans_dict.get('rel_az_pred', 0)) | |
| rel_el = safe_float(ans_dict.get('rel_el_pred', 0)) | |
| rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0)) | |
| tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro) | |
| print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot) | |
| # target 默认 alpha=1(根据你的说明) | |
| axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name) | |
| axis_tgt = Image.open(tmp_tgt.name).convert("RGBA") | |
| if axis_tgt.size != pil_tgt.size: | |
| pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC) | |
| pil_tgt_rgba = pil_tgt.convert("RGBA") | |
| overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB") | |
| else: | |
| overlaid_tgt = None | |
| rel_az = rel_el = rel_ro = 0.0 | |
| finally: | |
| # 安全删除临时文件(即使出错也清理) | |
| if os.path.exists(tmp_ref.name): | |
| os.remove(tmp_ref.name) | |
| print('cleaned {}'.format(tmp_ref.name)) | |
| if os.path.exists(tmp_tgt.name): | |
| os.remove(tmp_tgt.name) | |
| print('cleaned {}'.format(tmp_tgt.name)) | |
| return [ | |
| overlaid_ref, # 渲染+叠加后的参考图 | |
| overlaid_tgt, # 渲染+叠加后的目标图(可能为 None) | |
| f"{az:.2f}", | |
| f"{el:.2f}", | |
| f"{ro:.2f}", | |
| str(alpha), | |
| f"{rel_az:.2f}", | |
| f"{rel_el:.2f}", | |
| f"{rel_ro:.2f}", | |
| ] | |
| # ====== Gradio Blocks UI ====== | |
| with gr.Blocks(title="Orient-Anything-V2 Demo") as demo: | |
| gr.Markdown("# Orient-Anything-V2 Demo") | |
| gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.") | |
| with gr.Row(): | |
| # 左侧:输入图像(参考图 + 目标图,同一行) | |
| with gr.Column(): | |
| with gr.Row(): | |
| ref_img = gr.Image( | |
| label="Reference Image (required)", | |
| type="numpy", | |
| height=256, | |
| width=256, | |
| value=None, | |
| interactive=True | |
| ) | |
| tgt_img = gr.Image( | |
| label="Target Image (optional)", | |
| type="numpy", | |
| height=256, | |
| width=256, | |
| value=None, | |
| interactive=True | |
| ) | |
| rm_bkg = gr.Checkbox(label="Remove Background", value=True) | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| # === 在这里插入示例 === | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"], | |
| ["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"], | |
| ], | |
| inputs=[ref_img, tgt_img], | |
| examples_per_page=2, | |
| label="Example Inputs (click to load)" | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"], | |
| ["assets/examples/bottle.jpg", None], | |
| ], | |
| inputs=[ref_img, tgt_img], | |
| examples_per_page=2, | |
| label="" | |
| ) | |
| # 右侧:结果图像 + 文本输出 | |
| with gr.Column(): | |
| # 结果图像:参考结果 + 目标结果(可选) | |
| with gr.Row(): | |
| res_ref_img = gr.Image( | |
| label="Rendered Reference", | |
| type="pil", | |
| height=256, | |
| width=256, | |
| interactive=False | |
| ) | |
| res_tgt_img = gr.Image( | |
| label="Rendered Target (if provided)", | |
| type="pil", | |
| height=256, | |
| width=256, | |
| interactive=False | |
| ) | |
| # 文本输出放在图像下方 | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Absolute Pose (Reference)") | |
| az_out = gr.Textbox(label="Azimuth (0~360°)") | |
| el_out = gr.Textbox(label="Polar (-90~90°)") | |
| ro_out = gr.Textbox(label="Rotation (-90~90°)") | |
| alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)") | |
| with gr.Column(): | |
| gr.Markdown("### Relative Pose (Target w.r.t Reference)") | |
| rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)") | |
| rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)") | |
| rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)") | |
| # 绑定点击事件 | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[ref_img, tgt_img, rm_bkg], | |
| outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out], | |
| preprocess=True, | |
| postprocess=True | |
| ) | |
| # 启动(禁用 API 避免 schema 错误) | |
| demo.launch(show_api=False) | |