Spaces:
Running
Running
| import os | |
| from torch.utils.data import Dataset, DataLoader | |
| from diffusers.image_processor import VaeImageProcessor | |
| from tqdm import tqdm | |
| from PIL import Image, ImageFilter | |
| class InferenceDataset(Dataset): | |
| def __init__(self, args): | |
| self.args = args | |
| self.vae_processor = VaeImageProcessor(vae_scale_factor=8) | |
| self.mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True) | |
| self.data = self.load_data() | |
| def load_data(self): | |
| return [] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| data = self.data[idx] | |
| person, cloth, mask = [Image.open(data[key]) for key in ['person', 'cloth', 'mask']] | |
| return { | |
| 'index': idx, | |
| 'person_name': data['person_name'], | |
| 'person': self.vae_processor.preprocess(person, self.args.height, self.args.width)[0], | |
| 'cloth': self.vae_processor.preprocess(cloth, self.args.height, self.args.width)[0], | |
| 'mask': self.mask_processor.preprocess(mask, self.args.height, self.args.width)[0] | |
| } | |
| class VITONHDTestDataset(InferenceDataset): | |
| def load_data(self): | |
| name= "train" if self.args.is_train else "samples" | |
| assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, f'{name}_pairs.txt')), f"File {pair_txt} does not exist." | |
| with open(pair_txt, 'r') as f: | |
| lines = f.readlines() | |
| self.args.data_root_path = os.path.join(self.args.data_root_path, name) | |
| output_dir = os.path.join(self.args.output_dir, "vitonhd", 'unpaired' if not self.args.eval_pair else 'paired') | |
| data = [] | |
| for line in lines: | |
| person_img, cloth_img = line.strip().split(" ") | |
| if os.path.exists(os.path.join(output_dir, person_img)): | |
| continue | |
| if self.args.eval_pair: | |
| cloth_img = person_img | |
| # print(f"Loading {person_img} and {cloth_img}...") | |
| data.append({ | |
| 'person_name': person_img, | |
| 'person': os.path.join(self.args.data_root_path, 'image', person_img), | |
| 'cloth': os.path.join(self.args.data_root_path, 'cloth', cloth_img), | |
| 'mask': os.path.join(self.args.data_root_path, 'agnostic-mask', person_img.replace('.jpg', '_mask.png')), | |
| }) | |
| return data | |