Spaces:
Runtime error
Runtime error
Tobias Cornille
commited on
Commit
·
672ba8c
1
Parent(s):
27a9b54
Fix GPU + add examples
Browse files- .gitattributes +3 -0
- app.py +17 -4
.gitattributes
CHANGED
|
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
a2d2.png filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
bxl.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
dogs.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -143,13 +143,15 @@ def preds_to_semantic_inds(preds, threshold):
|
|
| 143 |
return semantic_inds
|
| 144 |
|
| 145 |
|
| 146 |
-
def clipseg_segmentation(
|
|
|
|
|
|
|
| 147 |
inputs = processor(
|
| 148 |
text=category_names,
|
| 149 |
images=[image] * len(category_names),
|
| 150 |
padding="max_length",
|
| 151 |
return_tensors="pt",
|
| 152 |
-
)
|
| 153 |
with torch.no_grad():
|
| 154 |
outputs = model(**inputs)
|
| 155 |
# resize the outputs
|
|
@@ -183,7 +185,7 @@ def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categori
|
|
| 183 |
# convert semantic_inds to shrunken bool masks
|
| 184 |
bool_masks = semantic_inds_to_shrunken_bool_masks(
|
| 185 |
semantic_inds, shrink_kernel_size, num_categories
|
| 186 |
-
)
|
| 187 |
|
| 188 |
sizes = [
|
| 189 |
torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
|
|
@@ -306,6 +308,7 @@ def generate_panoptic_mask(
|
|
| 306 |
image,
|
| 307 |
stuff_category_names,
|
| 308 |
segmentation_background_threshold,
|
|
|
|
| 309 |
)
|
| 310 |
# remove things from stuff masks
|
| 311 |
combined_things_mask = torch.any(thing_masks, dim=0)
|
|
@@ -327,7 +330,7 @@ def generate_panoptic_mask(
|
|
| 327 |
num_samples = int(relative_sizes[i] * num_samples_factor)
|
| 328 |
if num_samples == 0:
|
| 329 |
continue
|
| 330 |
-
points = sample_points_based_on_preds(clipseg_pred.numpy(), num_samples)
|
| 331 |
if len(points) == 0:
|
| 332 |
continue
|
| 333 |
# use SAM to get mask for points
|
|
@@ -381,6 +384,16 @@ clipseg_model = CLIPSegForImageSegmentation.from_pretrained(
|
|
| 381 |
clipseg_model.to(device)
|
| 382 |
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
if __name__ == "__main__":
|
| 385 |
parser = argparse.ArgumentParser("Panoptic Segment Anything demo", add_help=True)
|
| 386 |
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
|
|
|
| 143 |
return semantic_inds
|
| 144 |
|
| 145 |
|
| 146 |
+
def clipseg_segmentation(
|
| 147 |
+
processor, model, image, category_names, background_threshold, device
|
| 148 |
+
):
|
| 149 |
inputs = processor(
|
| 150 |
text=category_names,
|
| 151 |
images=[image] * len(category_names),
|
| 152 |
padding="max_length",
|
| 153 |
return_tensors="pt",
|
| 154 |
+
).to(device)
|
| 155 |
with torch.no_grad():
|
| 156 |
outputs = model(**inputs)
|
| 157 |
# resize the outputs
|
|
|
|
| 185 |
# convert semantic_inds to shrunken bool masks
|
| 186 |
bool_masks = semantic_inds_to_shrunken_bool_masks(
|
| 187 |
semantic_inds, shrink_kernel_size, num_categories
|
| 188 |
+
).to(preds.device)
|
| 189 |
|
| 190 |
sizes = [
|
| 191 |
torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
|
|
|
|
| 308 |
image,
|
| 309 |
stuff_category_names,
|
| 310 |
segmentation_background_threshold,
|
| 311 |
+
device,
|
| 312 |
)
|
| 313 |
# remove things from stuff masks
|
| 314 |
combined_things_mask = torch.any(thing_masks, dim=0)
|
|
|
|
| 330 |
num_samples = int(relative_sizes[i] * num_samples_factor)
|
| 331 |
if num_samples == 0:
|
| 332 |
continue
|
| 333 |
+
points = sample_points_based_on_preds(clipseg_pred.cpu().numpy(), num_samples)
|
| 334 |
if len(points) == 0:
|
| 335 |
continue
|
| 336 |
# use SAM to get mask for points
|
|
|
|
| 384 |
clipseg_model.to(device)
|
| 385 |
|
| 386 |
|
| 387 |
+
title = "Interactive demo: panoptic segment anything"
|
| 388 |
+
description = "Demo for zero-shot panoptic segmentation using Segment Anything, Grounding DINO, and CLIPSeg. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'."
|
| 389 |
+
article = "<p style='text-align: center'><a href='https://github.com/segments-ai/panoptic-segment-anything'>Github</a></p>"
|
| 390 |
+
|
| 391 |
+
examples = [
|
| 392 |
+
["a2d2.png", "car, bus, person", "road, sky, buildings", 0.3, 0.25, 0.1, 20, 1000],
|
| 393 |
+
["dogs.png", "dog, wooden stick", "sky, sand"],
|
| 394 |
+
["bxl.png", "car, tram, motorcycle, person", "road, buildings, sky"],
|
| 395 |
+
]
|
| 396 |
+
|
| 397 |
if __name__ == "__main__":
|
| 398 |
parser = argparse.ArgumentParser("Panoptic Segment Anything demo", add_help=True)
|
| 399 |
parser.add_argument("--debug", action="store_true", help="using debug mode")
|