Spaces:
Runtime error
Runtime error
| import warnings | |
| import os | |
| import os.path as osp | |
| import pkgutil | |
| import warnings | |
| from collections import OrderedDict | |
| from importlib import import_module | |
| import torch | |
| import torchvision | |
| import torch.nn as nn | |
| from torch.utils import model_zoo | |
| from torch.nn import functional as F | |
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |
| from torch import distributed as dist | |
| TORCH_VERSION = torch.__version__ | |
| def resize(input, | |
| size=None, | |
| scale_factor=None, | |
| mode='nearest', | |
| align_corners=None, | |
| warning=True): | |
| if warning: | |
| if size is not None and align_corners: | |
| input_h, input_w = tuple(int(x) for x in input.shape[2:]) | |
| output_h, output_w = tuple(int(x) for x in size) | |
| if output_h > input_h or output_w > output_h: | |
| if ((output_h > 1 and output_w > 1 and input_h > 1 | |
| and input_w > 1) and (output_h - 1) % (input_h - 1) | |
| and (output_w - 1) % (input_w - 1)): | |
| warnings.warn( | |
| f'When align_corners={align_corners}, ' | |
| 'the output would more aligned if ' | |
| f'input size {(input_h, input_w)} is `x+1` and ' | |
| f'out size {(output_h, output_w)} is `nx+1`') | |
| if isinstance(size, torch.Size): | |
| size = tuple(int(x) for x in size) | |
| return F.interpolate(input, size, scale_factor, mode, align_corners) | |
| def normal_init(module, mean=0, std=1, bias=0): | |
| if hasattr(module, 'weight') and module.weight is not None: | |
| nn.init.normal_(module.weight, mean, std) | |
| if hasattr(module, 'bias') and module.bias is not None: | |
| nn.init.constant_(module.bias, bias) | |
| def is_module_wrapper(module): | |
| module_wrappers = (DataParallel, DistributedDataParallel) | |
| return isinstance(module, module_wrappers) | |
| def get_dist_info(): | |
| if TORCH_VERSION < '1.0': | |
| initialized = dist._initialized | |
| else: | |
| if dist.is_available(): | |
| initialized = dist.is_initialized() | |
| else: | |
| initialized = False | |
| if initialized: | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| else: | |
| rank = 0 | |
| world_size = 1 | |
| return rank, world_size | |
| def load_state_dict(module, state_dict, strict=False, logger=None): | |
| """Load state_dict to a module. | |
| This method is modified from :meth:`torch.nn.Module.load_state_dict`. | |
| Default value for ``strict`` is set to ``False`` and the message for | |
| param mismatch will be shown even if strict is False. | |
| Args: | |
| module (Module): Module that receives the state_dict. | |
| state_dict (OrderedDict): Weights. | |
| strict (bool): whether to strictly enforce that the keys | |
| in :attr:`state_dict` match the keys returned by this module's | |
| :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. | |
| logger (:obj:`logging.Logger`, optional): Logger to log the error | |
| message. If not specified, print function will be used. | |
| """ | |
| unexpected_keys = [] | |
| all_missing_keys = [] | |
| err_msg = [] | |
| metadata = getattr(state_dict, '_metadata', None) | |
| state_dict = state_dict.copy() | |
| if metadata is not None: | |
| state_dict._metadata = metadata | |
| # use _load_from_state_dict to enable checkpoint version control | |
| def load(module, prefix=''): | |
| # recursively check parallel module in case that the model has a | |
| # complicated structure, e.g., nn.Module(nn.Module(DDP)) | |
| if is_module_wrapper(module): | |
| module = module.module | |
| local_metadata = {} if metadata is None else metadata.get( | |
| prefix[:-1], {}) | |
| module._load_from_state_dict(state_dict, prefix, local_metadata, True, | |
| all_missing_keys, unexpected_keys, | |
| err_msg) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| load(child, prefix + name + '.') | |
| load(module) | |
| load = None # break load->load reference cycle | |
| # ignore "num_batches_tracked" of BN layers | |
| missing_keys = [ | |
| key for key in all_missing_keys if 'num_batches_tracked' not in key | |
| ] | |
| if unexpected_keys: | |
| err_msg.append('unexpected key in source ' | |
| f'state_dict: {", ".join(unexpected_keys)}\n') | |
| if missing_keys: | |
| err_msg.append( | |
| f'missing keys in source state_dict: {", ".join(missing_keys)}\n') | |
| rank, _ = get_dist_info() | |
| if len(err_msg) > 0 and rank == 0: | |
| err_msg.insert( | |
| 0, 'The model and loaded state dict do not match exactly\n') | |
| err_msg = '\n'.join(err_msg) | |
| if strict: | |
| raise RuntimeError(err_msg) | |
| elif logger is not None: | |
| logger.warning(err_msg) | |
| else: | |
| print(err_msg) | |
| def load_url_dist(url, model_dir=None): | |
| """In distributed setting, this function only download checkpoint at local | |
| rank 0.""" | |
| rank, world_size = get_dist_info() | |
| rank = int(os.environ.get('LOCAL_RANK', rank)) | |
| if rank == 0: | |
| checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if rank > 0: | |
| checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |
| return checkpoint | |
| def get_torchvision_models(): | |
| model_urls = dict() | |
| for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): | |
| if ispkg: | |
| continue | |
| _zoo = import_module(f'torchvision.models.{name}') | |
| if hasattr(_zoo, 'model_urls'): | |
| _urls = getattr(_zoo, 'model_urls') | |
| model_urls.update(_urls) | |
| return model_urls | |
| def _load_checkpoint(filename, map_location=None): | |
| """Load checkpoint from somewhere (modelzoo, file, url). | |
| Args: | |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
| ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |
| details. | |
| map_location (str | None): Same as :func:`torch.load`. Default: None. | |
| Returns: | |
| dict | OrderedDict: The loaded checkpoint. It can be either an | |
| OrderedDict storing model weights or a dict containing other | |
| information, which depends on the checkpoint. | |
| """ | |
| if filename.startswith('modelzoo://'): | |
| warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' | |
| 'use "torchvision://" instead') | |
| model_urls = get_torchvision_models() | |
| model_name = filename[11:] | |
| checkpoint = load_url_dist(model_urls[model_name]) | |
| else: | |
| if not osp.isfile(filename): | |
| raise IOError(f'{filename} is not a checkpoint file') | |
| checkpoint = torch.load(filename, map_location=map_location) | |
| return checkpoint | |
| def load_checkpoint(model, | |
| filename, | |
| map_location='cpu', | |
| strict=False, | |
| logger=None): | |
| """Load checkpoint from a file or URI. | |
| Args: | |
| model (Module): Module to load checkpoint. | |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
| ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |
| details. | |
| map_location (str): Same as :func:`torch.load`. | |
| strict (bool): Whether to allow different params for the model and | |
| checkpoint. | |
| logger (:mod:`logging.Logger` or None): The logger for error message. | |
| Returns: | |
| dict or OrderedDict: The loaded checkpoint. | |
| """ | |
| checkpoint = _load_checkpoint(filename, map_location) | |
| # OrderedDict is a subclass of dict | |
| if not isinstance(checkpoint, dict): | |
| raise RuntimeError( | |
| f'No state_dict found in checkpoint file {filename}') | |
| # get state_dict from checkpoint | |
| if 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| elif 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| else: | |
| state_dict = checkpoint | |
| # strip prefix of state_dict | |
| if list(state_dict.keys())[0].startswith('module.'): | |
| state_dict = {k[7:]: v for k, v in state_dict.items()} | |
| # for MoBY, load model of online branch | |
| if sorted(list(state_dict.keys()))[0].startswith('encoder'): | |
| state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} | |
| # reshape absolute position embedding | |
| if state_dict.get('absolute_pos_embed') is not None: | |
| absolute_pos_embed = state_dict['absolute_pos_embed'] | |
| N1, L, C1 = absolute_pos_embed.size() | |
| N2, C2, H, W = model.absolute_pos_embed.size() | |
| if N1 != N2 or C1 != C2 or L != H*W: | |
| logger.warning("Error in loading absolute_pos_embed, pass") | |
| else: | |
| state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) | |
| # interpolate position bias table if needed | |
| relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] | |
| for table_key in relative_position_bias_table_keys: | |
| table_pretrained = state_dict[table_key] | |
| table_current = model.state_dict()[table_key] | |
| L1, nH1 = table_pretrained.size() | |
| L2, nH2 = table_current.size() | |
| if nH1 != nH2: | |
| logger.warning(f"Error in loading {table_key}, pass") | |
| else: | |
| if L1 != L2: | |
| S1 = int(L1 ** 0.5) | |
| S2 = int(L2 ** 0.5) | |
| table_pretrained_resized = F.interpolate( | |
| table_pretrained.permute(1, 0).view(1, nH1, S1, S1), | |
| size=(S2, S2), mode='bicubic') | |
| state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) | |
| # load state_dict | |
| load_state_dict(model, state_dict, strict, logger) | |
| return checkpoint |