Spaces:
Runtime error
Runtime error
Tobias Cornille
commited on
Commit
·
17d77a8
1
Parent(s):
391271a
Make more robust + fix segments annotations
Browse files
app.py
CHANGED
|
@@ -110,7 +110,7 @@ def dino_detection(
|
|
| 110 |
visualization = Image.fromarray(annotated_frame)
|
| 111 |
return boxes, category_ids, visualization
|
| 112 |
else:
|
| 113 |
-
return boxes, category_ids
|
| 114 |
|
| 115 |
|
| 116 |
def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
|
|
@@ -156,13 +156,16 @@ def clipseg_segmentation(
|
|
| 156 |
).to(device)
|
| 157 |
with torch.no_grad():
|
| 158 |
outputs = model(**inputs)
|
|
|
|
|
|
|
|
|
|
| 159 |
# resize the outputs
|
| 160 |
-
|
| 161 |
-
|
| 162 |
size=(image.size[1], image.size[0]),
|
| 163 |
mode="bilinear",
|
| 164 |
)
|
| 165 |
-
preds = torch.sigmoid(
|
| 166 |
semantic_inds = preds_to_semantic_inds(preds, background_threshold)
|
| 167 |
return preds, semantic_inds
|
| 168 |
|
|
@@ -195,7 +198,7 @@ def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categori
|
|
| 195 |
torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
|
| 196 |
]
|
| 197 |
max_size = max(sizes)
|
| 198 |
-
relative_sizes = [size / max_size for size in sizes]
|
| 199 |
|
| 200 |
# use bool masks to clip preds
|
| 201 |
clipped_preds = torch.zeros_like(preds)
|
|
@@ -240,7 +243,7 @@ def upsample_pred(pred, image_source):
|
|
| 240 |
else:
|
| 241 |
target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
|
| 242 |
upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
|
| 243 |
-
return upsampled_tensor.squeeze()
|
| 244 |
|
| 245 |
|
| 246 |
def sam_mask_from_points(predictor, image_array, points):
|
|
@@ -262,26 +265,30 @@ def sam_mask_from_points(predictor, image_array, points):
|
|
| 262 |
|
| 263 |
|
| 264 |
def inds_to_segments_format(
|
| 265 |
-
panoptic_inds, thing_category_ids,
|
| 266 |
):
|
| 267 |
panoptic_inds_array = panoptic_inds.numpy().astype(np.uint32)
|
| 268 |
bitmap_file = bitmap2file(panoptic_inds_array, is_segmentation_bitmap=True)
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
unique_inds = np.unique(panoptic_inds_array)
|
| 273 |
stuff_annotations = [
|
| 274 |
-
{"id": i
|
| 275 |
-
for i
|
| 276 |
if i in unique_inds
|
| 277 |
]
|
| 278 |
thing_annotations = [
|
| 279 |
-
{"id": len(
|
| 280 |
for i, thing_category_id in enumerate(thing_category_ids)
|
| 281 |
]
|
| 282 |
annotations = stuff_annotations + thing_annotations
|
| 283 |
|
| 284 |
-
return annotations
|
| 285 |
|
| 286 |
|
| 287 |
def generate_panoptic_mask(
|
|
@@ -295,7 +302,7 @@ def generate_panoptic_mask(
|
|
| 295 |
num_samples_factor=1000,
|
| 296 |
task_attributes_json="",
|
| 297 |
):
|
| 298 |
-
if task_attributes_json
|
| 299 |
task_attributes = json.loads(task_attributes_json)
|
| 300 |
categories = task_attributes["categories"]
|
| 301 |
category_name_to_id = {
|
|
@@ -334,67 +341,89 @@ def generate_panoptic_mask(
|
|
| 334 |
image = image.convert("RGB")
|
| 335 |
image_array = np.asarray(image)
|
| 336 |
|
| 337 |
-
# detect boxes for "thing" categories using Grounding DINO
|
| 338 |
-
thing_boxes, thing_category_ids = dino_detection(
|
| 339 |
-
dino_model,
|
| 340 |
-
image,
|
| 341 |
-
image_array,
|
| 342 |
-
thing_category_names,
|
| 343 |
-
category_name_to_id,
|
| 344 |
-
dino_box_threshold,
|
| 345 |
-
dino_text_threshold,
|
| 346 |
-
device,
|
| 347 |
-
)
|
| 348 |
# compute SAM image embedding
|
| 349 |
sam_predictor.set_image(image_array)
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
# combine the thing inds and the stuff inds into panoptic inds
|
| 393 |
-
panoptic_inds =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
ind = len(stuff_category_names) + 1
|
| 395 |
for thing_mask in thing_masks:
|
| 396 |
# overlay thing mask on panoptic inds
|
| 397 |
-
panoptic_inds[thing_mask.squeeze()] = ind
|
| 398 |
ind += 1
|
| 399 |
|
| 400 |
panoptic_bool_masks = (
|
|
@@ -403,23 +432,19 @@ def generate_panoptic_mask(
|
|
| 403 |
.astype(int)
|
| 404 |
)
|
| 405 |
panoptic_names = (
|
| 406 |
-
["
|
| 407 |
-
+ stuff_category_names
|
| 408 |
-
+ [category_names[category_id] for category_id in thing_category_ids]
|
| 409 |
)
|
| 410 |
subsection_label_pairs = [
|
| 411 |
(panoptic_bool_masks[i], panoptic_name)
|
| 412 |
for i, panoptic_name in enumerate(panoptic_names)
|
| 413 |
]
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
annotations = inds_to_segments_format(
|
| 418 |
-
panoptic_inds, thing_category_ids, stuff_category_ids, output_file_path
|
| 419 |
)
|
| 420 |
annotations_json = json.dumps(annotations)
|
| 421 |
|
| 422 |
-
return (image_array, subsection_label_pairs),
|
| 423 |
|
| 424 |
|
| 425 |
config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
|
@@ -497,7 +522,7 @@ if __name__ == "__main__":
|
|
| 497 |
step=0.001,
|
| 498 |
)
|
| 499 |
segmentation_background_threshold = gr.Slider(
|
| 500 |
-
label="Segmentation background threshold (under this threshold, a pixel is considered background)",
|
| 501 |
minimum=0.0,
|
| 502 |
maximum=1.0,
|
| 503 |
value=0.1,
|
|
@@ -529,11 +554,11 @@ if __name__ == "__main__":
|
|
| 529 |
The segmentation bitmap is a 32-bit RGBA png image which contains the segmentation masks.
|
| 530 |
The alpha channel is set to 255, and the remaining 24-bit values in the RGB channels correspond to the object ids in the annotations list.
|
| 531 |
Unlabeled regions have a value of 0.
|
| 532 |
-
Because of the large dynamic range,
|
| 533 |
"""
|
| 534 |
)
|
| 535 |
segmentation_bitmap = gr.Image(
|
| 536 |
-
type="
|
| 537 |
)
|
| 538 |
annotations_json = gr.Textbox(
|
| 539 |
label="Annotations JSON",
|
|
|
|
| 110 |
visualization = Image.fromarray(annotated_frame)
|
| 111 |
return boxes, category_ids, visualization
|
| 112 |
else:
|
| 113 |
+
return boxes, category_ids, phrases
|
| 114 |
|
| 115 |
|
| 116 |
def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
|
|
|
|
| 156 |
).to(device)
|
| 157 |
with torch.no_grad():
|
| 158 |
outputs = model(**inputs)
|
| 159 |
+
logits = outputs.logits
|
| 160 |
+
if len(logits.shape) == 2:
|
| 161 |
+
logits = logits.unsqueeze(0)
|
| 162 |
# resize the outputs
|
| 163 |
+
upscaled_logits = nn.functional.interpolate(
|
| 164 |
+
logits.unsqueeze(1),
|
| 165 |
size=(image.size[1], image.size[0]),
|
| 166 |
mode="bilinear",
|
| 167 |
)
|
| 168 |
+
preds = torch.sigmoid(upscaled_logits.squeeze(dim=1))
|
| 169 |
semantic_inds = preds_to_semantic_inds(preds, background_threshold)
|
| 170 |
return preds, semantic_inds
|
| 171 |
|
|
|
|
| 198 |
torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
|
| 199 |
]
|
| 200 |
max_size = max(sizes)
|
| 201 |
+
relative_sizes = [size / max_size for size in sizes] if max_size > 0 else sizes
|
| 202 |
|
| 203 |
# use bool masks to clip preds
|
| 204 |
clipped_preds = torch.zeros_like(preds)
|
|
|
|
| 243 |
else:
|
| 244 |
target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
|
| 245 |
upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
|
| 246 |
+
return upsampled_tensor.squeeze(dim=1)
|
| 247 |
|
| 248 |
|
| 249 |
def sam_mask_from_points(predictor, image_array, points):
|
|
|
|
| 265 |
|
| 266 |
|
| 267 |
def inds_to_segments_format(
|
| 268 |
+
panoptic_inds, thing_category_ids, stuff_category_names, category_name_to_id
|
| 269 |
):
|
| 270 |
panoptic_inds_array = panoptic_inds.numpy().astype(np.uint32)
|
| 271 |
bitmap_file = bitmap2file(panoptic_inds_array, is_segmentation_bitmap=True)
|
| 272 |
+
segmentation_bitmap = Image.open(bitmap_file)
|
| 273 |
+
|
| 274 |
+
stuff_category_ids = [
|
| 275 |
+
category_name_to_id[stuff_category_name]
|
| 276 |
+
for stuff_category_name in stuff_category_names
|
| 277 |
+
]
|
| 278 |
|
| 279 |
unique_inds = np.unique(panoptic_inds_array)
|
| 280 |
stuff_annotations = [
|
| 281 |
+
{"id": i, "category_id": stuff_category_ids[i - 1]}
|
| 282 |
+
for i in range(1, len(stuff_category_names) + 1)
|
| 283 |
if i in unique_inds
|
| 284 |
]
|
| 285 |
thing_annotations = [
|
| 286 |
+
{"id": len(stuff_category_names) + 1 + i, "category_id": thing_category_id}
|
| 287 |
for i, thing_category_id in enumerate(thing_category_ids)
|
| 288 |
]
|
| 289 |
annotations = stuff_annotations + thing_annotations
|
| 290 |
|
| 291 |
+
return segmentation_bitmap, annotations
|
| 292 |
|
| 293 |
|
| 294 |
def generate_panoptic_mask(
|
|
|
|
| 302 |
num_samples_factor=1000,
|
| 303 |
task_attributes_json="",
|
| 304 |
):
|
| 305 |
+
if task_attributes_json != "":
|
| 306 |
task_attributes = json.loads(task_attributes_json)
|
| 307 |
categories = task_attributes["categories"]
|
| 308 |
category_name_to_id = {
|
|
|
|
| 341 |
image = image.convert("RGB")
|
| 342 |
image_array = np.asarray(image)
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
# compute SAM image embedding
|
| 345 |
sam_predictor.set_image(image_array)
|
| 346 |
+
|
| 347 |
+
# detect boxes for "thing" categories using Grounding DINO
|
| 348 |
+
thing_category_ids = []
|
| 349 |
+
thing_masks = []
|
| 350 |
+
thing_boxes = []
|
| 351 |
+
detected_thing_category_names = []
|
| 352 |
+
if len(thing_category_names) > 0:
|
| 353 |
+
thing_boxes, thing_category_ids, detected_thing_category_names = dino_detection(
|
| 354 |
+
dino_model,
|
| 355 |
+
image,
|
| 356 |
+
image_array,
|
| 357 |
+
thing_category_names,
|
| 358 |
+
category_name_to_id,
|
| 359 |
+
dino_box_threshold,
|
| 360 |
+
dino_text_threshold,
|
| 361 |
+
device,
|
| 362 |
+
)
|
| 363 |
+
if len(thing_boxes) > 0:
|
| 364 |
+
# get segmentation masks for the thing boxes
|
| 365 |
+
thing_masks = sam_masks_from_dino_boxes(
|
| 366 |
+
sam_predictor, image_array, thing_boxes, device
|
| 367 |
+
)
|
| 368 |
+
detected_stuff_category_names = []
|
| 369 |
+
if len(stuff_category_names) > 0:
|
| 370 |
+
# get rough segmentation masks for "stuff" categories using CLIPSeg
|
| 371 |
+
clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
|
| 372 |
+
clipseg_processor,
|
| 373 |
+
clipseg_model,
|
| 374 |
+
image,
|
| 375 |
+
stuff_category_names,
|
| 376 |
+
segmentation_background_threshold,
|
| 377 |
+
device,
|
| 378 |
+
)
|
| 379 |
+
# remove things from stuff masks
|
| 380 |
+
clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
|
| 381 |
+
if len(thing_boxes) > 0:
|
| 382 |
+
combined_things_mask = torch.any(thing_masks, dim=0)
|
| 383 |
+
clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
|
| 384 |
+
# clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
|
| 385 |
+
# also returns the relative size of each category
|
| 386 |
+
clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
|
| 387 |
+
clipseg_semantic_inds_without_things,
|
| 388 |
+
clipseg_preds,
|
| 389 |
+
shrink_kernel_size,
|
| 390 |
+
len(stuff_category_names) + 1,
|
| 391 |
+
)
|
| 392 |
+
# get finer segmentation masks for the "stuff" categories using SAM
|
| 393 |
+
sam_preds = torch.zeros_like(clipsed_clipped_preds)
|
| 394 |
+
for i in range(clipsed_clipped_preds.shape[0]):
|
| 395 |
+
clipseg_pred = clipsed_clipped_preds[i]
|
| 396 |
+
# for each "stuff" category, sample points in the rough segmentation mask
|
| 397 |
+
num_samples = int(relative_sizes[i] * num_samples_factor)
|
| 398 |
+
if num_samples == 0:
|
| 399 |
+
continue
|
| 400 |
+
points = sample_points_based_on_preds(
|
| 401 |
+
clipseg_pred.cpu().numpy(), num_samples
|
| 402 |
+
)
|
| 403 |
+
if len(points) == 0:
|
| 404 |
+
continue
|
| 405 |
+
# use SAM to get mask for points
|
| 406 |
+
pred = sam_mask_from_points(sam_predictor, image_array, points)
|
| 407 |
+
sam_preds[i] = pred
|
| 408 |
+
sam_semantic_inds = preds_to_semantic_inds(
|
| 409 |
+
sam_preds, segmentation_background_threshold
|
| 410 |
+
)
|
| 411 |
+
detected_stuff_category_names = [
|
| 412 |
+
category_name
|
| 413 |
+
for i, category_name in enumerate(category_names)
|
| 414 |
+
if i + 1 in np.unique(sam_semantic_inds.numpy())
|
| 415 |
+
]
|
| 416 |
+
|
| 417 |
# combine the thing inds and the stuff inds into panoptic inds
|
| 418 |
+
panoptic_inds = (
|
| 419 |
+
sam_semantic_inds.clone()
|
| 420 |
+
if len(stuff_category_names) > 0
|
| 421 |
+
else torch.zeros(image_array.shape[0], image_array.shape[1], dtype=torch.long)
|
| 422 |
+
)
|
| 423 |
ind = len(stuff_category_names) + 1
|
| 424 |
for thing_mask in thing_masks:
|
| 425 |
# overlay thing mask on panoptic inds
|
| 426 |
+
panoptic_inds[thing_mask.squeeze(dim=0)] = ind
|
| 427 |
ind += 1
|
| 428 |
|
| 429 |
panoptic_bool_masks = (
|
|
|
|
| 432 |
.astype(int)
|
| 433 |
)
|
| 434 |
panoptic_names = (
|
| 435 |
+
["unlabeled"] + detected_stuff_category_names + detected_thing_category_names
|
|
|
|
|
|
|
| 436 |
)
|
| 437 |
subsection_label_pairs = [
|
| 438 |
(panoptic_bool_masks[i], panoptic_name)
|
| 439 |
for i, panoptic_name in enumerate(panoptic_names)
|
| 440 |
]
|
| 441 |
|
| 442 |
+
segmentation_bitmap, annotations = inds_to_segments_format(
|
| 443 |
+
panoptic_inds, thing_category_ids, stuff_category_names, category_name_to_id
|
|
|
|
|
|
|
| 444 |
)
|
| 445 |
annotations_json = json.dumps(annotations)
|
| 446 |
|
| 447 |
+
return (image_array, subsection_label_pairs), segmentation_bitmap, annotations_json
|
| 448 |
|
| 449 |
|
| 450 |
config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
|
|
|
| 522 |
step=0.001,
|
| 523 |
)
|
| 524 |
segmentation_background_threshold = gr.Slider(
|
| 525 |
+
label="Segmentation background threshold (under this threshold, a pixel is considered background/unlabeled)",
|
| 526 |
minimum=0.0,
|
| 527 |
maximum=1.0,
|
| 528 |
value=0.1,
|
|
|
|
| 554 |
The segmentation bitmap is a 32-bit RGBA png image which contains the segmentation masks.
|
| 555 |
The alpha channel is set to 255, and the remaining 24-bit values in the RGB channels correspond to the object ids in the annotations list.
|
| 556 |
Unlabeled regions have a value of 0.
|
| 557 |
+
Because of the large dynamic range, the segmentation bitmap appears black in the image viewer.
|
| 558 |
"""
|
| 559 |
)
|
| 560 |
segmentation_bitmap = gr.Image(
|
| 561 |
+
type="pil", label="Segmentation bitmap"
|
| 562 |
)
|
| 563 |
annotations_json = gr.Textbox(
|
| 564 |
label="Annotations JSON",
|