Spaces:
Runtime error
Runtime error
| import os | |
| import orjson | |
| import json | |
| import webdataset as wds | |
| from tqdm import tqdm, trange | |
| import h5py | |
| import numpy as np | |
| from utils import MAXCOUNT, NAMING, check_sample | |
| OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/vg_relation" | |
| BOX_SCALE = 512 | |
| def load_image_filenames(image_file, image_dir): | |
| """ | |
| Loads the image filenames from visual genome from the JSON file that contains them. | |
| This matches the preprocessing in scene-graph-TF-release/data_tools/vg_to_imdb.py. | |
| :param image_file: JSON file. Elements contain the param "image_id". | |
| :param image_dir: directory where the VisualGenome images are located | |
| :return: List of filenames corresponding to the good images | |
| """ | |
| with open(image_file, 'r') as f: | |
| im_data = json.load(f) | |
| corrupted_ims = ['1592.jpg', '1722.jpg', '4616.jpg', '4617.jpg'] | |
| fns = [] | |
| for i, img in enumerate(tqdm(im_data)): | |
| basename = '{}.jpg'.format(img['image_id']) | |
| height = int(img['height']) | |
| width = int(img['width']) | |
| if basename in corrupted_ims: | |
| continue | |
| filename = os.path.join(image_dir, basename) | |
| if os.path.exists(filename): | |
| fns.append([filename, height, width]) | |
| assert len(fns) == 108073 | |
| return fns | |
| def load_graphs(graphs_file, mode='train', num_im=-1, num_val_im=0, filter_empty_rels=True, | |
| filter_non_overlap=False): | |
| """ | |
| Load the file containing the GT boxes and relations, as well as the dataset split | |
| :param graphs_file: HDF5 | |
| :param mode: (train, val, or test) | |
| :param num_im: Number of images we want | |
| :param num_val_im: Number of validation images | |
| :param filter_empty_rels: (will be filtered otherwise.) | |
| :param filter_non_overlap: If training, filter images that dont overlap. | |
| :return: image_index: numpy array corresponding to the index of images we're using | |
| boxes: List where each element is a [num_gt, 4] array of ground | |
| truth boxes (x1, y1, x2, y2) | |
| gt_classes: List where each element is a [num_gt] array of classes | |
| relationships: List where each element is a [num_r, 3] array of | |
| (box_ind_1, box_ind_2, predicate) relationships | |
| """ | |
| if mode not in ('train', 'val', 'test'): | |
| raise ValueError('{} invalid'.format(mode)) | |
| roi_h5 = h5py.File(graphs_file, 'r') | |
| data_split = roi_h5['split'][:] | |
| split = 2 if mode == 'test' else 0 | |
| split_mask = data_split == split | |
| # Filter out images without bounding boxes | |
| split_mask &= roi_h5['img_to_first_box'][:] >= 0 | |
| if filter_empty_rels: | |
| split_mask &= roi_h5['img_to_first_rel'][:] >= 0 | |
| image_index = np.where(split_mask)[0] | |
| if num_im > -1: | |
| image_index = image_index[:num_im] | |
| if num_val_im > 0: | |
| if mode == 'val': | |
| image_index = image_index[:num_val_im] | |
| elif mode == 'train': | |
| image_index = image_index[num_val_im:] | |
| split_mask = np.zeros_like(data_split).astype(bool) | |
| split_mask[image_index] = True | |
| # Get box information | |
| all_labels = roi_h5['labels'][:, 0] | |
| all_boxes = roi_h5['boxes_{}'.format(BOX_SCALE)][:] # will index later | |
| assert np.all(all_boxes[:, :2] >= 0) # sanity check | |
| assert np.all(all_boxes[:, 2:] > 0) # no empty box | |
| # convert from xc, yc, w, h to x1, y1, x2, y2 | |
| all_boxes[:, :2] = all_boxes[:, :2] - all_boxes[:, 2:] / 2 | |
| all_boxes[:, 2:] = all_boxes[:, :2] + all_boxes[:, 2:] | |
| im_to_first_box = roi_h5['img_to_first_box'][:][split_mask] | |
| im_to_last_box = roi_h5['img_to_last_box'][:][split_mask] | |
| im_to_first_rel = roi_h5['img_to_first_rel'][:][split_mask] | |
| im_to_last_rel = roi_h5['img_to_last_rel'][:][split_mask] | |
| # load relation labels | |
| _relations = roi_h5['relationships'][:] | |
| _relation_predicates = roi_h5['predicates'][:, 0] | |
| assert (im_to_first_rel.shape[0] == im_to_last_rel.shape[0]) | |
| assert (_relations.shape[0] == _relation_predicates.shape[0]) # sanity check | |
| # Get everything by image. | |
| boxes = [] | |
| gt_classes = [] | |
| relationships = [] | |
| for i in trange(len(image_index)): | |
| boxes_i = all_boxes[im_to_first_box[i]:im_to_last_box[i] + 1, :] | |
| gt_classes_i = all_labels[im_to_first_box[i]:im_to_last_box[i] + 1] | |
| if im_to_first_rel[i] >= 0: | |
| predicates = _relation_predicates[im_to_first_rel[i]:im_to_last_rel[i] + 1] | |
| obj_idx = _relations[im_to_first_rel[i]:im_to_last_rel[i] + 1] - im_to_first_box[i] | |
| assert np.all(obj_idx >= 0) | |
| assert np.all(obj_idx < boxes_i.shape[0]) | |
| rels = np.column_stack((obj_idx, predicates)) | |
| else: | |
| assert not filter_empty_rels | |
| rels = np.zeros((0, 3), dtype=np.int32) | |
| if filter_non_overlap: | |
| raise NotImplementedError | |
| assert mode == 'train' | |
| inters = bbox_overlaps(boxes_i, boxes_i) | |
| rel_overs = inters[rels[:, 0], rels[:, 1]] | |
| inc = np.where(rel_overs > 0.0)[0] | |
| if inc.size > 0: | |
| rels = rels[inc] | |
| else: | |
| split_mask[image_index[i]] = 0 | |
| continue | |
| boxes.append(boxes_i) | |
| gt_classes.append(gt_classes_i) | |
| relationships.append(rels) | |
| return split_mask, boxes, gt_classes, relationships | |
| def load_info(info_file): | |
| """ | |
| Loads the file containing the visual genome label meanings | |
| :param info_file: JSON | |
| :return: ind_to_classes: sorted list of classes | |
| ind_to_predicates: sorted list of predicates | |
| """ | |
| info = json.load(open(info_file, 'r')) | |
| info['label_to_idx']['__background__'] = 0 | |
| info['predicate_to_idx']['__background__'] = 0 | |
| class_to_ind = info['label_to_idx'] | |
| predicate_to_ind = info['predicate_to_idx'] | |
| ind_to_classes = sorted(class_to_ind, key=lambda k: class_to_ind[k]) | |
| ind_to_predicates = sorted(predicate_to_ind, key=lambda k: predicate_to_ind[k]) | |
| return ind_to_classes, ind_to_predicates | |
| if __name__ == "__main__": | |
| root = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg" | |
| filenames = load_image_filenames(os.path.join(root, "image_data.json"), os.path.join(root, "VG_100K")) | |
| split_mask, boxes, gt_classes, relationships = load_graphs( | |
| graphs_file=os.path.join(root, "VG-SGG.h5"), | |
| mode="train", | |
| ) | |
| split_filenames = [] | |
| for i, mask in enumerate(split_mask): | |
| if mask: | |
| split_filenames.append(filenames[i]) | |
| filenames = split_filenames | |
| ind_to_classes, ind_to_predicates = load_info(os.path.join(root, "VG-SGG-dicts.json")) | |
| assert len(filenames) == len(boxes) | |
| assert len(filenames) == len(gt_classes) | |
| assert len(filenames) == len(relationships) | |
| uuid = 0 | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| pbar = tqdm() | |
| with wds.ShardWriter(os.path.join(OUT_DIR, NAMING), maxcount=MAXCOUNT) as sink: | |
| for box, box_class, relationship, (filename, height, width) in zip(boxes, gt_classes, relationships, filenames): | |
| size = float(BOX_SCALE) / max(height, width) | |
| size = np.array([width, height, width, height]) * size | |
| box = (box.astype(float) / size).clip(0, 1) | |
| for relation in relationship: | |
| box1_id = relation[0] | |
| box2_id = relation[1] | |
| predicate = ind_to_predicates[relation[2]] | |
| box1 = [box[box1_id], ind_to_classes[box_class[box1_id]]] | |
| box2 = [box[box2_id], ind_to_classes[box_class[box2_id]]] | |
| data = [box1, box2, predicate] | |
| dataset = "vg_relation" | |
| image_path = filename | |
| key = f"{dataset}_{uuid}" | |
| uuid += 1 | |
| pbar.update() | |
| sample = { | |
| "__key__": key, | |
| "image_path.txt": image_path, | |
| "dataset.txt": dataset, | |
| "data.pyd": data, | |
| } | |
| check_sample(sample) | |
| sink.write(sample) | |
| # if __name__ == "__main__": | |
| # root = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg" | |
| # relationships = orjson.loads(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/relationships.json").read()) | |
| # image_data = orjson.loads(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/image_data.json").read()) | |
| # image_id_to_filename = {} | |
| # image_id_to_wh = {} | |
| # for image in tqdm(image_data): | |
| # image_id = image["image_id"] | |
| # subfolder, filename = image['url'].split("/")[-2:] | |
| # image_id_to_filename[image_id] = os.path.join(root, subfolder, filename) | |
| # image_id_to_wh[image_id] = (image["width"], image["height"]) | |
| # unique_predicates = [] | |
| # # with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=500) as sink: | |
| # for relation_per_image in tqdm(relationships): | |
| # image_id = relation_per_image["image_id"] | |
| # for relation in relation_per_image["relationships"]: | |
| # predicate = relation["predicate"] | |
| # unique_predicates.append(predicate) | |
| # object = { | |
| # "name": relation["object"]["name"], | |
| # "x": relation["object"]["x"], | |
| # "y": relation["object"]["y"], | |
| # "w": relation["object"]["w"], | |
| # "h": relation["object"]["h"], | |
| # } | |
| # subject = { | |
| # "name": relation["subject"]["name"], | |
| # "x": relation["subject"]["x"], | |
| # "y": relation["subject"]["y"], | |
| # "w": relation["subject"]["w"], | |
| # "h": relation["subject"]["h"], | |
| # } | |