Virtual-Cloths-TryOn / VITON_Dataset.py
harsh99's picture
bug fixes
569254a
import os
from torch.utils.data import Dataset, DataLoader
from diffusers.image_processor import VaeImageProcessor
from tqdm import tqdm
from PIL import Image, ImageFilter
class InferenceDataset(Dataset):
def __init__(self, args):
self.args = args
self.vae_processor = VaeImageProcessor(vae_scale_factor=8)
self.mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
self.data = self.load_data()
def load_data(self):
return []
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
person, cloth, mask = [Image.open(data[key]) for key in ['person', 'cloth', 'mask']]
return {
'index': idx,
'person_name': data['person_name'],
'person': self.vae_processor.preprocess(person, self.args.height, self.args.width)[0],
'cloth': self.vae_processor.preprocess(cloth, self.args.height, self.args.width)[0],
'mask': self.mask_processor.preprocess(mask, self.args.height, self.args.width)[0]
}
class VITONHDTestDataset(InferenceDataset):
def load_data(self):
name= "train" if self.args.is_train else "samples"
assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, f'{name}_pairs.txt')), f"File {pair_txt} does not exist."
with open(pair_txt, 'r') as f:
lines = f.readlines()
self.args.data_root_path = os.path.join(self.args.data_root_path, name)
output_dir = os.path.join(self.args.output_dir, "vitonhd", 'unpaired' if not self.args.eval_pair else 'paired')
data = []
for line in lines:
person_img, cloth_img = line.strip().split(" ")
if os.path.exists(os.path.join(output_dir, person_img)):
continue
if self.args.eval_pair:
cloth_img = person_img
# print(f"Loading {person_img} and {cloth_img}...")
data.append({
'person_name': person_img,
'person': os.path.join(self.args.data_root_path, 'image', person_img),
'cloth': os.path.join(self.args.data_root_path, 'cloth', cloth_img),
'mask': os.path.join(self.args.data_root_path, 'agnostic-mask', person_img.replace('.jpg', '_mask.png')),
})
return data