update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import logging | |
| from typing import Optional, Union | |
| from pytorch_ie import PieDataModule | |
| from pytorch_ie.core.taskmodule import IterableTaskEncodingDataset, TaskEncodingDataset | |
| from torch.utils.data import DataLoader, Sampler | |
| from .components.sampler import ImbalancedDatasetSampler | |
| logger = logging.getLogger(__name__) | |
| class PieDataModuleWithSampler(PieDataModule): | |
| def __init__( | |
| self, | |
| train_sampler: Optional[str] = None, | |
| dont_shuffle_train: bool = False, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.train_sampler_name = train_sampler | |
| self.dont_shuffle_train = dont_shuffle_train | |
| def get_train_sampler( | |
| self, | |
| dataset: Union[TaskEncodingDataset, IterableTaskEncodingDataset], | |
| ) -> Optional[Sampler]: | |
| if self.train_sampler_name is None: | |
| return None | |
| elif self.train_sampler_name == "imbalanced_dataset": | |
| # for now, this work only with targets that have a single entry | |
| return ImbalancedDatasetSampler( | |
| dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds] | |
| ) | |
| else: | |
| raise ValueError(f"unknown sampler name: {self.train_sampler_name}") | |
| def train_dataloader(self) -> DataLoader: | |
| ds = self.data_split(self.train_split) | |
| sampler = self.get_train_sampler(dataset=ds) | |
| # don't shuffle if we explicitly set dont_shuffle_train, | |
| # streamed datasets or if we use a sampler or | |
| shuffle = not ( | |
| self.dont_shuffle_train | |
| or isinstance(ds, IterableTaskEncodingDataset) | |
| or sampler is not None | |
| ) | |
| if not shuffle: | |
| logger.warning("not shuffling train dataloader") | |
| return DataLoader( | |
| dataset=ds, | |
| sampler=sampler, | |
| collate_fn=self.taskmodule.collate, | |
| shuffle=shuffle, | |
| **self.dataloader_kwargs, | |
| ) | |