Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| from accelerate import Accelerator | |
| import torch.utils | |
| from torch.utils.data import BatchSampler, Sampler | |
| import torch.utils.data | |
| class CustomRandomSampler(Sampler): | |
| """Random sampling under a constraint: each sample in the batch has the same feature, | |
| which is chosen randomly from a known pool of 'features' for each batch. | |
| For instance, the 'feature' could be the image aspect-ratio. | |
| The index returned is a tuple (sample_idx, feat_idx). | |
| This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. | |
| """ | |
| def __init__( | |
| self, | |
| dataset, | |
| batch_size, | |
| pool_size, | |
| min_view_size, | |
| max_view_size, | |
| world_size, | |
| warmup=1, | |
| drop_last=True, | |
| ): | |
| self.batch_size = batch_size | |
| self.pool_size = pool_size | |
| self.min_view_size = min_view_size | |
| self.max_view_size = max_view_size | |
| self.drop_last = drop_last | |
| self.len_dataset = N = len(dataset) | |
| self.total_size = N | |
| self.epoch = None | |
| self.epochf = 0.0 | |
| def __len__(self): | |
| return self.total_size | |
| def set_epoch(self, epoch): | |
| self.epoch = epoch | |
| def __iter__(self): | |
| if self.epoch is None: | |
| raise ValueError( | |
| "Epoch number not set. Please call 'set_epoch(epoch)' before iterating." | |
| ) | |
| seed = self.epoch + 788 | |
| rng = np.random.default_rng(seed=seed) | |
| # random indices (will restart from 0 if not drop_last) | |
| sample_idxs = np.arange(self.total_size) | |
| rng.shuffle(sample_idxs) | |
| # random feat_idxs (same across each batch) | |
| n_batches = (self.total_size + self.batch_size - 1) // self.batch_size | |
| if self.pool_size > 1: | |
| p = np.ones(self.pool_size) | |
| p[: self.pool_size // 2] *= 2 | |
| p = p / p.sum() | |
| _feat_idxs = rng.choice(self.pool_size, size=n_batches, p=p) | |
| else: | |
| _feat_idxs = rng.integers(self.pool_size, size=n_batches) | |
| _feat_idxs = np.broadcast_to(_feat_idxs[:, None], (n_batches, self.batch_size)) | |
| _feat_idxs = _feat_idxs.ravel()[: self.total_size] | |
| _view_idxs = rng.integers( | |
| self.min_view_size, self.max_view_size + 1, size=n_batches | |
| ) | |
| _view_idxs = np.broadcast_to(_view_idxs[:, None], (n_batches, self.batch_size)) | |
| _view_idxs = _view_idxs.ravel()[: self.total_size] | |
| idxs = np.c_[sample_idxs, _feat_idxs, _view_idxs] | |
| yield from (tuple(idx) for idx in idxs) | |
| class BatchedRandomSampler(BatchSampler): | |
| """Batch sampler that groups indices from RandomSampler into batches.""" | |
| def __init__(self, sampler: CustomRandomSampler, batch_size, drop_last=True): | |
| self.sampler = sampler # An instance of RandomSampler | |
| self.batch_size = batch_size | |
| self.drop_last = drop_last | |
| def set_epoch(self, epoch): | |
| self.sampler.set_epoch(epoch) | |
| def round_by(total, multiple, up=False): | |
| if up: | |
| total = total + multiple - 1 | |
| return (total // multiple) * multiple | |