Spaces:
Runtime error
Runtime error
| import json, os, random, math | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| from .base_dataset import BaseDataset, check_filenames_in_zipdata | |
| from io import BytesIO | |
| def clean_annotations(annotations): | |
| for anno in annotations: | |
| anno.pop("segmentation", None) | |
| anno.pop("area", None) | |
| anno.pop("iscrowd", None) | |
| anno.pop("id", None) | |
| def make_a_sentence(obj_names, clean=False): | |
| if clean: | |
| obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names] | |
| caption = "" | |
| tokens_positive = [] | |
| for obj_name in obj_names: | |
| start_len = len(caption) | |
| caption += obj_name | |
| end_len = len(caption) | |
| caption += ", " | |
| tokens_positive.append( | |
| [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list | |
| ) | |
| caption = caption[:-2] # remove last ", " | |
| return caption #, tokens_positive | |
| class LayoutDataset(BaseDataset): | |
| """ | |
| Note: this dataset can somehow be achieved in cd_dataset.CDDataset | |
| Since if you donot set prob_real_caption=0 in CDDataset, then that | |
| dataset will only use detection annotations. However, in that dataset, | |
| we do not remove images but remove boxes. | |
| However, in layout2img works, people will just resize raw image data into 256*256, | |
| thus they pre-calculate box size and apply min_box_size before min/max_boxes_per_image. | |
| And then they will remove images if does not follow the rule. | |
| These two different methods will lead to different number of training/val images. | |
| Thus this dataset here is only for layout2img. | |
| """ | |
| def __init__(self, | |
| image_root, | |
| instances_json_path, | |
| stuff_json_path, | |
| category_embedding_path, | |
| fake_caption_type = 'empty', | |
| image_size=256, | |
| max_samples=None, | |
| min_box_size=0.02, | |
| min_boxes_per_image=3, | |
| max_boxes_per_image=8, | |
| include_other=False, | |
| random_flip=True | |
| ): | |
| super().__init__(random_crop=None, random_flip=None, image_size=None) # we only use vis_getitem func in BaseDataset, donot use the others. | |
| assert fake_caption_type in ['empty', 'made'] | |
| self.image_root = image_root | |
| self.instances_json_path = instances_json_path | |
| self.stuff_json_path = stuff_json_path | |
| self.category_embedding_path = category_embedding_path | |
| self.fake_caption_type = fake_caption_type | |
| self.image_size = image_size | |
| self.max_samples = max_samples | |
| self.min_box_size = min_box_size | |
| self.min_boxes_per_image = min_boxes_per_image | |
| self.max_boxes_per_image = max_boxes_per_image | |
| self.include_other = include_other | |
| self.random_flip = random_flip | |
| self.transform = transforms.Compose([transforms.Resize( (image_size, image_size) ), | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda t: (t * 2) - 1) ]) | |
| # Load all jsons | |
| with open(instances_json_path, 'r') as f: | |
| instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' | |
| clean_annotations(instances_data["annotations"]) | |
| self.instances_data = instances_data | |
| with open(stuff_json_path, 'r') as f: | |
| stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' | |
| clean_annotations(stuff_data["annotations"]) | |
| self.stuff_data = stuff_data | |
| # Load preprocessed name embedding | |
| self.category_embeddings = torch.load(category_embedding_path) | |
| self.embedding_len = list( self.category_embeddings.values() )[0].shape[0] | |
| # Misc | |
| self.image_ids = [] # main list for selecting images | |
| self.image_id_to_filename = {} # file names used to read image | |
| self.image_id_to_size = {} # original size of this image | |
| assert instances_data['images'] == stuff_data["images"] | |
| for image_data in instances_data['images']: | |
| image_id = image_data['id'] | |
| filename = image_data['file_name'] | |
| width = image_data['width'] | |
| height = image_data['height'] | |
| self.image_ids.append(image_id) | |
| self.image_id_to_filename[image_id] = filename | |
| self.image_id_to_size[image_id] = (width, height) | |
| # All category names (including things and stuff) | |
| self.things_id_list = [] | |
| self.stuff_id_list = [] | |
| self.object_idx_to_name = {} | |
| for category_data in instances_data['categories']: | |
| self.things_id_list.append( category_data['id'] ) | |
| self.object_idx_to_name[category_data['id']] = category_data['name'] | |
| for category_data in stuff_data['categories']: | |
| self.stuff_id_list.append( category_data['id'] ) | |
| self.object_idx_to_name[category_data['id']] = category_data['name'] | |
| self.all_categories = [ self.object_idx_to_name.get(k, None) for k in range(183+1) ] | |
| # Add object data from instances and stuff | |
| self.image_id_to_objects = defaultdict(list) | |
| self.select_objects( instances_data['annotations'] ) | |
| self.select_objects( stuff_data['annotations'] ) | |
| # Prune images that have too few or too many objects | |
| new_image_ids = [] | |
| for image_id in self.image_ids: | |
| num_objs = len(self.image_id_to_objects[image_id]) | |
| if self.min_boxes_per_image <= num_objs <= self.max_boxes_per_image: | |
| new_image_ids.append(image_id) | |
| self.image_ids = new_image_ids | |
| # Check if all filenames can be found in the zip file | |
| all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids] | |
| check_filenames_in_zipdata(all_filenames, image_root) | |
| def select_objects(self, annotations): | |
| for object_anno in annotations: | |
| image_id = object_anno['image_id'] | |
| _, _, w, h = object_anno['bbox'] | |
| W, H = self.image_id_to_size[image_id] | |
| box_area = (w * h) / (W * H) | |
| box_ok = box_area > self.min_box_size | |
| object_name = self.object_idx_to_name[object_anno['category_id']] | |
| other_ok = object_name != 'other' or self.include_other | |
| if box_ok and other_ok: | |
| self.image_id_to_objects[image_id].append(object_anno) | |
| def total_images(self): | |
| return len(self) | |
| def __getitem__(self, index): | |
| if self.max_boxes_per_image > 99: | |
| assert False, "Are you sure setting such large number of boxes?" | |
| out = {} | |
| image_id = self.image_ids[index] | |
| out['id'] = image_id | |
| flip = self.random_flip and random.random()<0.5 | |
| # Image | |
| filename = self.image_id_to_filename[image_id] | |
| zip_file = self.fetch_zipfile(self.image_root) | |
| image = Image.open(BytesIO(zip_file.read(filename))).convert('RGB') | |
| WW, HH = image.size | |
| if flip: | |
| image = ImageOps.mirror(image) | |
| out["image"] = self.transform(image) | |
| this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id]) | |
| # Make a sentence | |
| obj_names = [] # used for make a sentence | |
| boxes = torch.zeros(self.max_boxes_per_image, 4) | |
| masks = torch.zeros(self.max_boxes_per_image) | |
| positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len) | |
| for idx, object_anno in enumerate(this_image_obj_annos): | |
| obj_name = self.object_idx_to_name[ object_anno['category_id'] ] | |
| obj_names.append(obj_name) | |
| x, y, w, h = object_anno['bbox'] | |
| x0 = x / WW | |
| y0 = y / HH | |
| x1 = (x + w) / WW | |
| y1 = (y + h) / HH | |
| if flip: | |
| x0, x1 = 1-x1, 1-x0 | |
| boxes[idx] = torch.tensor([x0,y0,x1,y1]) | |
| masks[idx] = 1 | |
| positive_embeddings[idx] = self.category_embeddings[obj_name] | |
| if self.fake_caption_type == 'empty': | |
| caption = "" | |
| else: | |
| caption = make_a_sentence(obj_names, clean=True) | |
| out["caption"] = caption | |
| out["boxes"] = boxes | |
| out["masks"] = masks | |
| out["positive_embeddings"] = positive_embeddings | |
| return out | |
| def __len__(self): | |
| if self.max_samples is None: | |
| return len(self.image_ids) | |
| return min(len(self.image_ids), self.max_samples) | |