Spaces:
Build error
Build error
| import os, yaml, pickle, shutil, tarfile, glob | |
| import cv2 | |
| import albumentations | |
| import PIL | |
| import numpy as np | |
| import torchvision.transforms.functional as TF | |
| from omegaconf import OmegaConf | |
| from functools import partial | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from torch.utils.data import Dataset, Subset | |
| import taming.data.utils as tdu | |
| from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve | |
| from taming.data.imagenet import ImagePaths | |
| from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light | |
| def synset2idx(path_to_yaml="data/index_synset.yaml"): | |
| with open(path_to_yaml) as f: | |
| di2s = yaml.load(f) | |
| return dict((v,k) for k,v in di2s.items()) | |
| class ImageNetBase(Dataset): | |
| def __init__(self, config=None): | |
| self.config = config or OmegaConf.create() | |
| if not type(self.config)==dict: | |
| self.config = OmegaConf.to_container(self.config) | |
| self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) | |
| self.process_images = True # if False we skip loading & processing images and self.data contains filepaths | |
| self._prepare() | |
| self._prepare_synset_to_human() | |
| self._prepare_idx_to_synset() | |
| self._prepare_human_to_integer_label() | |
| self._load() | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, i): | |
| return self.data[i] | |
| def _prepare(self): | |
| raise NotImplementedError() | |
| def _filter_relpaths(self, relpaths): | |
| ignore = set([ | |
| "n06596364_9591.JPEG", | |
| ]) | |
| relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] | |
| if "sub_indices" in self.config: | |
| indices = str_to_indices(self.config["sub_indices"]) | |
| synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings | |
| self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) | |
| files = [] | |
| for rpath in relpaths: | |
| syn = rpath.split("/")[0] | |
| if syn in synsets: | |
| files.append(rpath) | |
| return files | |
| else: | |
| return relpaths | |
| def _prepare_synset_to_human(self): | |
| SIZE = 2655750 | |
| URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" | |
| self.human_dict = os.path.join(self.root, "synset_human.txt") | |
| if (not os.path.exists(self.human_dict) or | |
| not os.path.getsize(self.human_dict)==SIZE): | |
| download(URL, self.human_dict) | |
| def _prepare_idx_to_synset(self): | |
| URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" | |
| self.idx2syn = os.path.join(self.root, "index_synset.yaml") | |
| if (not os.path.exists(self.idx2syn)): | |
| download(URL, self.idx2syn) | |
| def _prepare_human_to_integer_label(self): | |
| URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" | |
| self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") | |
| if (not os.path.exists(self.human2integer)): | |
| download(URL, self.human2integer) | |
| with open(self.human2integer, "r") as f: | |
| lines = f.read().splitlines() | |
| assert len(lines) == 1000 | |
| self.human2integer_dict = dict() | |
| for line in lines: | |
| value, key = line.split(":") | |
| self.human2integer_dict[key] = int(value) | |
| def _load(self): | |
| with open(self.txt_filelist, "r") as f: | |
| self.relpaths = f.read().splitlines() | |
| l1 = len(self.relpaths) | |
| self.relpaths = self._filter_relpaths(self.relpaths) | |
| print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) | |
| self.synsets = [p.split("/")[0] for p in self.relpaths] | |
| self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] | |
| unique_synsets = np.unique(self.synsets) | |
| class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) | |
| if not self.keep_orig_class_label: | |
| self.class_labels = [class_dict[s] for s in self.synsets] | |
| else: | |
| self.class_labels = [self.synset2idx[s] for s in self.synsets] | |
| with open(self.human_dict, "r") as f: | |
| human_dict = f.read().splitlines() | |
| human_dict = dict(line.split(maxsplit=1) for line in human_dict) | |
| self.human_labels = [human_dict[s] for s in self.synsets] | |
| labels = { | |
| "relpath": np.array(self.relpaths), | |
| "synsets": np.array(self.synsets), | |
| "class_label": np.array(self.class_labels), | |
| "human_label": np.array(self.human_labels), | |
| } | |
| if self.process_images: | |
| self.size = retrieve(self.config, "size", default=256) | |
| self.data = ImagePaths(self.abspaths, | |
| labels=labels, | |
| size=self.size, | |
| random_crop=self.random_crop, | |
| ) | |
| else: | |
| self.data = self.abspaths | |
| class ImageNetTrain(ImageNetBase): | |
| NAME = "ILSVRC2012_train" | |
| URL = "http://www.image-net.org/challenges/LSVRC/2012/" | |
| AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" | |
| FILES = [ | |
| "ILSVRC2012_img_train.tar", | |
| ] | |
| SIZES = [ | |
| 147897477120, | |
| ] | |
| def __init__(self, process_images=True, data_root=None, **kwargs): | |
| self.process_images = process_images | |
| self.data_root = data_root | |
| super().__init__(**kwargs) | |
| def _prepare(self): | |
| if self.data_root: | |
| self.root = os.path.join(self.data_root, self.NAME) | |
| else: | |
| cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) | |
| self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) | |
| self.datadir = os.path.join(self.root, "data") | |
| self.txt_filelist = os.path.join(self.root, "filelist.txt") | |
| self.expected_length = 1281167 | |
| self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", | |
| default=True) | |
| if not tdu.is_prepared(self.root): | |
| # prep | |
| print("Preparing dataset {} in {}".format(self.NAME, self.root)) | |
| datadir = self.datadir | |
| if not os.path.exists(datadir): | |
| path = os.path.join(self.root, self.FILES[0]) | |
| if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: | |
| import academictorrents as at | |
| atpath = at.get(self.AT_HASH, datastore=self.root) | |
| assert atpath == path | |
| print("Extracting {} to {}".format(path, datadir)) | |
| os.makedirs(datadir, exist_ok=True) | |
| with tarfile.open(path, "r:") as tar: | |
| tar.extractall(path=datadir) | |
| print("Extracting sub-tars.") | |
| subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) | |
| for subpath in tqdm(subpaths): | |
| subdir = subpath[:-len(".tar")] | |
| os.makedirs(subdir, exist_ok=True) | |
| with tarfile.open(subpath, "r:") as tar: | |
| tar.extractall(path=subdir) | |
| filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) | |
| filelist = [os.path.relpath(p, start=datadir) for p in filelist] | |
| filelist = sorted(filelist) | |
| filelist = "\n".join(filelist)+"\n" | |
| with open(self.txt_filelist, "w") as f: | |
| f.write(filelist) | |
| tdu.mark_prepared(self.root) | |
| class ImageNetValidation(ImageNetBase): | |
| NAME = "ILSVRC2012_validation" | |
| URL = "http://www.image-net.org/challenges/LSVRC/2012/" | |
| AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" | |
| VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" | |
| FILES = [ | |
| "ILSVRC2012_img_val.tar", | |
| "validation_synset.txt", | |
| ] | |
| SIZES = [ | |
| 6744924160, | |
| 1950000, | |
| ] | |
| def __init__(self, process_images=True, data_root=None, **kwargs): | |
| self.data_root = data_root | |
| self.process_images = process_images | |
| super().__init__(**kwargs) | |
| def _prepare(self): | |
| if self.data_root: | |
| self.root = os.path.join(self.data_root, self.NAME) | |
| else: | |
| cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) | |
| self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) | |
| self.datadir = os.path.join(self.root, "data") | |
| self.txt_filelist = os.path.join(self.root, "filelist.txt") | |
| self.expected_length = 50000 | |
| self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", | |
| default=False) | |
| if not tdu.is_prepared(self.root): | |
| # prep | |
| print("Preparing dataset {} in {}".format(self.NAME, self.root)) | |
| datadir = self.datadir | |
| if not os.path.exists(datadir): | |
| path = os.path.join(self.root, self.FILES[0]) | |
| if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: | |
| import academictorrents as at | |
| atpath = at.get(self.AT_HASH, datastore=self.root) | |
| assert atpath == path | |
| print("Extracting {} to {}".format(path, datadir)) | |
| os.makedirs(datadir, exist_ok=True) | |
| with tarfile.open(path, "r:") as tar: | |
| tar.extractall(path=datadir) | |
| vspath = os.path.join(self.root, self.FILES[1]) | |
| if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: | |
| download(self.VS_URL, vspath) | |
| with open(vspath, "r") as f: | |
| synset_dict = f.read().splitlines() | |
| synset_dict = dict(line.split() for line in synset_dict) | |
| print("Reorganizing into synset folders") | |
| synsets = np.unique(list(synset_dict.values())) | |
| for s in synsets: | |
| os.makedirs(os.path.join(datadir, s), exist_ok=True) | |
| for k, v in synset_dict.items(): | |
| src = os.path.join(datadir, k) | |
| dst = os.path.join(datadir, v) | |
| shutil.move(src, dst) | |
| filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) | |
| filelist = [os.path.relpath(p, start=datadir) for p in filelist] | |
| filelist = sorted(filelist) | |
| filelist = "\n".join(filelist)+"\n" | |
| with open(self.txt_filelist, "w") as f: | |
| f.write(filelist) | |
| tdu.mark_prepared(self.root) | |
| class ImageNetSR(Dataset): | |
| def __init__(self, size=None, | |
| degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., | |
| random_crop=True): | |
| """ | |
| Imagenet Superresolution Dataloader | |
| Performs following ops in order: | |
| 1. crops a crop of size s from image either as random or center crop | |
| 2. resizes crop to size with cv2.area_interpolation | |
| 3. degrades resized crop with degradation_fn | |
| :param size: resizing to size after cropping | |
| :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light | |
| :param downscale_f: Low Resolution Downsample factor | |
| :param min_crop_f: determines crop size s, | |
| where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) | |
| :param max_crop_f: "" | |
| :param data_root: | |
| :param random_crop: | |
| """ | |
| self.base = self.get_base() | |
| assert size | |
| assert (size / downscale_f).is_integer() | |
| self.size = size | |
| self.LR_size = int(size / downscale_f) | |
| self.min_crop_f = min_crop_f | |
| self.max_crop_f = max_crop_f | |
| assert(max_crop_f <= 1.) | |
| self.center_crop = not random_crop | |
| self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) | |
| self.pil_interpolation = False # gets reset later if incase interp_op is from pillow | |
| if degradation == "bsrgan": | |
| self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) | |
| elif degradation == "bsrgan_light": | |
| self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) | |
| else: | |
| interpolation_fn = { | |
| "cv_nearest": cv2.INTER_NEAREST, | |
| "cv_bilinear": cv2.INTER_LINEAR, | |
| "cv_bicubic": cv2.INTER_CUBIC, | |
| "cv_area": cv2.INTER_AREA, | |
| "cv_lanczos": cv2.INTER_LANCZOS4, | |
| "pil_nearest": PIL.Image.NEAREST, | |
| "pil_bilinear": PIL.Image.BILINEAR, | |
| "pil_bicubic": PIL.Image.BICUBIC, | |
| "pil_box": PIL.Image.BOX, | |
| "pil_hamming": PIL.Image.HAMMING, | |
| "pil_lanczos": PIL.Image.LANCZOS, | |
| }[degradation] | |
| self.pil_interpolation = degradation.startswith("pil_") | |
| if self.pil_interpolation: | |
| self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) | |
| else: | |
| self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, | |
| interpolation=interpolation_fn) | |
| def __len__(self): | |
| return len(self.base) | |
| def __getitem__(self, i): | |
| example = self.base[i] | |
| image = Image.open(example["file_path_"]) | |
| if not image.mode == "RGB": | |
| image = image.convert("RGB") | |
| image = np.array(image).astype(np.uint8) | |
| min_side_len = min(image.shape[:2]) | |
| crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) | |
| crop_side_len = int(crop_side_len) | |
| if self.center_crop: | |
| self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) | |
| else: | |
| self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) | |
| image = self.cropper(image=image)["image"] | |
| image = self.image_rescaler(image=image)["image"] | |
| if self.pil_interpolation: | |
| image_pil = PIL.Image.fromarray(image) | |
| LR_image = self.degradation_process(image_pil) | |
| LR_image = np.array(LR_image).astype(np.uint8) | |
| else: | |
| LR_image = self.degradation_process(image=image)["image"] | |
| example["image"] = (image/127.5 - 1.0).astype(np.float32) | |
| example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) | |
| return example | |
| class ImageNetSRTrain(ImageNetSR): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def get_base(self): | |
| with open("data/imagenet_train_hr_indices.p", "rb") as f: | |
| indices = pickle.load(f) | |
| dset = ImageNetTrain(process_images=False,) | |
| return Subset(dset, indices) | |
| class ImageNetSRValidation(ImageNetSR): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def get_base(self): | |
| with open("data/imagenet_val_hr_indices.p", "rb") as f: | |
| indices = pickle.load(f) | |
| dset = ImageNetValidation(process_images=False,) | |
| return Subset(dset, indices) | |