Spaces:
Running
Running
zhang-ziang
commited on
Commit
·
0366edb
1
Parent(s):
74503df
infer aug + remove bkg
Browse files
app.py
CHANGED
|
@@ -210,10 +210,10 @@ def remove_outliers_and_average_circular(tensor, threshold=1.5):
|
|
| 210 |
|
| 211 |
return mean_angle
|
| 212 |
|
| 213 |
-
def get_3angle_infer_aug(
|
| 214 |
|
| 215 |
# image = Image.open(image_path).convert('RGB')
|
| 216 |
-
image = get_crop_images(
|
| 217 |
image_inputs = val_preprocess(images = image)
|
| 218 |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
| 219 |
with torch.no_grad():
|
|
@@ -267,21 +267,22 @@ def figure_to_img(fig):
|
|
| 267 |
return image
|
| 268 |
|
| 269 |
def infer_func(img, do_rm_bkg, do_infer_aug):
|
| 270 |
-
|
| 271 |
-
img = background_preprocess(img, do_rm_bkg)
|
| 272 |
if do_infer_aug:
|
| 273 |
-
|
|
|
|
| 274 |
else:
|
| 275 |
-
|
|
|
|
| 276 |
|
| 277 |
fig, ax = plt.subplots(figsize=(8, 8))
|
| 278 |
|
| 279 |
-
w, h =
|
| 280 |
if h>w:
|
| 281 |
extent = [-5*w/h, 5*w/h, -5, 5]
|
| 282 |
else:
|
| 283 |
extent = [-5, 5, -5*h/w, 5*h/w]
|
| 284 |
-
ax.imshow(
|
| 285 |
|
| 286 |
origin = np.array([0, 0])
|
| 287 |
|
|
|
|
| 210 |
|
| 211 |
return mean_angle
|
| 212 |
|
| 213 |
+
def get_3angle_infer_aug(origin_img, rm_bkg_img):
|
| 214 |
|
| 215 |
# image = Image.open(image_path).convert('RGB')
|
| 216 |
+
image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3)
|
| 217 |
image_inputs = val_preprocess(images = image)
|
| 218 |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
| 219 |
with torch.no_grad():
|
|
|
|
| 267 |
return image
|
| 268 |
|
| 269 |
def infer_func(img, do_rm_bkg, do_infer_aug):
|
| 270 |
+
origin_img = Image.fromarray(img)
|
|
|
|
| 271 |
if do_infer_aug:
|
| 272 |
+
rm_bkg_img = background_preprocess(origin_img, True)
|
| 273 |
+
angles = get_3angle_infer_aug(origin_img, rm_bkg_img)
|
| 274 |
else:
|
| 275 |
+
rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
|
| 276 |
+
angles = get_3angle(rm_bkg_img)
|
| 277 |
|
| 278 |
fig, ax = plt.subplots(figsize=(8, 8))
|
| 279 |
|
| 280 |
+
w, h = rm_bkg_img.size
|
| 281 |
if h>w:
|
| 282 |
extent = [-5*w/h, 5*w/h, -5, 5]
|
| 283 |
else:
|
| 284 |
extent = [-5, 5, -5*h/w, 5*h/w]
|
| 285 |
+
ax.imshow(rm_bkg_img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围
|
| 286 |
|
| 287 |
origin = np.array([0, 0])
|
| 288 |
|