Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import collections | |
| import os | |
| import shutil | |
| import tempfile | |
| import unittest | |
| import numpy as np | |
| import torch | |
| from scripts.average_checkpoints import average_checkpoints | |
| from torch import nn | |
| class ModelWithSharedParameter(nn.Module): | |
| def __init__(self): | |
| super(ModelWithSharedParameter, self).__init__() | |
| self.embedding = nn.Embedding(1000, 200) | |
| self.FC1 = nn.Linear(200, 200) | |
| self.FC2 = nn.Linear(200, 200) | |
| # tie weight in FC2 to FC1 | |
| self.FC2.weight = nn.Parameter(self.FC1.weight) | |
| self.FC2.bias = nn.Parameter(self.FC1.bias) | |
| self.relu = nn.ReLU() | |
| def forward(self, input): | |
| return self.FC2(self.ReLU(self.FC1(input))) + self.FC1(input) | |
| class TestAverageCheckpoints(unittest.TestCase): | |
| def test_average_checkpoints(self): | |
| params_0 = collections.OrderedDict( | |
| [ | |
| ("a", torch.DoubleTensor([100.0])), | |
| ("b", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), | |
| ("c", torch.IntTensor([7, 8, 9])), | |
| ] | |
| ) | |
| params_1 = collections.OrderedDict( | |
| [ | |
| ("a", torch.DoubleTensor([1.0])), | |
| ("b", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])), | |
| ("c", torch.IntTensor([2, 2, 2])), | |
| ] | |
| ) | |
| params_avg = collections.OrderedDict( | |
| [ | |
| ("a", torch.DoubleTensor([50.5])), | |
| ("b", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])), | |
| # We expect truncation for integer division | |
| ("c", torch.IntTensor([4, 5, 5])), | |
| ] | |
| ) | |
| fd_0, path_0 = tempfile.mkstemp() | |
| fd_1, path_1 = tempfile.mkstemp() | |
| torch.save(collections.OrderedDict([("model", params_0)]), path_0) | |
| torch.save(collections.OrderedDict([("model", params_1)]), path_1) | |
| output = average_checkpoints([path_0, path_1])["model"] | |
| os.close(fd_0) | |
| os.remove(path_0) | |
| os.close(fd_1) | |
| os.remove(path_1) | |
| for (k_expected, v_expected), (k_out, v_out) in zip( | |
| params_avg.items(), output.items() | |
| ): | |
| self.assertEqual( | |
| k_expected, | |
| k_out, | |
| "Key mismatch - expected {} but found {}. " | |
| "(Expected list of keys: {} vs actual list of keys: {})".format( | |
| k_expected, k_out, params_avg.keys(), output.keys() | |
| ), | |
| ) | |
| np.testing.assert_allclose( | |
| v_expected.numpy(), | |
| v_out.numpy(), | |
| err_msg="Tensor value mismatch for key {}".format(k_expected), | |
| ) | |
| def test_average_checkpoints_with_shared_parameters(self): | |
| def _construct_model_with_shared_parameters(path, value): | |
| m = ModelWithSharedParameter() | |
| nn.init.constant_(m.FC1.weight, value) | |
| torch.save({"model": m.state_dict()}, path) | |
| return m | |
| tmpdir = tempfile.mkdtemp() | |
| paths = [] | |
| path = os.path.join(tmpdir, "m1.pt") | |
| m1 = _construct_model_with_shared_parameters(path, 1.0) | |
| paths.append(path) | |
| path = os.path.join(tmpdir, "m2.pt") | |
| m2 = _construct_model_with_shared_parameters(path, 2.0) | |
| paths.append(path) | |
| path = os.path.join(tmpdir, "m3.pt") | |
| m3 = _construct_model_with_shared_parameters(path, 3.0) | |
| paths.append(path) | |
| new_model = average_checkpoints(paths) | |
| self.assertTrue( | |
| torch.equal( | |
| new_model["model"]["embedding.weight"], | |
| (m1.embedding.weight + m2.embedding.weight + m3.embedding.weight) / 3.0, | |
| ) | |
| ) | |
| self.assertTrue( | |
| torch.equal( | |
| new_model["model"]["FC1.weight"], | |
| (m1.FC1.weight + m2.FC1.weight + m3.FC1.weight) / 3.0, | |
| ) | |
| ) | |
| self.assertTrue( | |
| torch.equal( | |
| new_model["model"]["FC2.weight"], | |
| (m1.FC2.weight + m2.FC2.weight + m3.FC2.weight) / 3.0, | |
| ) | |
| ) | |
| shutil.rmtree(tmpdir) | |
| if __name__ == "__main__": | |
| unittest.main() | |