Spaces:
Runtime error
Runtime error
| from pix2pix.data.base_dataset import BaseDataset | |
| from pix2pix.data.image_folder import make_dataset | |
| from pix2pix.util.guidedfilter import GuidedFilter | |
| import numpy as np | |
| import os | |
| import torch | |
| from PIL import Image | |
| def normalize(img): | |
| img = img * 2 | |
| img = img - 1 | |
| return img | |
| def normalize01(img): | |
| return (img - torch.min(img)) / (torch.max(img)-torch.min(img)) | |
| class DepthMergeDataset(BaseDataset): | |
| def __init__(self, opt): | |
| BaseDataset.__init__(self, opt) | |
| self.dir_outer = os.path.join(opt.dataroot, opt.phase, 'outer') | |
| self.dir_inner = os.path.join(opt.dataroot, opt.phase, 'inner') | |
| self.dir_gtfake = os.path.join(opt.dataroot, opt.phase, 'gtfake') | |
| self.outer_paths = sorted(make_dataset(self.dir_outer, opt.max_dataset_size)) | |
| self.inner_paths = sorted(make_dataset(self.dir_inner, opt.max_dataset_size)) | |
| self.gtfake_paths = sorted(make_dataset(self.dir_gtfake, opt.max_dataset_size)) | |
| self.dataset_size = len(self.outer_paths) | |
| if opt.phase == 'train': | |
| self.isTrain = True | |
| else: | |
| self.isTrain = False | |
| def __getitem__(self, index): | |
| normalize_coef = np.float32(2 ** 16) | |
| data_outer = Image.open(self.outer_paths[index % self.dataset_size]) # needs to be a tensor | |
| data_outer = np.array(data_outer, dtype=np.float32) | |
| data_outer = data_outer / normalize_coef | |
| data_inner = Image.open(self.inner_paths[index % self.dataset_size]) # needs to be a tensor | |
| data_inner = np.array(data_inner, dtype=np.float32) | |
| data_inner = data_inner / normalize_coef | |
| if self.isTrain: | |
| data_gtfake = Image.open(self.gtfake_paths[index % self.dataset_size]) # needs to be a tensor | |
| data_gtfake = np.array(data_gtfake, dtype=np.float32) | |
| data_gtfake = data_gtfake / normalize_coef | |
| data_inner = GuidedFilter(data_gtfake, data_inner, 64, 0.00000001).smooth.astype('float32') | |
| data_outer = GuidedFilter(data_outer, data_gtfake, 64, 0.00000001).smooth.astype('float32') | |
| data_outer = torch.from_numpy(data_outer) | |
| data_outer = torch.unsqueeze(data_outer, 0) | |
| data_outer = normalize01(data_outer) | |
| data_outer = normalize(data_outer) | |
| data_inner = torch.from_numpy(data_inner) | |
| data_inner = torch.unsqueeze(data_inner, 0) | |
| data_inner = normalize01(data_inner) | |
| data_inner = normalize(data_inner) | |
| if self.isTrain: | |
| data_gtfake = torch.from_numpy(data_gtfake) | |
| data_gtfake = torch.unsqueeze(data_gtfake, 0) | |
| data_gtfake = normalize01(data_gtfake) | |
| data_gtfake = normalize(data_gtfake) | |
| image_path = self.outer_paths[index % self.dataset_size] | |
| if self.isTrain: | |
| return {'data_inner': data_inner, 'data_outer': data_outer, | |
| 'data_gtfake': data_gtfake, 'image_path': image_path} | |
| else: | |
| return {'data_inner': data_inner, 'data_outer': data_outer, 'image_path': image_path} | |
| def __len__(self): | |
| """Return the total number of images.""" | |
| return self.dataset_size | |