Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| import os | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from datasets.transforms import get_pair_transforms | |
| def load_image(impath): | |
| return Image.open(impath) | |
| def load_pairs_from_cache_file(fname, root=""): | |
| assert os.path.isfile( | |
| fname | |
| ), "cannot parse pairs from {:s}, file does not exist".format(fname) | |
| with open(fname, "r") as fid: | |
| lines = fid.read().strip().splitlines() | |
| pairs = [ | |
| (os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1])) | |
| for l in lines | |
| ] | |
| return pairs | |
| def load_pairs_from_list_file(fname, root=""): | |
| assert os.path.isfile( | |
| fname | |
| ), "cannot parse pairs from {:s}, file does not exist".format(fname) | |
| with open(fname, "r") as fid: | |
| lines = fid.read().strip().splitlines() | |
| pairs = [ | |
| (os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg")) | |
| for l in lines | |
| if not l.startswith("#") | |
| ] | |
| return pairs | |
| def write_cache_file(fname, pairs, root=""): | |
| if len(root) > 0: | |
| if not root.endswith("/"): | |
| root += "/" | |
| assert os.path.isdir(root) | |
| s = "" | |
| for im1, im2 in pairs: | |
| if len(root) > 0: | |
| assert im1.startswith(root), im1 | |
| assert im2.startswith(root), im2 | |
| s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :]) | |
| with open(fname, "w") as fid: | |
| fid.write(s[:-1]) | |
| def parse_and_cache_all_pairs(dname, data_dir="./data/"): | |
| if dname == "habitat_release": | |
| dirname = os.path.join(data_dir, "habitat_release") | |
| assert os.path.isdir(dirname), ( | |
| "cannot find folder for habitat_release pairs: " + dirname | |
| ) | |
| cache_file = os.path.join(dirname, "pairs.txt") | |
| assert not os.path.isfile(cache_file), ( | |
| "cache file already exists: " + cache_file | |
| ) | |
| print("Parsing pairs for dataset: " + dname) | |
| pairs = [] | |
| for root, dirs, files in os.walk(dirname): | |
| if "val" in root: | |
| continue | |
| dirs.sort() | |
| pairs += [ | |
| ( | |
| os.path.join(root, f), | |
| os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"), | |
| ) | |
| for f in sorted(files) | |
| if f.endswith("_1.jpeg") | |
| ] | |
| print("Found {:,} pairs".format(len(pairs))) | |
| print("Writing cache to: " + cache_file) | |
| write_cache_file(cache_file, pairs, root=dirname) | |
| else: | |
| raise NotImplementedError("Unknown dataset: " + dname) | |
| def dnames_to_image_pairs(dnames, data_dir="./data/"): | |
| """ | |
| dnames: list of datasets with image pairs, separated by + | |
| """ | |
| all_pairs = [] | |
| for dname in dnames.split("+"): | |
| if dname == "habitat_release": | |
| dirname = os.path.join(data_dir, "habitat_release") | |
| assert os.path.isdir(dirname), ( | |
| "cannot find folder for habitat_release pairs: " + dirname | |
| ) | |
| cache_file = os.path.join(dirname, "pairs.txt") | |
| assert os.path.isfile(cache_file), ( | |
| "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. " | |
| + cache_file | |
| ) | |
| pairs = load_pairs_from_cache_file(cache_file, root=dirname) | |
| elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]: | |
| dirname = os.path.join(data_dir, dname + "_crops") | |
| assert os.path.isdir( | |
| dirname | |
| ), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) | |
| list_file = os.path.join(dirname, "listing.txt") | |
| assert os.path.isfile( | |
| list_file | |
| ), "cannot find list file for {:s} pairs, see instructions. {:s}".format( | |
| dname, list_file | |
| ) | |
| pairs = load_pairs_from_list_file(list_file, root=dirname) | |
| print(" {:s}: {:,} pairs".format(dname, len(pairs))) | |
| all_pairs += pairs | |
| if "+" in dnames: | |
| print(" Total: {:,} pairs".format(len(all_pairs))) | |
| return all_pairs | |
| class PairsDataset(Dataset): | |
| def __init__( | |
| self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/" | |
| ): | |
| super().__init__() | |
| self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) | |
| self.transforms = get_pair_transforms( | |
| transform_str=trfs, totensor=totensor, normalize=normalize | |
| ) | |
| def __len__(self): | |
| return len(self.image_pairs) | |
| def __getitem__(self, index): | |
| im1path, im2path = self.image_pairs[index] | |
| im1 = load_image(im1path) | |
| im2 = load_image(im2path) | |
| if self.transforms is not None: | |
| im1, im2 = self.transforms(im1, im2) | |
| return im1, im2 | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| prog="Computing and caching list of pairs for a given dataset" | |
| ) | |
| parser.add_argument( | |
| "--data_dir", default="./data/", type=str, help="path where data are stored" | |
| ) | |
| parser.add_argument( | |
| "--dataset", default="habitat_release", type=str, help="name of the dataset" | |
| ) | |
| args = parser.parse_args() | |
| parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) | |