Spaces:
Build error
Build error
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval | |
| from data.nocaps_dataset import nocaps_eval | |
| from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval | |
| from data.vqa_dataset import vqa_dataset | |
| from data.nlvr_dataset import nlvr_dataset | |
| from data.pretrain_dataset import pretrain_dataset | |
| from transform.randaugment import RandomAugment | |
| def create_dataset(dataset, config, min_scale=0.5): | |
| normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
| transform_train = transforms.Compose([ | |
| transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), | |
| transforms.RandomHorizontalFlip(), | |
| RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', | |
| 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), | |
| transforms.ToTensor(), | |
| normalize, | |
| ]) | |
| transform_test = transforms.Compose([ | |
| transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| normalize, | |
| ]) | |
| if dataset=='pretrain': | |
| dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) | |
| return dataset | |
| elif dataset=='caption_coco': | |
| train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) | |
| val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') | |
| test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') | |
| return train_dataset, val_dataset, test_dataset | |
| elif dataset=='nocaps': | |
| val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') | |
| test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') | |
| return val_dataset, test_dataset | |
| elif dataset=='retrieval_coco': | |
| train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) | |
| val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') | |
| test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') | |
| return train_dataset, val_dataset, test_dataset | |
| elif dataset=='retrieval_flickr': | |
| train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) | |
| val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') | |
| test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') | |
| return train_dataset, val_dataset, test_dataset | |
| elif dataset=='vqa': | |
| train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], | |
| train_files = config['train_files'], split='train') | |
| test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') | |
| return train_dataset, test_dataset | |
| elif dataset=='nlvr': | |
| train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') | |
| val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') | |
| test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') | |
| return train_dataset, val_dataset, test_dataset | |
| def create_sampler(datasets, shuffles, num_tasks, global_rank): | |
| samplers = [] | |
| for dataset,shuffle in zip(datasets,shuffles): | |
| sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) | |
| samplers.append(sampler) | |
| return samplers | |
| def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): | |
| loaders = [] | |
| for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): | |
| if is_train: | |
| shuffle = (sampler is None) | |
| drop_last = True | |
| else: | |
| shuffle = False | |
| drop_last = False | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=bs, | |
| num_workers=n_worker, | |
| pin_memory=True, | |
| sampler=sampler, | |
| shuffle=shuffle, | |
| collate_fn=collate_fn, | |
| drop_last=drop_last, | |
| ) | |
| loaders.append(loader) | |
| return loaders | |