Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) Megvii, Inc. and its affiliates. | |
| import unittest | |
| import torch | |
| from torch import nn | |
| from yolox.utils import adjust_status, freeze_module | |
| from yolox.exp import get_exp | |
| class TestModelUtils(unittest.TestCase): | |
| def setUp(self): | |
| self.model: nn.Module = get_exp(exp_name="yolox-s").get_model() | |
| def test_model_state_adjust_status(self): | |
| data = torch.ones(1, 10, 10, 10) | |
| # use bn since bn changes state during train/val | |
| model = nn.BatchNorm2d(10) | |
| prev_state = model.state_dict() | |
| modes = [False, True] | |
| results = [True, False] | |
| # test under train/eval mode | |
| for mode, result in zip(modes, results): | |
| with adjust_status(model, training=mode): | |
| model(data) | |
| model_state = model.state_dict() | |
| self.assertTrue(len(model_state) == len(prev_state)) | |
| self.assertEqual( | |
| result, | |
| all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()]) | |
| ) | |
| # test recurrsive context case | |
| prev_state = model.state_dict() | |
| with adjust_status(model, training=False): | |
| with adjust_status(model, training=False): | |
| model(data) | |
| model_state = model.state_dict() | |
| self.assertTrue(len(model_state) == len(prev_state)) | |
| self.assertTrue( | |
| all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()]) | |
| ) | |
| def test_model_effect_adjust_status(self): | |
| # test context effect | |
| self.model.train() | |
| with adjust_status(self.model, training=False): | |
| for module in self.model.modules(): | |
| self.assertFalse(module.training) | |
| # all training after exit | |
| for module in self.model.modules(): | |
| self.assertTrue(module.training) | |
| # only backbone set to eval | |
| self.model.backbone.eval() | |
| with adjust_status(self.model, training=False): | |
| for module in self.model.modules(): | |
| self.assertFalse(module.training) | |
| for name, module in self.model.named_modules(): | |
| if "backbone" in name: | |
| self.assertFalse(module.training) | |
| else: | |
| self.assertTrue(module.training) | |
| def test_freeze_module(self): | |
| model = nn.Sequential( | |
| nn.Conv2d(3, 10, 1), | |
| nn.BatchNorm2d(10), | |
| nn.ReLU(), | |
| ) | |
| data = torch.rand(1, 3, 10, 10) | |
| model.train() | |
| assert isinstance(model[1], nn.BatchNorm2d) | |
| before_states = model[1].state_dict() | |
| freeze_module(model[1]) | |
| model(data) | |
| after_states = model[1].state_dict() | |
| self.assertTrue( | |
| all([torch.allclose(v, after_states[k]) for k, v in before_states.items()]) | |
| ) | |
| # yolox test | |
| self.model.train() | |
| for module in self.model.modules(): | |
| self.assertTrue(module.training) | |
| freeze_module(self.model, "backbone") | |
| for module in self.model.backbone.modules(): | |
| self.assertFalse(module.training) | |
| for p in self.model.backbone.parameters(): | |
| self.assertFalse(p.requires_grad) | |
| for module in self.model.head.modules(): | |
| self.assertTrue(module.training) | |
| for p in self.model.head.parameters(): | |
| self.assertTrue(p.requires_grad) | |
| if __name__ == "__main__": | |
| unittest.main() | |