File size: 8,815 Bytes
f783161
 
 
 
a1a407f
 
f783161
 
 
 
 
 
 
 
 
 
 
939a4f3
 
 
 
 
 
f783161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1a407f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f783161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939a4f3
 
f783161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1a407f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f783161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1a407f
f783161
a1a407f
 
 
 
 
f783161
a1a407f
 
 
f783161
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import gradio as gr
import numpy as np
from PIL import Image
import torch
import os
import tempfile

from paths import *
from vision_tower import VGGT_OriAny_Ref
from inference import *
from app_utils import *
from axis_renderer import BlendRenderer

from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)

if torch.cuda.is_available():
    mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
    device = torch.device('cuda')
else:
    mark_dtype = torch.float16
    device = torch.device('cpu')

model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
model.eval()
model = model.to(device)
print('Model loaded.')

axis_renderer = BlendRenderer(RENDER_FILE)


# ====== 工具函数:安全图像处理 ======
def safe_image_input(image):
    """确保返回合法的 numpy 数组或 None"""
    if image is None:
        return None
    if isinstance(image, np.ndarray):
        return image
    try:
        return np.array(image)
    except Exception:
        return None


# ====== 推理函数 ======
@torch.no_grad()
def run_inference(image_ref, image_tgt, do_rm_bkg):
    image_ref = safe_image_input(image_ref)
    image_tgt = safe_image_input(image_tgt)

    if image_ref is None:
        raise gr.Error("Please upload a reference image before running inference.")

    # 转为 PIL(用于背景去除和后续叠加)
    pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB")
    pil_tgt = None

    if image_tgt is not None:
        pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB")
        if do_rm_bkg:
            pil_ref = background_preprocess(pil_ref, True)
            pil_tgt = background_preprocess(pil_tgt, True)
    else:
        if do_rm_bkg:
            pil_ref = background_preprocess(pil_ref, True)

    try:
        ans_dict = inf_single_case(model, pil_ref, pil_tgt)
    except Exception as e:
        print("Inference error:", e)
        raise gr.Error(f"Inference failed: {str(e)}")

    def safe_float(val, default=0.0):
        try:
            return float(val)
        except:
            return float(default)

    az = safe_float(ans_dict.get('ref_az_pred', 0))
    el = safe_float(ans_dict.get('ref_el_pred', 0))
    ro = safe_float(ans_dict.get('ref_ro_pred', 0))
    alpha = int(ans_dict.get('ref_alpha_pred', 1))  # 注意:target 默认 alpha=1,但 ref 可能不是

    # ===== 用临时文件保存渲染结果 =====
    tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    tmp_ref.close()
    tmp_tgt.close()

    try:
        # ===== 渲染参考图的坐标轴 =====
        axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name)
        axis_ref = Image.open(tmp_ref.name).convert("RGBA")

        # 叠加坐标轴到参考图
        # 确保尺寸一致
        if axis_ref.size != pil_ref.size:
            pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC)
        pil_ref_rgba = pil_ref.convert("RGBA")
        overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")

        # ===== 处理目标图(如果有)=====
        if pil_tgt is not None:
            rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
            rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
            rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))

            tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
            print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
            
            # target 默认 alpha=1(根据你的说明)
            axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name)
            axis_tgt = Image.open(tmp_tgt.name).convert("RGBA")

            if axis_tgt.size != pil_tgt.size:
                pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC)
            pil_tgt_rgba = pil_tgt.convert("RGBA")
            overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
        else:
            overlaid_tgt = None
            rel_az = rel_el = rel_ro = 0.0
    finally:
        # 安全删除临时文件(即使出错也清理)
        if os.path.exists(tmp_ref.name):
            os.remove(tmp_ref.name)
            print('cleaned {}'.format(tmp_ref.name))
        if os.path.exists(tmp_tgt.name):
            os.remove(tmp_tgt.name)
            print('cleaned {}'.format(tmp_tgt.name))

    return [
        overlaid_ref,  # 渲染+叠加后的参考图
        overlaid_tgt,  # 渲染+叠加后的目标图(可能为 None)
        f"{az:.2f}",
        f"{el:.2f}",
        f"{ro:.2f}",
        str(alpha),
        f"{rel_az:.2f}",
        f"{rel_el:.2f}",
        f"{rel_ro:.2f}",
    ]


# ====== Gradio Blocks UI ======
with gr.Blocks(title="Orient-Anything-V2 Demo") as demo:
    gr.Markdown("# Orient-Anything-V2 Demo")
    gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.")

    with gr.Row():
        # 左侧:输入图像(参考图 + 目标图,同一行)
        with gr.Column():
            with gr.Row():
                ref_img = gr.Image(
                    label="Reference Image (required)",
                    type="numpy",
                    height=256,
                    width=256,
                    value=None,
                    interactive=True
                )
                tgt_img = gr.Image(
                    label="Target Image (optional)",
                    type="numpy",
                    height=256,
                    width=256,
                    value=None,
                    interactive=True
                )
            rm_bkg = gr.Checkbox(label="Remove Background", value=True)
            run_btn = gr.Button("Run Inference", variant="primary")
            # === 在这里插入示例 ===
            with gr.Row():
                gr.Examples(
                    examples=[
                        ["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"],
                        ["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"],
                    ],
                    inputs=[ref_img, tgt_img],
                    examples_per_page=2,
                    label="Example Inputs (click to load)"
                )
                gr.Examples(
                    examples=[
                        ["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"],
                        ["assets/examples/bottle.jpg", None],
                    ],
                    inputs=[ref_img, tgt_img],
                    examples_per_page=2,
                    label=""
                )

        # 右侧:结果图像 + 文本输出
        with gr.Column():
            # 结果图像:参考结果 + 目标结果(可选)
            with gr.Row():
                res_ref_img = gr.Image(
                    label="Rendered Reference",
                    type="pil",
                    height=256,
                    width=256,
                    interactive=False
                )
                res_tgt_img = gr.Image(
                    label="Rendered Target (if provided)",
                    type="pil",
                    height=256,
                    width=256,
                    interactive=False
                )
                
            # 文本输出放在图像下方
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Absolute Pose (Reference)")
                    az_out = gr.Textbox(label="Azimuth (0~360°)")
                    el_out = gr.Textbox(label="Polar (-90~90°)")
                    ro_out = gr.Textbox(label="Rotation (-90~90°)")
                    alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)")
                with gr.Column():
                    gr.Markdown("### Relative Pose (Target w.r.t Reference)")
                    rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)")
                    rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)")
                    rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)")

    # 绑定点击事件
    run_btn.click(
        fn=run_inference,
        inputs=[ref_img, tgt_img, rm_bkg],
        outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out],
        preprocess=True,
        postprocess=True
    )

# 启动(禁用 API 避免 schema 错误)
demo.launch(show_api=False)