Spaces:
Runtime error
Runtime error
Commit
Β·
eabc299
1
Parent(s):
3e95b67
:sparkles: return foreground only
Browse files
app.py
CHANGED
|
@@ -33,7 +33,7 @@ def get_scale_factor(im_h, im_w, ref_size=512):
|
|
| 33 |
MODEL_PATH = hf_hub_download('nateraw/background-remover-files', 'modnet.onnx', repo_type='dataset')
|
| 34 |
|
| 35 |
|
| 36 |
-
def main(image_path):
|
| 37 |
|
| 38 |
# read image
|
| 39 |
im = cv2.imread(image_path)
|
|
@@ -85,9 +85,17 @@ def main(image_path):
|
|
| 85 |
image = np.repeat(image, 3, axis=2)
|
| 86 |
elif image.shape[2] == 4:
|
| 87 |
image = image[:, :, 0:3]
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
title = "MODNet Background Remover"
|
|
@@ -99,9 +107,12 @@ image = Image.open(requests.get(url, stream=True).raw)
|
|
| 99 |
image.save('twitter_profile_pic.jpeg')
|
| 100 |
interface = gr.Interface(
|
| 101 |
fn=main,
|
| 102 |
-
inputs=
|
|
|
|
|
|
|
|
|
|
| 103 |
outputs='image',
|
| 104 |
-
examples=[['twitter_profile_pic.jpeg']],
|
| 105 |
title=title,
|
| 106 |
description=description,
|
| 107 |
article=article,
|
|
|
|
| 33 |
MODEL_PATH = hf_hub_download('nateraw/background-remover-files', 'modnet.onnx', repo_type='dataset')
|
| 34 |
|
| 35 |
|
| 36 |
+
def main(image_path, threshold):
|
| 37 |
|
| 38 |
# read image
|
| 39 |
im = cv2.imread(image_path)
|
|
|
|
| 85 |
image = np.repeat(image, 3, axis=2)
|
| 86 |
elif image.shape[2] == 4:
|
| 87 |
image = image[:, :, 0:3]
|
| 88 |
+
|
| 89 |
+
b, g, r = cv2.split(image)
|
| 90 |
+
|
| 91 |
+
mask = np.asarray(matte)
|
| 92 |
+
a = np.ones(mask.shape, dtype='uint8') * 255
|
| 93 |
+
alpha_im = cv2.merge([b, g, r, a], 4)
|
| 94 |
+
bg = np.zeros(alpha_im.shape)
|
| 95 |
+
new_mask = np.stack([mask, mask, mask, mask], axis=2)
|
| 96 |
+
foreground = np.where(new_mask > threshold, alpha_im, bg).astype(np.uint8)
|
| 97 |
+
|
| 98 |
+
return Image.fromarray(foreground)
|
| 99 |
|
| 100 |
|
| 101 |
title = "MODNet Background Remover"
|
|
|
|
| 107 |
image.save('twitter_profile_pic.jpeg')
|
| 108 |
interface = gr.Interface(
|
| 109 |
fn=main,
|
| 110 |
+
inputs=[
|
| 111 |
+
gr.inputs.Image(type='filepath'),
|
| 112 |
+
gr.inputs.Slider(minimum=0, maximum=250, default=100, step=5, label='Mask Cutoff Threshold'),
|
| 113 |
+
],
|
| 114 |
outputs='image',
|
| 115 |
+
examples=[['twitter_profile_pic.jpeg', 120]],
|
| 116 |
title=title,
|
| 117 |
description=description,
|
| 118 |
article=article,
|