import os from torch.utils.data import Dataset, DataLoader from diffusers.image_processor import VaeImageProcessor from tqdm import tqdm from PIL import Image, ImageFilter class InferenceDataset(Dataset): def __init__(self, args): self.args = args self.vae_processor = VaeImageProcessor(vae_scale_factor=8) self.mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True) self.data = self.load_data() def load_data(self): return [] def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] person, cloth, mask = [Image.open(data[key]) for key in ['person', 'cloth', 'mask']] return { 'index': idx, 'person_name': data['person_name'], 'person': self.vae_processor.preprocess(person, self.args.height, self.args.width)[0], 'cloth': self.vae_processor.preprocess(cloth, self.args.height, self.args.width)[0], 'mask': self.mask_processor.preprocess(mask, self.args.height, self.args.width)[0] } class VITONHDTestDataset(InferenceDataset): def load_data(self): name= "train" if self.args.is_train else "samples" assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, f'{name}_pairs.txt')), f"File {pair_txt} does not exist." with open(pair_txt, 'r') as f: lines = f.readlines() self.args.data_root_path = os.path.join(self.args.data_root_path, name) output_dir = os.path.join(self.args.output_dir, "vitonhd", 'unpaired' if not self.args.eval_pair else 'paired') data = [] for line in lines: person_img, cloth_img = line.strip().split(" ") if os.path.exists(os.path.join(output_dir, person_img)): continue if self.args.eval_pair: cloth_img = person_img # print(f"Loading {person_img} and {cloth_img}...") data.append({ 'person_name': person_img, 'person': os.path.join(self.args.data_root_path, 'image', person_img), 'cloth': os.path.join(self.args.data_root_path, 'cloth', cloth_img), 'mask': os.path.join(self.args.data_root_path, 'agnostic-mask', person_img.replace('.jpg', '_mask.png')), }) return data