Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import pytorch_lightning as pl | |
| from omegaconf import OmegaConf | |
| from functools import partial | |
| from ldm.util import instantiate_from_config | |
| from torch.utils.data import random_split, DataLoader, Dataset, Subset | |
| class WrappedDataset(Dataset): | |
| """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" | |
| def __init__(self, dataset): | |
| self.data = dataset | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| class DataModuleFromConfig(pl.LightningDataModule): | |
| def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, | |
| wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, | |
| shuffle_val_dataloader=False): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.dataset_configs = dict() | |
| self.num_workers = num_workers if num_workers is not None else batch_size * 2 | |
| self.use_worker_init_fn = use_worker_init_fn | |
| if train is not None: | |
| self.dataset_configs["train"] = train | |
| self.train_dataloader = self._train_dataloader | |
| if validation is not None: | |
| self.dataset_configs["validation"] = validation | |
| self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) | |
| if test is not None: | |
| self.dataset_configs["test"] = test | |
| self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) | |
| if predict is not None: | |
| self.dataset_configs["predict"] = predict | |
| self.predict_dataloader = self._predict_dataloader | |
| self.wrap = wrap | |
| def prepare_data(self): | |
| for data_cfg in self.dataset_configs.values(): | |
| instantiate_from_config(data_cfg) | |
| def setup(self, stage=None): | |
| self.datasets = dict( | |
| (k, instantiate_from_config(self.dataset_configs[k])) | |
| for k in self.dataset_configs) | |
| if self.wrap: | |
| for k in self.datasets: | |
| self.datasets[k] = WrappedDataset(self.datasets[k]) | |
| def _train_dataloader(self): | |
| init_fn = None | |
| return DataLoader(self.datasets["train"], batch_size=self.batch_size, | |
| num_workers=self.num_workers, shuffle= True, | |
| worker_init_fn=init_fn) | |
| def _val_dataloader(self, shuffle=False): | |
| init_fn = None | |
| return DataLoader(self.datasets["validation"], | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| worker_init_fn=init_fn, | |
| shuffle=shuffle) | |
| def _test_dataloader(self, shuffle=False): | |
| is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) | |
| if is_iterable_dataset or self.use_worker_init_fn: | |
| init_fn = worker_init_fn | |
| else: | |
| init_fn = None | |
| # do not shuffle dataloader for iterable dataset | |
| shuffle = shuffle and (not is_iterable_dataset) | |
| return DataLoader(self.datasets["test"], batch_size=self.batch_size, | |
| num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) | |
| def _predict_dataloader(self, shuffle=False): | |
| if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: | |
| init_fn = worker_init_fn | |
| else: | |
| init_fn = None | |
| return DataLoader(self.datasets["predict"], batch_size=self.batch_size, | |
| num_workers=self.num_workers, worker_init_fn=init_fn) | |
| def create_data(config): | |
| data = instantiate_from_config(config.data) | |
| # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html | |
| # calling these ourselves should not be necessary but it is. | |
| # lightning still takes care of proper multiprocessing though | |
| data.prepare_data() | |
| data.setup() | |
| return data |