Spaces:
Build error
Build error
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import shutil | |
| import tempfile | |
| import unittest | |
| from typing import Optional | |
| class TestFileChunker(unittest.TestCase): | |
| _tmpdir: Optional[str] = None | |
| _tmpfile: Optional[str] = None | |
| _line_content = "Hello, World\n" | |
| _num_bytes = None | |
| _num_lines = 200 | |
| _num_splits = 20 | |
| def setUpClass(cls) -> None: | |
| cls._num_bytes = len(cls._line_content.encode("utf-8")) | |
| cls._tmpdir = tempfile.mkdtemp() | |
| with open(os.path.join(cls._tmpdir, "test.txt"), "w") as f: | |
| cls._tmpfile = f.name | |
| for _i in range(cls._num_lines): | |
| f.write(cls._line_content) | |
| f.flush() | |
| def tearDownClass(cls) -> None: | |
| # Cleanup temp working dir. | |
| if cls._tmpdir is not None: | |
| shutil.rmtree(cls._tmpdir) # type: ignore | |
| def test_find_offsets(self): | |
| from fairseq.file_chunker_utils import find_offsets | |
| offsets = find_offsets(self._tmpfile, self._num_splits) | |
| self.assertEqual(len(offsets), self._num_splits + 1) | |
| (zero, *real_offsets, last) = offsets | |
| self.assertEqual(zero, 0) | |
| for i, o in enumerate(real_offsets): | |
| self.assertEqual( | |
| o, | |
| self._num_bytes | |
| + ((i + 1) * self._num_bytes * self._num_lines / self._num_splits), | |
| ) | |
| self.assertEqual(last, self._num_bytes * self._num_lines) | |
| def test_readchunks(self): | |
| from fairseq.file_chunker_utils import Chunker, find_offsets | |
| offsets = find_offsets(self._tmpfile, self._num_splits) | |
| for start, end in zip(offsets, offsets[1:]): | |
| with Chunker(self._tmpfile, start, end) as lines: | |
| all_lines = list(lines) | |
| num_lines = self._num_lines / self._num_splits | |
| self.assertAlmostEqual( | |
| len(all_lines), num_lines, delta=1 | |
| ) # because we split on the bites, we might end up with one more/less line in a chunk | |
| self.assertListEqual( | |
| all_lines, [self._line_content for _ in range(len(all_lines))] | |
| ) | |