Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import os | |
| import tempfile | |
| import unittest | |
| from collections import OrderedDict | |
| import torch | |
| from iopath.common.file_io import PathHandler, PathManager | |
| from torch import nn | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| from detectron2.checkpoint.c2_model_loading import ( | |
| _longest_common_prefix_str, | |
| align_and_update_state_dicts, | |
| ) | |
| from detectron2.utils.logger import setup_logger | |
| class TestCheckpointer(unittest.TestCase): | |
| def setUp(self): | |
| setup_logger() | |
| def create_complex_model(self): | |
| m = nn.Module() | |
| m.block1 = nn.Module() | |
| m.block1.layer1 = nn.Linear(2, 3) | |
| m.layer2 = nn.Linear(3, 2) | |
| m.res = nn.Module() | |
| m.res.layer2 = nn.Linear(3, 2) | |
| state_dict = OrderedDict() | |
| state_dict["layer1.weight"] = torch.rand(3, 2) | |
| state_dict["layer1.bias"] = torch.rand(3) | |
| state_dict["layer2.weight"] = torch.rand(2, 3) | |
| state_dict["layer2.bias"] = torch.rand(2) | |
| state_dict["res.layer2.weight"] = torch.rand(2, 3) | |
| state_dict["res.layer2.bias"] = torch.rand(2) | |
| return m, state_dict | |
| def test_complex_model_loaded(self): | |
| for add_data_parallel in [False, True]: | |
| model, state_dict = self.create_complex_model() | |
| if add_data_parallel: | |
| model = nn.DataParallel(model) | |
| model_sd = model.state_dict() | |
| sd_to_load = align_and_update_state_dicts(model_sd, state_dict) | |
| model.load_state_dict(sd_to_load) | |
| for loaded, stored in zip(model_sd.values(), state_dict.values()): | |
| # different tensor references | |
| self.assertFalse(id(loaded) == id(stored)) | |
| # same content | |
| self.assertTrue(loaded.to(stored).equal(stored)) | |
| def test_load_with_matching_heuristics(self): | |
| with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: | |
| model, state_dict = self.create_complex_model() | |
| torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth")) | |
| checkpointer = DetectionCheckpointer(model, save_dir=d) | |
| with torch.no_grad(): | |
| # use a different weight from the `state_dict`, since torch.rand is less than 1 | |
| model.block1.layer1.weight.fill_(1) | |
| # load checkpoint without matching_heuristics | |
| checkpointer.load(os.path.join(d, "checkpoint.pth")) | |
| self.assertTrue(model.block1.layer1.weight.equal(torch.ones(3, 2))) | |
| # load checkpoint with matching_heuristics | |
| checkpointer.load(os.path.join(d, "checkpoint.pth?matching_heuristics=True")) | |
| self.assertFalse(model.block1.layer1.weight.equal(torch.ones(3, 2))) | |
| def test_custom_path_manager_handler(self): | |
| with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: | |
| class CustomPathManagerHandler(PathHandler): | |
| PREFIX = "detectron2_test://" | |
| def _get_supported_prefixes(self): | |
| return [self.PREFIX] | |
| def _get_local_path(self, path, **kwargs): | |
| name = path[len(self.PREFIX) :] | |
| return os.path.join(d, name) | |
| def _open(self, path, mode="r", **kwargs): | |
| return open(self._get_local_path(path), mode, **kwargs) | |
| pathmgr = PathManager() | |
| pathmgr.register_handler(CustomPathManagerHandler()) | |
| model, state_dict = self.create_complex_model() | |
| torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth")) | |
| checkpointer = DetectionCheckpointer(model, save_dir=d) | |
| checkpointer.path_manager = pathmgr | |
| checkpointer.load("detectron2_test://checkpoint.pth") | |
| checkpointer.load("detectron2_test://checkpoint.pth?matching_heuristics=True") | |
| def test_lcp(self): | |
| self.assertEqual(_longest_common_prefix_str(["class", "dlaps_model"]), "") | |
| self.assertEqual(_longest_common_prefix_str(["classA", "classB"]), "class") | |
| self.assertEqual(_longest_common_prefix_str(["classA", "classB", "clab"]), "cla") | |
| if __name__ == "__main__": | |
| unittest.main() | |