Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import contextlib | |
| import json | |
| import os | |
| import tempfile | |
| import unittest | |
| from io import StringIO | |
| import torch | |
| from . import test_binaries | |
| class TestReproducibility(unittest.TestCase): | |
| def _test_reproducibility( | |
| self, | |
| name, | |
| extra_flags=None, | |
| delta=0.0001, | |
| resume_checkpoint="checkpoint1.pt", | |
| max_epoch=3, | |
| ): | |
| def get_last_log_stats_containing_string(log_records, search_string): | |
| for log_record in logs.records[::-1]: | |
| if isinstance(log_record.msg, str) and search_string in log_record.msg: | |
| return json.loads(log_record.msg) | |
| if extra_flags is None: | |
| extra_flags = [] | |
| with tempfile.TemporaryDirectory(name) as data_dir: | |
| with self.assertLogs() as logs: | |
| test_binaries.create_dummy_data(data_dir) | |
| test_binaries.preprocess_translation_data(data_dir) | |
| # train epochs 1 and 2 together | |
| with self.assertLogs() as logs: | |
| test_binaries.train_translation_model( | |
| data_dir, | |
| "fconv_iwslt_de_en", | |
| [ | |
| "--dropout", | |
| "0.0", | |
| "--log-format", | |
| "json", | |
| "--log-interval", | |
| "1", | |
| "--max-epoch", | |
| str(max_epoch), | |
| ] | |
| + extra_flags, | |
| ) | |
| train_log = get_last_log_stats_containing_string(logs.records, "train_loss") | |
| valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss") | |
| # train epoch 2, resuming from previous checkpoint 1 | |
| os.rename( | |
| os.path.join(data_dir, resume_checkpoint), | |
| os.path.join(data_dir, "checkpoint_last.pt"), | |
| ) | |
| with self.assertLogs() as logs: | |
| test_binaries.train_translation_model( | |
| data_dir, | |
| "fconv_iwslt_de_en", | |
| [ | |
| "--dropout", | |
| "0.0", | |
| "--log-format", | |
| "json", | |
| "--log-interval", | |
| "1", | |
| "--max-epoch", | |
| str(max_epoch), | |
| ] | |
| + extra_flags, | |
| ) | |
| train_res_log = get_last_log_stats_containing_string( | |
| logs.records, "train_loss" | |
| ) | |
| valid_res_log = get_last_log_stats_containing_string( | |
| logs.records, "valid_loss" | |
| ) | |
| for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]: | |
| self.assertAlmostEqual( | |
| float(train_log[k]), float(train_res_log[k]), delta=delta | |
| ) | |
| for k in [ | |
| "valid_loss", | |
| "valid_ppl", | |
| "valid_num_updates", | |
| "valid_best_loss", | |
| ]: | |
| self.assertAlmostEqual( | |
| float(valid_log[k]), float(valid_res_log[k]), delta=delta | |
| ) | |
| def test_reproducibility(self): | |
| self._test_reproducibility("test_reproducibility") | |
| def test_reproducibility_fp16(self): | |
| self._test_reproducibility( | |
| "test_reproducibility_fp16", | |
| [ | |
| "--fp16", | |
| "--fp16-init-scale", | |
| "4096", | |
| ], | |
| delta=0.011, | |
| ) | |
| def test_reproducibility_memory_efficient_fp16(self): | |
| self._test_reproducibility( | |
| "test_reproducibility_memory_efficient_fp16", | |
| [ | |
| "--memory-efficient-fp16", | |
| "--fp16-init-scale", | |
| "4096", | |
| ], | |
| ) | |
| def test_reproducibility_amp(self): | |
| self._test_reproducibility( | |
| "test_reproducibility_amp", | |
| [ | |
| "--amp", | |
| "--fp16-init-scale", | |
| "4096", | |
| ], | |
| delta=0.011, | |
| ) | |
| def test_mid_epoch_reproducibility(self): | |
| self._test_reproducibility( | |
| "test_mid_epoch_reproducibility", | |
| ["--save-interval-updates", "3"], | |
| resume_checkpoint="checkpoint_1_3.pt", | |
| max_epoch=1, | |
| ) | |
| if __name__ == "__main__": | |
| unittest.main() | |