Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| from typing import Any, Dict, Sequence | |
| import fairseq.data.indexed_dataset as indexed_dataset | |
| import fairseq.options | |
| import fairseq.tasks.online_backtranslation as obt | |
| import torch | |
| from tests import utils | |
| def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]: | |
| batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size) | |
| sample = { | |
| "net_input": { | |
| "src_tokens": batch, | |
| "prev_output_tokens": batch, | |
| "src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long), | |
| }, | |
| "target": batch[:, 1:], | |
| } | |
| return sample | |
| def mk_dataset(num_samples: int, max_len: int, output: Path): | |
| output.parent.mkdir(exist_ok=True) | |
| idx = indexed_dataset.IndexedDatasetBuilder(str(output)) | |
| data = torch.randint(5, 100, (num_samples, max_len)) | |
| lengths = torch.randint(3, max_len, (num_samples,)) | |
| for d, l in zip(data, lengths): | |
| d[0] = 0 | |
| idx.add_item(d[:l]) | |
| idx.finalize(output.with_suffix(".idx")) | |
| assert output.exists() | |
| assert output.with_suffix(".idx").exists() | |
| class OnlineBacktranslationTest(unittest.TestCase): | |
| tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest")) | |
| def obt_task( | |
| cls, languages: Sequence[str], data: Path = None, language_mapping: str = None | |
| ): | |
| dict_path = cls.tmp_dir / "dict.txt" | |
| if not dict_path.exists(): | |
| dictionary = utils.dummy_dictionary(100) | |
| dictionary.save(str(dict_path)) | |
| if data is not None: | |
| (data / "dict.txt").write_text(dict_path.read_text()) | |
| else: | |
| data = cls.tmp_dir | |
| assert len(languages) >= 2 | |
| kwargs = { | |
| "arch": "transformer", | |
| # --max-sentences=1 for better predictability of batches | |
| "max_sentences": 1, | |
| # Use characteristics dimensions | |
| "encoder_layers": 3, | |
| "encoder_embed_dim": 12, | |
| "encoder_ffn_embed_dim": 14, | |
| "encoder_attention_heads": 4, | |
| "decoder_layers": 3, | |
| "decoder_embed_dim": 12, | |
| "decoder_output_dim": 12, | |
| "decoder_ffn_embed_dim": 14, | |
| "decoder_attention_heads": 4, | |
| # Disable dropout so we have comparable tests. | |
| "dropout": 0, | |
| "attention_dropout": 0, | |
| "activation_dropout": 0, | |
| "encoder_layerdrop": 0, | |
| } | |
| args = fairseq.options.get_args( | |
| data, | |
| task="online_backtranslation", | |
| mono_langs=",".join(languages), | |
| valid_lang_pairs=f"{languages[0]}-{languages[1]}", | |
| tokens_per_sample=256, | |
| language_mapping=language_mapping, | |
| **kwargs, | |
| ) | |
| task = obt.OnlineBackTranslationTask.setup_task(args) | |
| # we need to build the model to have the correct dictionary | |
| model = task.build_model(task.args) | |
| return task, model | |
| def tmp_path(self, test_case: str) -> Path: | |
| return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir)) | |
| def test_lang_tokens(self): | |
| task, model = self.obt_task(["en", "ro", "zh"]) | |
| assert obt._lang_token("en") in task.dictionary | |
| assert obt._lang_token("ro") in task.dictionary | |
| assert obt._lang_token("zh") in task.dictionary | |
| en_bos = obt._lang_token_index(task.common_dict, "en") | |
| assert "en" == task.common_dict[en_bos].strip("_") | |
| zh_bos = obt._lang_token_index(task.common_dict, "zh") | |
| assert "zh" == task.common_dict[zh_bos].strip("_") | |
| zh_sample = mk_sample([zh_bos, 16, 14, 12, 10]) | |
| # we expect to receive the bos token for translation | |
| assert task.get_bos_token_from_sample(zh_sample) == en_bos | |
| def test_backtranslate_sample(self): | |
| task, model = self.obt_task(["en", "ro", "zh"]) | |
| en_bos = obt._lang_token_index(task.common_dict, "en") | |
| zh_bos = obt._lang_token_index(task.common_dict, "zh") | |
| sample = mk_sample([zh_bos, 16, 14, 12, 10]) | |
| task.backtranslate_sample(sample, "zh", "en") | |
| target_zh = list(sample["target"][0]) | |
| assert target_zh == [16, 14, 12, 10] # original zh sentence | |
| generated_en = sample["net_input"]["src_tokens"][0] | |
| assert generated_en[0] == en_bos | |
| def test_train_dataset(self): | |
| data = self.tmp_path("test_train_dataset") | |
| mk_dataset(20, 10, data / "en" / "train.bin") | |
| mk_dataset(10, 10, data / "zh" / "train.bin") | |
| task, model = self.obt_task(["en", "zh"], data) | |
| task.load_dataset("train") | |
| en_bos = obt._lang_token_index(task.common_dict, "en") | |
| zh_bos = obt._lang_token_index(task.common_dict, "zh") | |
| train = task.datasets["train"] | |
| train.ordered_indices() | |
| train.prefetch([0, 19]) | |
| sample_0 = train[0] | |
| sample_19 = train[19] | |
| self.assertEqual( | |
| set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"} | |
| ) | |
| for sample in (sample_0, sample_19): | |
| self.assertEqual(sample["en-BT"]["source"][0], en_bos) | |
| # bt target isn't ready to look at. | |
| self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos) | |
| # TODO What could we check on the target side ? | |
| for i in range(10): | |
| # Zh dataset is shorter, and is wrapped around En dataset. | |
| train.prefetch([i, i + 10]) | |
| self.assertEqual( | |
| list(train[i]["zh-DENOISE"]["source"]), | |
| list(train[i + 10]["zh-DENOISE"]["source"]), | |
| ) | |
| self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos) | |
| # Sorted by increasing len | |
| self.assertLess( | |
| len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"]) | |
| ) | |
| def test_valid_dataset(self): | |
| data = self.tmp_path("test_valid_dataset") | |
| mk_dataset(10, 21, data / "valid.en-zh.en.bin") | |
| mk_dataset(10, 21, data / "valid.en-zh.zh.bin") | |
| task, model = self.obt_task(["en", "zh"], data) | |
| valid = task.load_dataset("valid") | |
| en_bos = obt._lang_token_index(task.common_dict, "en") | |
| assert valid is not None | |
| valid.prefetch(range(10)) | |
| sample_0 = valid[0] | |
| sample_9 = valid[9] | |
| self.assertEqual(sample_0["id"], 0) | |
| self.assertEqual(sample_9["id"], 9) | |
| self.assertEqual(sample_0["source"][0], en_bos) | |
| self.assertEqual(sample_9["source"][0], en_bos) | |
| # TODO: could we test the target side ? | |
| def assertFnMatch(self, fn, values): | |
| for x, y in values.items(): | |
| fn_x = fn(x) | |
| self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}") | |
| def test_piecewise_linear_fn(self): | |
| self.assertFnMatch( | |
| obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1} | |
| ) | |
| self.assertFnMatch( | |
| obt.PiecewiseLinearFn.from_string("0:1,1000:0"), | |
| {0: 1, 500: 0.5, 1000: 0, 2000: 0}, | |
| ) | |
| self.assertFnMatch( | |
| obt.PiecewiseLinearFn.from_string("0:0,1000:1"), | |
| {0: 0, 500: 0.5, 1000: 1, 2000: 1}, | |
| ) | |
| self.assertFnMatch( | |
| obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"), | |
| {0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0}, | |
| ) | |