Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import ast | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import sys | |
| import time | |
| from dataclasses import dataclass | |
| from multiprocessing import Value | |
| import braceexpand | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torchvision.datasets as datasets | |
| import webdataset as wds | |
| from PIL import Image | |
| from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info | |
| from torch.utils.data.distributed import DistributedSampler | |
| from webdataset.filters import _shuffle | |
| from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample | |
| try: | |
| import horovod.torch as hvd | |
| except ImportError: | |
| hvd = None | |
| from clip import tokenize | |
| class CsvDataset(Dataset): | |
| def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): | |
| logging.debug(f'Loading csv data from {input_filename}.') | |
| df = pd.read_csv(input_filename, sep=sep) | |
| self.images = df[img_key].tolist() | |
| self.captions = df[caption_key].tolist() | |
| self.transforms = transforms | |
| logging.debug('Done loading data.') | |
| def __len__(self): | |
| return len(self.captions) | |
| def __getitem__(self, idx): | |
| images = self.transforms(Image.open(str(self.images[idx]))) | |
| texts = tokenize([str(self.captions[idx])])[0] | |
| return images, texts | |
| class SharedEpoch: | |
| def __init__(self, epoch: int = 0): | |
| self.shared_epoch = Value('i', epoch) | |
| def set_value(self, epoch): | |
| self.shared_epoch.value = epoch | |
| def get_value(self): | |
| return self.shared_epoch.value | |
| class DataInfo: | |
| dataloader: DataLoader | |
| sampler: DistributedSampler = None | |
| shared_epoch: SharedEpoch = None | |
| def set_epoch(self, epoch): | |
| if self.shared_epoch is not None: | |
| self.shared_epoch.set_value(epoch) | |
| if self.sampler is not None and isinstance(self.sampler, DistributedSampler): | |
| self.sampler.set_epoch(epoch) | |
| def preprocess_txt(text): | |
| return tokenize([str(text)])[0] | |
| def get_dataset_size(shards): | |
| shards_list = list(braceexpand.braceexpand(shards)) | |
| dir_path = os.path.dirname(shards) | |
| sizes_filename = os.path.join(dir_path, 'sizes.json') | |
| len_filename = os.path.join(dir_path, '__len__') | |
| if os.path.exists(sizes_filename): | |
| sizes = json.load(open(sizes_filename, 'r')) | |
| total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) | |
| elif os.path.exists(len_filename): | |
| # FIXME this used to be eval(open(...)) but that seemed rather unsafe | |
| total_size = ast.literal_eval(open(len_filename, 'r').read()) | |
| else: | |
| total_size = None # num samples undefined | |
| # some common dataset sizes (at time of authors last download) | |
| # CC3M (train): 2905954 | |
| # CC12M: 10968539 | |
| # LAION-400M: 407332084 | |
| # LAION-2B (english): 2170337258 | |
| num_shards = len(shards_list) | |
| return total_size, num_shards | |
| def get_imagenet(args, preprocess_fns, split): | |
| assert split in ["train", "val", "v2"] | |
| is_train = split == "train" | |
| preprocess_train, preprocess_val = preprocess_fns | |
| if split == "v2": | |
| from imagenetv2_pytorch import ImageNetV2Dataset | |
| dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) | |
| else: | |
| if is_train: | |
| data_path = args.imagenet_train | |
| preprocess_fn = preprocess_train | |
| else: | |
| data_path = args.imagenet_val | |
| preprocess_fn = preprocess_val | |
| assert data_path | |
| dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) | |
| if is_train: | |
| idxs = np.zeros(len(dataset.targets)) | |
| target_array = np.array(dataset.targets) | |
| k = 50 | |
| for c in range(1000): | |
| m = target_array == c | |
| n = len(idxs[m]) | |
| arr = np.zeros(n) | |
| arr[:k] = 1 | |
| np.random.shuffle(arr) | |
| idxs[m] = arr | |
| idxs = idxs.astype('int') | |
| sampler = SubsetRandomSampler(np.where(idxs)[0]) | |
| else: | |
| sampler = None | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| num_workers=args.workers, | |
| sampler=sampler, | |
| ) | |
| return DataInfo(dataloader=dataloader, sampler=sampler) | |
| def count_samples(dataloader): | |
| os.environ["WDS_EPOCH"] = "0" | |
| n_elements, n_batches = 0, 0 | |
| for images, texts in dataloader: | |
| n_batches += 1 | |
| n_elements += len(images) | |
| assert len(images) == len(texts) | |
| return n_elements, n_batches | |
| def filter_no_caption(sample): | |
| return 'txt' in sample | |
| def log_and_continue(exn): | |
| """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" | |
| logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') | |
| return True | |
| def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): | |
| """Return function over iterator that groups key, value pairs into samples. | |
| :param keys: function that splits the key into key and extension (base_plus_ext) | |
| :param lcase: convert suffixes to lower case (Default value = True) | |
| """ | |
| current_sample = None | |
| for filesample in data: | |
| assert isinstance(filesample, dict) | |
| fname, value = filesample["fname"], filesample["data"] | |
| prefix, suffix = keys(fname) | |
| if prefix is None: | |
| continue | |
| if lcase: | |
| suffix = suffix.lower() | |
| # FIXME webdataset version throws if suffix in current_sample, but we have a potential for | |
| # this happening in the current LAION400m dataset if a tar ends with same prefix as the next | |
| # begins, rare, but can happen since prefix aren't unique across tar files in that dataset | |
| if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: | |
| if valid_sample(current_sample): | |
| yield current_sample | |
| current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
| if suffixes is None or suffix in suffixes: | |
| current_sample[suffix] = value | |
| if valid_sample(current_sample): | |
| yield current_sample | |
| def tarfile_to_samples_nothrow(src, handler=log_and_continue): | |
| # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw | |
| streams = url_opener(src, handler=handler) | |
| files = tar_file_expander(streams, handler=handler) | |
| samples = group_by_keys_nothrow(files, handler=handler) | |
| return samples | |
| def pytorch_worker_seed(): | |
| """get dataloader worker seed from pytorch""" | |
| worker_info = get_worker_info() | |
| if worker_info is not None: | |
| # favour the seed already created for pytorch dataloader workers if it exists | |
| return worker_info.seed | |
| # fallback to wds rank based seed | |
| return wds.utils.pytorch_worker_seed() | |
| _SHARD_SHUFFLE_SIZE = 2000 | |
| _SHARD_SHUFFLE_INITIAL = 500 | |
| _SAMPLE_SHUFFLE_SIZE = 5000 | |
| _SAMPLE_SHUFFLE_INITIAL = 1000 | |
| class detshuffle2(wds.PipelineStage): | |
| def __init__( | |
| self, | |
| bufsize=1000, | |
| initial=100, | |
| seed=0, | |
| epoch=-1, | |
| ): | |
| self.bufsize = bufsize | |
| self.initial = initial | |
| self.seed = seed | |
| self.epoch = epoch | |
| def run(self, src): | |
| if isinstance(self.epoch, SharedEpoch): | |
| epoch = self.epoch.get_value() | |
| else: | |
| # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
| # situation as different workers may wrap at different times (or not at all). | |
| self.epoch += 1 | |
| epoch = self.epoch | |
| rng = random.Random() | |
| if self.seed < 0: | |
| seed = pytorch_worker_seed() + epoch | |
| else: | |
| seed = self.seed + epoch | |
| rng.seed(seed) | |
| return _shuffle(src, self.bufsize, self.initial, rng) | |
| class ResampledShards2(IterableDataset): | |
| """An iterable dataset yielding a list of urls.""" | |
| def __init__( | |
| self, | |
| urls, | |
| nshards=sys.maxsize, | |
| worker_seed=None, | |
| deterministic=False, | |
| epoch=-1, | |
| ): | |
| """Sample shards from the shard list with replacement. | |
| :param urls: a list of URLs as a Python list or brace notation string | |
| """ | |
| super().__init__() | |
| urls = wds.shardlists.expand_urls(urls) | |
| self.urls = urls | |
| assert isinstance(self.urls[0], str) | |
| self.nshards = nshards | |
| self.rng = random.Random() | |
| self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed | |
| self.deterministic = deterministic | |
| self.epoch = epoch | |
| def __iter__(self): | |
| """Return an iterator over the shards.""" | |
| if isinstance(self.epoch, SharedEpoch): | |
| epoch = self.epoch.get_value() | |
| else: | |
| # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
| # situation as different workers may wrap at different times (or not at all). | |
| self.epoch += 1 | |
| epoch = self.epoch | |
| if self.deterministic: | |
| # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed | |
| self.rng.seed(self.worker_seed() + epoch) | |
| for _ in range(self.nshards): | |
| yield dict(url=self.rng.choice(self.urls)) | |
| def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): | |
| input_shards = args.train_data if is_train else args.val_data | |
| assert input_shards is not None | |
| resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| num_samples, num_shards = get_dataset_size(input_shards) | |
| if not num_samples: | |
| if is_train: | |
| num_samples = args.train_num_samples | |
| if not num_samples: | |
| raise RuntimeError( | |
| 'Currently, number of dataset samples must be specified for training dataset. ' | |
| 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| else: | |
| num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| if resampled: | |
| pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)] | |
| else: | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| # at this point we have an iterator over all the shards | |
| if is_train: | |
| if not resampled: | |
| pipeline.extend([ | |
| detshuffle2( | |
| bufsize=_SHARD_SHUFFLE_SIZE, | |
| initial=_SHARD_SHUFFLE_INITIAL, | |
| seed=args.seed, | |
| epoch=shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| pipeline.extend([ | |
| # at this point, we have an iterator over the shards assigned to each worker at each node | |
| tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.shuffle( | |
| bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| initial=_SAMPLE_SHUFFLE_INITIAL, | |
| ), | |
| ]) | |
| else: | |
| pipeline.extend([ | |
| wds.split_by_worker, | |
| # at this point, we have an iterator over the shards assigned to each worker | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| ]) | |
| pipeline.extend([ | |
| wds.select(filter_no_caption), | |
| wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png", text="txt"), | |
| wds.map_dict(image=preprocess_img, text=preprocess_txt), | |
| wds.to_tuple("image", "text"), | |
| wds.batched(args.batch_size, partial=not is_train), | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| if is_train: | |
| if not resampled: | |
| assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # roll over and repeat a few samples to get same number of full batches on each node | |
| round_fn = math.floor if floor else math.ceil | |
| global_batch_size = args.batch_size * args.world_size | |
| num_batches = round_fn(num_samples / global_batch_size) | |
| num_workers = max(1, args.workers) | |
| num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| else: | |
| # last batches are partial, eval is done on single (master) node | |
| num_batches = math.ceil(num_samples / args.batch_size) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| ) | |
| # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # if is_train: | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = math.ceil(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataloader = dataloader.with_epoch(num_batches) | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # add meta-data to dataloader instance for convenience | |
| dataloader.num_batches = num_batches | |
| dataloader.num_samples = num_samples | |
| return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| def get_csv_dataset(args, preprocess_fn, is_train, epoch=0): | |
| input_filename = args.train_data if is_train else args.val_data | |
| assert input_filename | |
| dataset = CsvDataset( | |
| input_filename, | |
| preprocess_fn, | |
| img_key=args.csv_img_key, | |
| caption_key=args.csv_caption_key, | |
| sep=args.csv_separator) | |
| num_samples = len(dataset) | |
| sampler = DistributedSampler(dataset) if args.distributed and is_train else None | |
| shuffle = is_train and sampler is None | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=shuffle, | |
| num_workers=args.workers, | |
| pin_memory=True, | |
| sampler=sampler, | |
| drop_last=is_train, | |
| ) | |
| dataloader.num_samples = num_samples | |
| dataloader.num_batches = len(dataloader) | |
| return DataInfo(dataloader, sampler) | |
| def get_dataset_fn(data_path, dataset_type): | |
| if dataset_type == "webdataset": | |
| return get_wds_dataset | |
| elif dataset_type == "csv": | |
| return get_csv_dataset | |
| elif dataset_type == "auto": | |
| ext = data_path.split('.')[-1] | |
| if ext in ['csv', 'tsv']: | |
| return get_csv_dataset | |
| elif ext in ['tar']: | |
| return get_wds_dataset | |
| else: | |
| raise ValueError( | |
| f"Tried to figure out dataset type, but failed for extention {ext}.") | |
| else: | |
| raise ValueError(f"Unsupported dataset type: {dataset_type}") | |
| def get_data(args, preprocess_fns, epoch=0): | |
| preprocess_train, preprocess_val = preprocess_fns | |
| data = {} | |
| if args.train_data: | |
| data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( | |
| args, preprocess_train, is_train=True, epoch=epoch) | |
| if args.val_data: | |
| data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( | |
| args, preprocess_val, is_train=False) | |
| if args.imagenet_val is not None: | |
| data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") | |
| if args.imagenet_v2 is not None: | |
| data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") | |
| return data | |