Spaces:
Runtime error
Runtime error
| from typing import Dict, List, Optional, NoReturn | |
| import torch | |
| import lightning.pytorch as pl | |
| from torch.utils.data import DataLoader | |
| from data.audiotext_dataset import AudioTextDataset | |
| class DataModule(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| train_dataset: object, | |
| batch_size: int, | |
| num_workers: int | |
| ): | |
| r"""Data module. To get one batch of data: | |
| code-block:: python | |
| data_module.setup() | |
| for batch_data_dict in data_module.train_dataloader(): | |
| print(batch_data_dict.keys()) | |
| break | |
| Args: | |
| train_sampler: Sampler object | |
| train_dataset: Dataset object | |
| num_workers: int | |
| distributed: bool | |
| """ | |
| super().__init__() | |
| self._train_dataset = train_dataset | |
| self.num_workers = num_workers | |
| self.batch_size = batch_size | |
| self.collate_fn = collate_fn | |
| def prepare_data(self): | |
| # download, split, etc... | |
| # only called on 1 GPU/TPU in distributed | |
| pass | |
| def setup(self, stage: Optional[str] = None) -> NoReturn: | |
| r"""called on every device.""" | |
| # make assignments here (val/train/test split) | |
| # called on every process in DDP | |
| # SegmentSampler is used for selecting segments for training. | |
| # On multiple devices, each SegmentSampler samples a part of mini-batch | |
| # data. | |
| self.train_dataset = self._train_dataset | |
| def train_dataloader(self) -> torch.utils.data.DataLoader: | |
| r"""Get train loader.""" | |
| train_loader = DataLoader( | |
| dataset=self.train_dataset, | |
| batch_size=self.batch_size, | |
| collate_fn=self.collate_fn, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| persistent_workers=False, | |
| shuffle=True | |
| ) | |
| return train_loader | |
| def val_dataloader(self): | |
| # val_split = Dataset(...) | |
| # return DataLoader(val_split) | |
| pass | |
| def test_dataloader(self): | |
| # test_split = Dataset(...) | |
| # return DataLoader(test_split) | |
| pass | |
| def teardown(self): | |
| # clean up after fit or test | |
| # called on every process in DDP | |
| pass | |
| def collate_fn(list_data_dict): | |
| r"""Collate mini-batch data to inputs and targets for training. | |
| Args: | |
| list_data_dict: e.g., [ | |
| { | |
| 'text': 'a sound of dog', | |
| 'waveform': (1, samples), | |
| 'modality': 'audio_text' | |
| } | |
| ... | |
| ] | |
| Returns: | |
| data_dict: e.g. | |
| 'audio_text': { | |
| 'text': ['a sound of dog', ...] | |
| 'waveform': (batch_size, 1, samples) | |
| } | |
| """ | |
| at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text'] | |
| at_data_dict = {} | |
| if len(at_list_data_dict) > 0: | |
| for key in at_list_data_dict[0].keys(): | |
| at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict] | |
| if key == 'waveform': | |
| at_data_dict[key] = torch.stack(at_data_dict[key]) | |
| elif key == 'text': | |
| at_data_dict[key] = [text for text in at_data_dict[key]] | |
| data_dict = { | |
| 'audio_text': at_data_dict | |
| } | |
| return data_dict |