Spaces:
Runtime error
Runtime error
| import ast | |
| import os | |
| import json | |
| from matplotlib.patches import Polygon | |
| from matplotlib.collections import PatchCollection | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import cv2 | |
| import inflect | |
| p = inflect.engine() | |
| img_dir = "imgs" | |
| bg_prompt_text = "Background prompt: " | |
| # h, w | |
| box_scale = (512, 512) | |
| size = box_scale | |
| size_h, size_w = size | |
| print(f"Using box scale: {box_scale}") | |
| def parse_input(text=None, no_input=False): | |
| if not text: | |
| if no_input: | |
| return | |
| text = input("Enter the response: ") | |
| if "Objects: " in text: | |
| text = text.split("Objects: ")[1] | |
| text_split = text.split(bg_prompt_text) | |
| if len(text_split) == 2: | |
| gen_boxes, bg_prompt = text_split | |
| elif len(text_split) == 1: | |
| if no_input: | |
| return | |
| gen_boxes = text | |
| bg_prompt = "" | |
| while not bg_prompt: | |
| # Ignore the empty lines in the response | |
| bg_prompt = input("Enter the background prompt: ").strip() | |
| if bg_prompt_text in bg_prompt: | |
| bg_prompt = bg_prompt.split(bg_prompt_text)[1] | |
| else: | |
| raise ValueError(f"text: {text}") | |
| try: | |
| gen_boxes = ast.literal_eval(gen_boxes) | |
| except SyntaxError as e: | |
| # Sometimes the response is in plain text | |
| if "No objects" in gen_boxes: | |
| gen_boxes = [] | |
| else: | |
| raise e | |
| bg_prompt = bg_prompt.strip() | |
| return gen_boxes, bg_prompt | |
| def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3): | |
| if len(gen_boxes) == 0: | |
| return [] | |
| box_dict_format = False | |
| gen_boxes_new = [] | |
| for gen_box in gen_boxes: | |
| if isinstance(gen_box, dict): | |
| name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box'] | |
| box_dict_format = True | |
| else: | |
| name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box | |
| if bbox_w <= 0 or bbox_h <= 0: | |
| # Empty boxes | |
| continue | |
| if ignore_background: | |
| if (bbox_w >= size[1] and bbox_h >= size[0]) or bbox_x > size[1] or bbox_y > size[0]: | |
| # Ignore the background boxes | |
| continue | |
| gen_boxes_new.append(gen_box) | |
| gen_boxes = gen_boxes_new | |
| if len(gen_boxes) == 0: | |
| return [] | |
| filtered_gen_boxes = [] | |
| if box_dict_format: | |
| # For compatibility | |
| bbox_left_x_min = min([gen_box['bounding_box'][0] for gen_box in gen_boxes]) | |
| bbox_right_x_max = max([gen_box['bounding_box'][0] + gen_box['bounding_box'][2] for gen_box in gen_boxes]) | |
| bbox_top_y_min = min([gen_box['bounding_box'][1] for gen_box in gen_boxes]) | |
| bbox_bottom_y_max = max([gen_box['bounding_box'][1] + gen_box['bounding_box'][3] for gen_box in gen_boxes]) | |
| else: | |
| bbox_left_x_min = min([gen_box[1][0] for gen_box in gen_boxes]) | |
| bbox_right_x_max = max([gen_box[1][0] + gen_box[1][2] for gen_box in gen_boxes]) | |
| bbox_top_y_min = min([gen_box[1][1] for gen_box in gen_boxes]) | |
| bbox_bottom_y_max = max([gen_box[1][1] + gen_box[1][3] for gen_box in gen_boxes]) | |
| # All boxes are empty | |
| if (bbox_right_x_max - bbox_left_x_min) == 0: | |
| return [] | |
| # Used if scale_boxes is True | |
| shift = -bbox_left_x_min | |
| scale = size_w / (bbox_right_x_max - bbox_left_x_min) | |
| scale = min(scale, max_scale) | |
| for gen_box in gen_boxes: | |
| if box_dict_format: | |
| name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box'] | |
| else: | |
| name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box | |
| if scale_boxes: | |
| # Vertical: move the boxes if out of bound | |
| # Horizontal: move and scale the boxes so it spans the horizontal line | |
| bbox_x = (bbox_x + shift) * scale | |
| bbox_y = bbox_y * scale | |
| bbox_w, bbox_h = bbox_w * scale, bbox_h * scale | |
| # TODO: verify this makes the y center not moving | |
| bbox_y_offset = 0 | |
| if bbox_top_y_min * scale + bbox_y_offset < 0: | |
| bbox_y_offset -= bbox_top_y_min * scale | |
| if bbox_bottom_y_max * scale + bbox_y_offset >= size_h: | |
| bbox_y_offset -= bbox_bottom_y_max * scale - size_h | |
| bbox_y += bbox_y_offset | |
| if bbox_y < 0: | |
| bbox_y, bbox_h = 0, bbox_h - bbox_y | |
| name = name.rstrip(".") | |
| bounding_box = (int(np.round(bbox_x)), int(np.round(bbox_y)), int(np.round(bbox_w)), int(np.round(bbox_h))) | |
| if box_dict_format: | |
| gen_box = { | |
| 'name': name, | |
| 'bounding_box': bounding_box | |
| } | |
| else: | |
| gen_box = (name, bounding_box) | |
| filtered_gen_boxes.append(gen_box) | |
| return filtered_gen_boxes | |
| def draw_boxes(anns): | |
| ax = plt.gca() | |
| ax.set_autoscale_on(False) | |
| polygons = [] | |
| color = [] | |
| for ann in anns: | |
| c = (np.random.random((1, 3))*0.6+0.4) | |
| [bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox'] | |
| poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], | |
| [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] | |
| np_poly = np.array(poly).reshape((4, 2)) | |
| polygons.append(Polygon(np_poly)) | |
| color.append(c) | |
| # print(ann) | |
| name = ann['name'] if 'name' in ann else str(ann['category_id']) | |
| ax.text(bbox_x, bbox_y, name, style='italic', | |
| bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5}) | |
| p = PatchCollection(polygons, facecolor='none', | |
| edgecolors=color, linewidths=2) | |
| ax.add_collection(p) | |
| def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False): | |
| if len(gen_boxes) == 0: | |
| return | |
| if isinstance(gen_boxes[0], dict): | |
| anns = [{'name': gen_box['name'], 'bbox': gen_box['bounding_box']} | |
| for gen_box in gen_boxes] | |
| else: | |
| anns = [{'name': gen_box[0], 'bbox': gen_box[1]} for gen_box in gen_boxes] | |
| # White background (to allow line to show on the edge) | |
| I = np.ones((size[0]+4, size[1]+4, 3), dtype=np.uint8) * 255 | |
| plt.imshow(I) | |
| plt.axis('off') | |
| if bg_prompt is not None: | |
| ax = plt.gca() | |
| ax.text(0, 0, bg_prompt, style='italic', | |
| bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5}) | |
| c = (np.zeros((1, 3))) | |
| [bbox_x, bbox_y, bbox_w, bbox_h] = (0, 0, size[1], size[0]) | |
| poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], | |
| [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] | |
| np_poly = np.array(poly).reshape((4, 2)) | |
| polygons = [Polygon(np_poly)] | |
| color = [c] | |
| p = PatchCollection(polygons, facecolor='none', | |
| edgecolors=color, linewidths=2) | |
| ax.add_collection(p) | |
| draw_boxes(anns) | |
| if show: | |
| plt.show() | |
| else: | |
| print("Saved to", f"{img_dir}/boxes.png", f"ind: {ind}") | |
| if ind is not None: | |
| plt.savefig(f"{img_dir}/boxes_{ind}.png") | |
| plt.savefig(f"{img_dir}/boxes.png") | |
| def show_masks(masks): | |
| masks_to_show = np.zeros((*size, 3), dtype=np.float32) | |
| for mask in masks: | |
| c = (np.random.random((3,))*0.6+0.4) | |
| masks_to_show += mask[..., None] * c[None, None, :] | |
| plt.imshow(masks_to_show) | |
| plt.savefig(f"{img_dir}/masks.png") | |
| plt.show() | |
| plt.clf() | |
| def convert_box(box, height, width): | |
| # box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max | |
| x_min, y_min = box[0] / width, box[1] / height | |
| w_box, h_box = box[2] / width, box[3] / height | |
| x_max, y_max = x_min + w_box, y_min + h_box | |
| return x_min, y_min, x_max, y_max | |
| def convert_spec(spec, height, width, include_counts=True, verbose=False): | |
| # Infer from spec | |
| prompt, gen_boxes, bg_prompt = spec['prompt'], spec['gen_boxes'], spec['bg_prompt'] | |
| # This ensures the same objects appear together because flattened `overall_phrases_bboxes` should EXACTLY correspond to `so_prompt_phrase_box_list`. | |
| gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0]) | |
| gen_boxes = [(name, convert_box(box, height=height, width=width)) for name, box in gen_boxes] | |
| # NOTE: so phrase should include all the words associated to the object (otherwise "an orange dog" may be recognized as "an orange" by the model generating the background). | |
| # so word should have one token that includes the word to transfer cross attention (the object name). | |
| # Currently using the last word of the object name as word. | |
| if bg_prompt: | |
| so_prompt_phrase_word_box_list = [(f"{bg_prompt} with {name}", name, name.split(" ")[-1], box) for name, box in gen_boxes] | |
| else: | |
| so_prompt_phrase_word_box_list = [(f"{name}", name, name.split(" ")[-1], box) for name, box in gen_boxes] | |
| objects = [gen_box[0] for gen_box in gen_boxes] | |
| objects_unique, objects_count = np.unique(objects, return_counts=True) | |
| num_total_matched_boxes = 0 | |
| overall_phrases_words_bboxes = [] | |
| for ind, object_name in enumerate(objects_unique): | |
| bboxes = [box for name, box in gen_boxes if name == object_name] | |
| if objects_count[ind] > 1: | |
| phrase = p.plural_noun(object_name.replace("an ", "").replace("a ", "")) | |
| if include_counts: | |
| phrase = p.number_to_words(objects_count[ind]) + " " + phrase | |
| else: | |
| phrase = object_name | |
| # Currently using the last word of the phrase as word. | |
| word = phrase.split(' ')[-1] | |
| num_total_matched_boxes += len(bboxes) | |
| overall_phrases_words_bboxes.append((phrase, word, bboxes)) | |
| assert num_total_matched_boxes == len(gen_boxes), f"{num_total_matched_boxes} != {len(gen_boxes)}" | |
| objects_str = ", ".join([phrase for phrase, _, _ in overall_phrases_words_bboxes]) | |
| if objects_str: | |
| if bg_prompt: | |
| overall_prompt = f"{bg_prompt} with {objects_str}" | |
| else: | |
| overall_prompt = objects_str | |
| else: | |
| overall_prompt = bg_prompt | |
| if verbose: | |
| print("so_prompt_phrase_word_box_list:", so_prompt_phrase_word_box_list) | |
| print("overall_prompt:", overall_prompt) | |
| print("overall_phrases_words_bboxes:", overall_phrases_words_bboxes) | |
| return so_prompt_phrase_word_box_list, overall_prompt, overall_phrases_words_bboxes | |