zhangziang commited on
Commit
a1a407f
·
1 Parent(s): 9868529

update layout,model,example

Browse files
app.py CHANGED
@@ -2,8 +2,9 @@ import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
  import torch
 
 
5
 
6
- # ====== 你的原有导入和模型加载保持不变 ======
7
  from paths import *
8
  from vision_tower import VGGT_OriAny_Ref
9
  from inference import *
@@ -81,37 +82,52 @@ def run_inference(image_ref, image_tgt, do_rm_bkg):
81
  ro = safe_float(ans_dict.get('ref_ro_pred', 0))
82
  alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
83
 
84
- # ===== 渲染参考图的坐标轴 =====
85
- axis_renderer.render_axis(az, el, ro, alpha, save_path=REF_AXIS_IMAGE)
86
- axis_ref = Image.open(REF_AXIS_IMAGE).convert("RGBA")
87
-
88
- # 叠加坐标轴到参考图
89
- # 确保尺寸一致
90
- if axis_ref.size != pil_ref.size:
91
- axis_ref = axis_ref.resize(pil_ref.size, Image.LANCZOS)
92
- pil_ref_rgba = pil_ref.convert("RGBA")
93
- overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")
94
-
95
- # ===== 处理目标图(如果有)=====
96
- if pil_tgt is not None:
97
- rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
98
- rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
99
- rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))
100
-
101
- tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
102
- print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
103
-
104
- # target 默认 alpha=1(根据你的说明)
105
- axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=TGT_AXIS_IMAGE)
106
- axis_tgt = Image.open(TGT_AXIS_IMAGE).convert("RGBA")
107
-
108
- if axis_tgt.size != pil_tgt.size:
109
- axis_tgt = axis_tgt.resize(pil_tgt.size, Image.LANCZOS)
110
- pil_tgt_rgba = pil_tgt.convert("RGBA")
111
- overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
112
- else:
113
- overlaid_tgt = None
114
- rel_az = rel_el = rel_ro = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  return [
117
  overlaid_ref, # 渲染+叠加后的参考图
@@ -153,6 +169,26 @@ with gr.Blocks(title="Orient-Anything Demo") as demo:
153
  )
154
  rm_bkg = gr.Checkbox(label="Remove Background", value=True)
155
  run_btn = gr.Button("Run Inference", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # 右侧:结果图像 + 文本输出
158
  with gr.Column():
@@ -175,17 +211,17 @@ with gr.Blocks(title="Orient-Anything Demo") as demo:
175
 
176
  # 文本输出放在图像下方
177
  with gr.Row():
178
- with gr.Column(scale=1):
179
  gr.Markdown("### Absolute Pose (Reference)")
180
- az_out = gr.Textbox(label="Azimuth (0~360°)",scale=0.5)
181
- el_out = gr.Textbox(label="Polar (-90~90°)",scale=0.5)
182
- ro_out = gr.Textbox(label="Rotation (-90~90°)",scale=0.5)
183
- alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)",scale=0.5)
184
- with gr.Column(scale=1):
185
  gr.Markdown("### Relative Pose (Target w.r.t Reference)")
186
- rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)",scale=0.5)
187
- rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)",scale=0.5)
188
- rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)",scale=0.5)
189
 
190
  # 绑定点击事件
191
  run_btn.click(
 
2
  import numpy as np
3
  from PIL import Image
4
  import torch
5
+ import os
6
+ import tempfile
7
 
 
8
  from paths import *
9
  from vision_tower import VGGT_OriAny_Ref
10
  from inference import *
 
82
  ro = safe_float(ans_dict.get('ref_ro_pred', 0))
83
  alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
84
 
85
+ # ===== 用临时文件保存渲染结果 =====
86
+ tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
87
+ tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
88
+ tmp_ref.close()
89
+ tmp_tgt.close()
90
+
91
+ try:
92
+ # ===== 渲染参考图的坐标轴 =====
93
+ axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name)
94
+ axis_ref = Image.open(tmp_ref.name).convert("RGBA")
95
+
96
+ # 叠加坐标轴到参考图
97
+ # 确保尺寸一致
98
+ if axis_ref.size != pil_ref.size:
99
+ pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC)
100
+ pil_ref_rgba = pil_ref.convert("RGBA")
101
+ overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")
102
+
103
+ # ===== 处理目标图(如果有)=====
104
+ if pil_tgt is not None:
105
+ rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
106
+ rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
107
+ rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))
108
+
109
+ tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
110
+ print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
111
+
112
+ # target 默认 alpha=1(根据你的说明)
113
+ axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name)
114
+ axis_tgt = Image.open(tmp_tgt.name).convert("RGBA")
115
+
116
+ if axis_tgt.size != pil_tgt.size:
117
+ pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC)
118
+ pil_tgt_rgba = pil_tgt.convert("RGBA")
119
+ overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
120
+ else:
121
+ overlaid_tgt = None
122
+ rel_az = rel_el = rel_ro = 0.0
123
+ finally:
124
+ # 安全删除临时文件(即使出错也清理)
125
+ if os.path.exists(tmp_ref.name):
126
+ os.remove(tmp_ref.name)
127
+ print('cleaned {}'.format(tmp_ref.name))
128
+ if os.path.exists(tmp_tgt.name):
129
+ os.remove(tmp_tgt.name)
130
+ print('cleaned {}'.format(tmp_tgt.name))
131
 
