Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import collections | |
| import unittest | |
| import numpy as np | |
| from fairseq.data import ListDataset, ResamplingDataset | |
| class TestResamplingDataset(unittest.TestCase): | |
| def setUp(self): | |
| self.strings = ["ab", "c", "def", "ghij"] | |
| self.weights = [4.0, 2.0, 7.0, 1.5] | |
| self.size_ratio = 2 | |
| self.dataset = ListDataset( | |
| self.strings, np.array([len(s) for s in self.strings]) | |
| ) | |
| def _test_common(self, resampling_dataset, iters): | |
| assert len(self.dataset) == len(self.strings) == len(self.weights) | |
| assert len(resampling_dataset) == self.size_ratio * len(self.strings) | |
| results = {"ordered_by_size": True, "max_distribution_diff": 0.0} | |
| totalfreqs = 0 | |
| freqs = collections.defaultdict(int) | |
| for epoch_num in range(iters): | |
| resampling_dataset.set_epoch(epoch_num) | |
| indices = resampling_dataset.ordered_indices() | |
| assert len(indices) == len(resampling_dataset) | |
| prev_size = -1 | |
| for i in indices: | |
| cur_size = resampling_dataset.size(i) | |
| # Make sure indices map to same sequences within an epoch | |
| assert resampling_dataset[i] == resampling_dataset[i] | |
| # Make sure length of sequence is correct | |
| assert cur_size == len(resampling_dataset[i]) | |
| freqs[resampling_dataset[i]] += 1 | |
| totalfreqs += 1 | |
| if prev_size > cur_size: | |
| results["ordered_by_size"] = False | |
| prev_size = cur_size | |
| assert set(freqs.keys()) == set(self.strings) | |
| for s, weight in zip(self.strings, self.weights): | |
| freq = freqs[s] / totalfreqs | |
| expected_freq = weight / sum(self.weights) | |
| results["max_distribution_diff"] = max( | |
| results["max_distribution_diff"], abs(expected_freq - freq) | |
| ) | |
| return results | |
| def test_resampling_dataset_batch_by_size_false(self): | |
| resampling_dataset = ResamplingDataset( | |
| self.dataset, | |
| self.weights, | |
| size_ratio=self.size_ratio, | |
| batch_by_size=False, | |
| seed=0, | |
| ) | |
| results = self._test_common(resampling_dataset, iters=1000) | |
| # For batch_by_size = False, the batches should be returned in | |
| # arbitrary order of size. | |
| assert not results["ordered_by_size"] | |
| # Allow tolerance in distribution error of 2%. | |
| assert results["max_distribution_diff"] < 0.02 | |
| def test_resampling_dataset_batch_by_size_true(self): | |
| resampling_dataset = ResamplingDataset( | |
| self.dataset, | |
| self.weights, | |
| size_ratio=self.size_ratio, | |
| batch_by_size=True, | |
| seed=0, | |
| ) | |
| results = self._test_common(resampling_dataset, iters=1000) | |
| # For batch_by_size = True, the batches should be returned in | |
| # increasing order of size. | |
| assert results["ordered_by_size"] | |
| # Allow tolerance in distribution error of 2%. | |
| assert results["max_distribution_diff"] < 0.02 | |
| if __name__ == "__main__": | |
| unittest.main() | |