Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import Iterable | |
| from omegaconf import ListConfig | |
| import cv2 | |
| import torch | |
| from functools import partial | |
| import torchvision as thv | |
| from torch.utils.data import Dataset | |
| from utils import util_sisr | |
| from utils import util_image | |
| from utils import util_common | |
| from basicsr.data.transforms import augment | |
| from basicsr.data.realesrgan_dataset import RealESRGANDataset | |
| def get_transforms(transform_type, kwargs): | |
| ''' | |
| Accepted optins in kwargs. | |
| mean: scaler or sequence, for nornmalization | |
| std: scaler or sequence, for nornmalization | |
| crop_size: int or sequence, random or center cropping | |
| scale, out_shape: for Bicubic | |
| min_max: tuple or list with length 2, for cliping | |
| ''' | |
| if transform_type == 'default': | |
| transform = thv.transforms.Compose([ | |
| thv.transforms.ToTensor(), | |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
| ]) | |
| elif transform_type == 'resize_ccrop_norm': | |
| transform = thv.transforms.Compose([ | |
| util_image.SmallestMaxSize( | |
| max_size=kwargs.get('size'), | |
| interpolation=kwargs.get('interpolation'), | |
| ), | |
| thv.transforms.ToTensor(), | |
| thv.transforms.CenterCrop(size=kwargs.get('size', None)), | |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
| ]) | |
| elif transform_type == 'ccrop_norm': | |
| transform = thv.transforms.Compose([ | |
| thv.transforms.ToTensor(), | |
| thv.transforms.CenterCrop(size=kwargs.get('size', None)), | |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
| ]) | |
| elif transform_type == 'rcrop_aug_norm': | |
| transform = thv.transforms.Compose([ | |
| util_image.RandomCrop(pch_size=kwargs.get('pch_size', 256)), | |
| util_image.SpatialAug( | |
| only_hflip=kwargs.get('only_hflip', False), | |
| only_vflip=kwargs.get('only_vflip', False), | |
| only_hvflip=kwargs.get('only_hvflip', False), | |
| ), | |
| util_image.ToTensor(max_value=kwargs.get('max_value')), # (ndarray, hwc) --> (Tensor, chw) | |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
| ]) | |
| elif transform_type == 'aug_norm': | |
| transform = thv.transforms.Compose([ | |
| util_image.SpatialAug( | |
| only_hflip=kwargs.get('only_hflip', False), | |
| only_vflip=kwargs.get('only_vflip', False), | |
| only_hvflip=kwargs.get('only_hvflip', False), | |
| ), | |
| util_image.ToTensor(), # hwc --> chw | |
| thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)), | |
| ]) | |
| else: | |
| raise ValueError(f'Unexpected transform_variant {transform_variant}') | |
| return transform | |
| def create_dataset(dataset_config): | |
| if dataset_config['type'] == 'base': | |
| dataset = BaseData(**dataset_config['params']) | |
| elif dataset_config['type'] == 'base_meta': | |
| dataset = BaseDataMetaCond(**dataset_config['params']) | |
| elif dataset_config['type'] == 'realesrgan': | |
| dataset = RealESRGANDataset(dataset_config['params']) | |
| else: | |
| raise NotImplementedError(f"{dataset_config['type']}") | |
| return dataset | |
| class BaseData(Dataset): | |
| def __init__( | |
| self, | |
| dir_path, | |
| txt_path=None, | |
| transform_type='default', | |
| transform_kwargs={'mean':0.0, 'std':1.0}, | |
| extra_dir_path=None, | |
| extra_transform_type=None, | |
| extra_transform_kwargs=None, | |
| length=None, | |
| need_path=False, | |
| im_exts=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'], | |
| recursive=False, | |
| ): | |
| super().__init__() | |
| file_paths_all = [] | |
| if dir_path is not None: | |
| file_paths_all.extend(util_common.scan_files_from_folder(dir_path, im_exts, recursive)) | |
| if txt_path is not None: | |
| file_paths_all.extend(util_common.readline_txt(txt_path)) | |
| self.file_paths = file_paths_all if length is None else random.sample(file_paths_all, length) | |
| self.file_paths_all = file_paths_all | |
| self.length = length | |
| self.need_path = need_path | |
| self.transform = get_transforms(transform_type, transform_kwargs) | |
| self.extra_dir_path = extra_dir_path | |
| if extra_dir_path is not None: | |
| assert extra_transform_type is not None | |
| self.extra_transform = get_transforms(extra_transform_type, extra_transform_kwargs) | |
| def __len__(self): | |
| return len(self.file_paths) | |
| def __getitem__(self, index): | |
| im_path_base = self.file_paths[index] | |
| im_base = util_image.imread(im_path_base, chn='rgb', dtype='float32') | |
| im_target = self.transform(im_base) | |
| out = {'image':im_target, 'lq':im_target} | |
| if self.extra_dir_path is not None: | |
| im_path_extra = Path(self.extra_dir_path) / Path(im_path_base).name | |
| im_extra = util_image.imread(im_path_extra, chn='rgb', dtype='float32') | |
| im_extra = self.extra_transform(im_extra) | |
| out['gt'] = im_extra | |
| if self.need_path: | |
| out['path'] = im_path_base | |
| return out | |
| def reset_dataset(self): | |
| self.file_paths = random.sample(self.file_paths_all, self.length) | |
| class BaseDataMetaCond(Dataset): | |
| def __init__( | |
| self, | |
| meta_dir, | |
| transform_type='default', | |
| transform_kwargs={'mean':0.5, 'std':0.5}, | |
| length=None, | |
| need_path=False, | |
| cond_key='canny', | |
| cond_transform_type='default', | |
| cond_transform_kwargs={'mean':0.5, 'std':0.5}, | |
| ): | |
| super().__init__() | |
| if not isinstance(meta_dir, ListConfig): | |
| meta_dir = [meta_dir,] | |
| meta_list = [] | |
| # for current_dir in meta_dir: | |
| # for json_path in Path(current_dir).glob("*.json"): | |
| # with open(json_path, 'r') as json_file: | |
| # meta_info = json.load(json_file) | |
| # meta_list.append(meta_info) | |
| for current_dir in meta_dir: | |
| meta_list.extend(sorted([str(x) for x in Path(current_dir).glob("*.json")])) | |
| self.meta_list = meta_list if length is None else meta_list[:length] | |
| self.cond_key = cond_key | |
| self.length = length | |
| self.need_path = need_path | |
| self.transform = get_transforms(transform_type, transform_kwargs) | |
| self.cond_trasform = get_transforms(cond_transform_type, cond_transform_kwargs) | |
| def __len__(self): | |
| return len(self.meta_list) | |
| def __getitem__(self, index): | |
| # meta_info = self.meta_list[index] | |
| json_path = self.meta_list[index] | |
| with open(json_path, 'r') as json_file: | |
| meta_info = json.load(json_file) | |
| # images | |
| im_path = meta_info['source'] | |
| im_source = util_image.imread(im_path, chn='rgb', dtype='uint8') | |
| im_source = self.transform(im_source) | |
| out = {'image': im_source,} | |
| if self.need_path: | |
| out['path'] = im_path | |
| # latent | |
| if 'latent' in meta_info: | |
| latent_path = meta_info['latent'] | |
| out['latent'] = np.load(latent_path) | |
| # prompt | |
| out['txt'] = meta_info['prompt'] | |
| # condition | |
| cond_key = self.cond_key | |
| cond_path = meta_info[cond_key] | |
| if cond_key == 'canny': | |
| cond = util_image.imread(cond_path, chn='gray', dtype='uint8')[:, :, None] | |
| elif cond_key == 'seg': | |
| cond = util_image.imread(cond_path, chn='rgb', dtype='uint8') | |
| else: | |
| raise ValueError(f"Unexpected cond key: {cond_key}") | |
| cond = self.cond_trasform(cond) | |
| out['cond'] = cond | |
| return out | |