Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import unittest | |
| import torch | |
| from torch import nn | |
| from fairseq.distributed import ModuleProxyWrapper | |
| from .utils import objects_are_equal | |
| class MockDDPWrapper(nn.Module): | |
| """A simple wrapper with an interface similar to DistributedDataParallel.""" | |
| def __init__(self, module): | |
| super().__init__() | |
| self.module = module | |
| def forward(self, x): | |
| return self.module(x) | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = nn.Linear(5, 10) | |
| self.xyz = "hello" | |
| def forward(self, x): | |
| return self.linear(x) | |
| def get_xyz(self): | |
| return self.xyz | |
| class TestModuleProxyWrapper(unittest.TestCase): | |
| def _get_module(self): | |
| module = Model() | |
| wrapped_module = MockDDPWrapper(module) | |
| wrapped_module = ModuleProxyWrapper(wrapped_module) | |
| return wrapped_module, module | |
| def test_getattr_forwarding(self): | |
| wrapped_module, module = self._get_module() | |
| assert module.xyz == "hello" | |
| assert module.get_xyz() == "hello" | |
| assert wrapped_module.xyz == "hello" | |
| wrapped_module.xyz = "world" | |
| assert wrapped_module.xyz == "world" | |
| assert module.get_xyz() == "hello" | |
| def test_state_dict(self): | |
| wrapped_module, module = self._get_module() | |
| assert objects_are_equal(wrapped_module.state_dict(), module.state_dict()) | |
| def test_load_state_dict(self): | |
| wrapped_module, module = self._get_module() | |
| wrapped_module.load_state_dict(module.state_dict()) | |
| input = torch.rand(4, 5) | |
| torch.testing.assert_allclose(wrapped_module(input), module(input)) | |
| def test_forward(self): | |
| wrapped_module, module = self._get_module() | |
| input = torch.rand(4, 5) | |
| torch.testing.assert_allclose(wrapped_module(input), module(input)) | |
| if __name__ == "__main__": | |
| unittest.main() | |