Spaces:
Running
Running
zhang-ziang
commited on
Commit
·
6965bae
1
Parent(s):
0f72f6a
confidence added
Browse files
app.py
CHANGED
|
@@ -10,6 +10,7 @@ import io
|
|
| 10 |
from PIL import Image
|
| 11 |
import rembg
|
| 12 |
from typing import Any
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
from huggingface_hub import hf_hub_download
|
|
@@ -107,11 +108,31 @@ def get_3angle(image):
|
|
| 107 |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
| 108 |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
| 109 |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
| 110 |
-
|
|
|
|
| 111 |
angles[0] = gaus_ax_pred
|
| 112 |
angles[1] = gaus_pl_pred - 90
|
| 113 |
angles[2] = gaus_ro_pred - 30
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
return angles
|
| 116 |
|
| 117 |
def scale(x):
|
|
@@ -145,10 +166,13 @@ def figure_to_img(fig):
|
|
| 145 |
image = Image.open(buf).copy()
|
| 146 |
return image
|
| 147 |
|
| 148 |
-
def infer_func(img, do_rm_bkg):
|
| 149 |
img = Image.fromarray(img)
|
| 150 |
img = background_preprocess(img, do_rm_bkg)
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
fig, ax = plt.subplots(figsize=(8, 8))
|
| 154 |
|
|
@@ -197,21 +221,23 @@ def infer_func(img, do_rm_bkg):
|
|
| 197 |
|
| 198 |
res_img = figure_to_img(fig)
|
| 199 |
# axis_model = "axis.obj"
|
| 200 |
-
return [res_img, float(angles[0]), float(angles[1]), float(angles[2])]
|
| 201 |
|
| 202 |
server = gr.Interface(
|
| 203 |
flagging_mode='never',
|
| 204 |
fn=infer_func,
|
| 205 |
inputs=[
|
| 206 |
gr.Image(height=512, width=512, label="upload your image"),
|
| 207 |
-
gr.Checkbox(label="Remove Background", value=True)
|
|
|
|
| 208 |
],
|
| 209 |
outputs=[
|
| 210 |
gr.Image(height=512, width=512, label="result image"),
|
| 211 |
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
|
| 212 |
gr.Textbox(lines=1, label='Azimuth(0~360°)'),
|
| 213 |
gr.Textbox(lines=1, label='Polar(-90~90°)'),
|
| 214 |
-
gr.Textbox(lines=1, label='Rotation(-90~90°)')
|
|
|
|
| 215 |
]
|
| 216 |
)
|
| 217 |
|
|
|
|
| 10 |
from PIL import Image
|
| 11 |
import rembg
|
| 12 |
from typing import Any
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
|
| 15 |
|
| 16 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 108 |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
| 109 |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
| 110 |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
| 111 |
+
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
|
| 112 |
+
angles = torch.zeros(4)
|
| 113 |
angles[0] = gaus_ax_pred
|
| 114 |
angles[1] = gaus_pl_pred - 90
|
| 115 |
angles[2] = gaus_ro_pred - 30
|
| 116 |
+
angles[3] = confidence
|
| 117 |
+
return angles
|
| 118 |
+
|
| 119 |
+
def get_3angle_infer_aug(image):
|
| 120 |
|
| 121 |
+
# image = Image.open(image_path).convert('RGB')
|
| 122 |
+
image_inputs = val_preprocess(images = image)
|
| 123 |
+
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
dino_pred = dino(image_inputs)
|
| 126 |
+
|
| 127 |
+
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
| 128 |
+
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
| 129 |
+
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
| 130 |
+
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
|
| 131 |
+
angles = torch.zeros(4)
|
| 132 |
+
angles[0] = gaus_ax_pred
|
| 133 |
+
angles[1] = gaus_pl_pred - 90
|
| 134 |
+
angles[2] = gaus_ro_pred - 30
|
| 135 |
+
angles[3] = confidence
|
| 136 |
return angles
|
| 137 |
|
| 138 |
def scale(x):
|
|
|
|
| 166 |
image = Image.open(buf).copy()
|
| 167 |
return image
|
| 168 |
|
| 169 |
+
def infer_func(img, do_rm_bkg, do_infer_aug):
|
| 170 |
img = Image.fromarray(img)
|
| 171 |
img = background_preprocess(img, do_rm_bkg)
|
| 172 |
+
if do_infer_aug:
|
| 173 |
+
angles = get_3angle_infer_aug(img)
|
| 174 |
+
else:
|
| 175 |
+
angles = get_3angle(img)
|
| 176 |
|
| 177 |
fig, ax = plt.subplots(figsize=(8, 8))
|
| 178 |
|
|
|
|
| 221 |
|
| 222 |
res_img = figure_to_img(fig)
|
| 223 |
# axis_model = "axis.obj"
|
| 224 |
+
return [res_img, float(angles[0]), float(angles[1]), float(angles[2]), float(angles[3])]
|
| 225 |
|
| 226 |
server = gr.Interface(
|
| 227 |
flagging_mode='never',
|
| 228 |
fn=infer_func,
|
| 229 |
inputs=[
|
| 230 |
gr.Image(height=512, width=512, label="upload your image"),
|
| 231 |
+
gr.Checkbox(label="Remove Background", value=True),
|
| 232 |
+
gr.Checkbox(label="Inference time augmentation", value=False)
|
| 233 |
],
|
| 234 |
outputs=[
|
| 235 |
gr.Image(height=512, width=512, label="result image"),
|
| 236 |
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
|
| 237 |
gr.Textbox(lines=1, label='Azimuth(0~360°)'),
|
| 238 |
gr.Textbox(lines=1, label='Polar(-90~90°)'),
|
| 239 |
+
gr.Textbox(lines=1, label='Rotation(-90~90°)'),
|
| 240 |
+
gr.Textbox(lines=1, label='Confidence(0~1)')
|
| 241 |
]
|
| 242 |
)
|
| 243 |
|