132
  return [
133
  overlaid_ref, # 渲染+叠加后的参考图
 
169
  )
170
  rm_bkg = gr.Checkbox(label="Remove Background", value=True)
171
  run_btn = gr.Button("Run Inference", variant="primary")
172
+ # === 在这里插入示例 ===
173
+ with gr.Row():
174
+ gr.Examples(
175
+ examples=[
176
+ ["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"],
177
+ ["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"],
178
+ ],
179
+ inputs=[ref_img, tgt_img],
180
+ examples_per_page=2,
181
+ label="Example Inputs (click to load)"
182
+ )
183
+ gr.Examples(
184
+ examples=[
185
+ ["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"],
186
+ ["assets/examples/bottle.jpg", None],
187
+ ],
188
+ inputs=[ref_img, tgt_img],
189
+ examples_per_page=2,
190
+ label=""
191
+ )
192
 
193
  # 右侧:结果图像 + 文本输出
194
  with gr.Column():
 
211
 
212
  # 文本输出放在图像下方
213
  with gr.Row():
214
+ with gr.Column():
215
  gr.Markdown("### Absolute Pose (Reference)")
216
+ az_out = gr.Textbox(label="Azimuth (0~360°)")
217
+ el_out = gr.Textbox(label="Polar (-90~90°)")
218
+ ro_out = gr.Textbox(label="Rotation (-90~90°)")
219
+ alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)")
220
+ with gr.Column():
221
  gr.Markdown("### Relative Pose (Target w.r.t Reference)")
222
+ rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)")
223
+ rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)")
224
+ rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)")
225
 
226
  # 绑定点击事件
227
  run_btn.click(
assets/{axis_ref.png → examples/F35-0.jpg} RENAMED
File without changes
assets/{axis_tgt.png → examples/F35-1.jpg} RENAMED
File without changes
assets/examples/bottle.jpg ADDED

Git LFS Details

  • SHA256: a06e763afda918b0bbd5d5fe248e09b88bdeb72cd10c0343babf4eb14209a062
  • Pointer size: 130 Bytes
  • Size of remote file: 20.1 kB
assets/examples/hat.jpg ADDED

Git LFS Details

  • SHA256: 4fec959db7ee5ccad6cb20e74ffa68ef697bf0a1aa207ccc694b1faf29fccbcb
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
assets/examples/skateboard-0.jpg ADDED

Git LFS Details

  • SHA256: 0f5783079542c8aec34b3ce041245f0e2050484c7fdabe7658cefa9fca05c078
  • Pointer size: 131 Bytes
  • Size of remote file: 258 kB
assets/examples/skateboard-1.jpg ADDED

Git LFS Details

  • SHA256: e88dbf8357d1d83960004e8a4723f5b9e4aee1a4eb4b3dc5384b75467019d576
  • Pointer size: 130 Bytes
  • Size of remote file: 65 kB
assets/examples/table-0.jpg ADDED

Git LFS Details

  • SHA256: 9eee2a881ea0be52bf5381b0948a0eea9bb173b4982dbd5e867b95b2368e7031
  • Pointer size: 129 Bytes
  • Size of remote file: 4.81 kB
assets/examples/table-1.jpg ADDED

Git LFS Details

  • SHA256: 590d5702941a56a3296f4c86cc7f8110ce8dae4b6103002ce0993cfb24b82489
  • Pointer size: 129 Bytes
  • Size of remote file: 4.6 kB
inference.py CHANGED
@@ -51,11 +51,11 @@ def val_fit_alpha(distribute):
51
  mu_fit, kappa_fit = saved_params[max_index]
52
  r_squared = saved_r_squared[max_index]
53
 
54
- if alpha == 1. and kappa_fit>=0.5 and r_squared>=0.5:
55
  pass
56
- elif alpha == 2. and kappa_fit>=0.35 and r_squared>=0.35:
57
  pass
58
- elif alpha == 4. and kappa_fit>=0.25 and r_squared>=0.25:
59
  pass
60
  else:
61
  alpha=0.
 
51
  mu_fit, kappa_fit = saved_params[max_index]
52
  r_squared = saved_r_squared[max_index]
53
 
54
+ if alpha == 1. and kappa_fit>=0.6 and r_squared>=0.45:
55
  pass
56
+ elif alpha == 2. and kappa_fit>=0.4 and r_squared>=0.45:
57
  pass
58
+ elif alpha == 4. and kappa_fit>=0.25 and r_squared>=0.45:
59
  pass
60
  else:
61
  alpha=0.
orianyV2_demo.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
paths.py CHANGED
@@ -7,10 +7,7 @@ VGGT_1B = "facebook/VGGT-1B"
7
 
8
  ORIANY_V2 = "Viglong/OriAnyV2_ckpt"
9
 
10
- REMOTE_CKPT_PATH = "demo_ckpts/acc8mask20lowlr.pt"
11
-
12
 
13
  RENDER_FILE = "assets/axis_render.blend"
14
- REF_AXIS_IMAGE = "assets/axis_ref.png"
15
- TGT_AXIS_IMAGE = "assets/axis_tgt.png"
16
 
 
7
 
8
  ORIANY_V2 = "Viglong/OriAnyV2_ckpt"
9
 
10
+ REMOTE_CKPT_PATH = "demo_ckpts/rotmod_realrotaug_best.pt"
 
11
 
12
  RENDER_FILE = "assets/axis_render.blend"
 
 
13