Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import inspect | |
| from typing import Dict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmengine.model import xavier_init | |
| from mmengine.registry import MODELS | |
| MODELS.register_module('nearest', module=nn.Upsample) | |
| MODELS.register_module('bilinear', module=nn.Upsample) | |
| class PixelShufflePack(nn.Module): | |
| """Pixel Shuffle upsample layer. | |
| This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to | |
| achieve a simple upsampling with pixel shuffle. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| scale_factor (int): Upsample ratio. | |
| upsample_kernel (int): Kernel size of the conv layer to expand the | |
| channels. | |
| """ | |
| def __init__(self, in_channels: int, out_channels: int, scale_factor: int, | |
| upsample_kernel: int): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.scale_factor = scale_factor | |
| self.upsample_kernel = upsample_kernel | |
| self.upsample_conv = nn.Conv2d( | |
| self.in_channels, | |
| self.out_channels * scale_factor * scale_factor, | |
| self.upsample_kernel, | |
| padding=(self.upsample_kernel - 1) // 2) | |
| self.init_weights() | |
| def init_weights(self): | |
| xavier_init(self.upsample_conv, distribution='uniform') | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.upsample_conv(x) | |
| x = F.pixel_shuffle(x, self.scale_factor) | |
| return x | |
| def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: | |
| """Build upsample layer. | |
| Args: | |
| cfg (dict): The upsample layer config, which should contain: | |
| - type (str): Layer type. | |
| - scale_factor (int): Upsample ratio, which is not applicable to | |
| deconv. | |
| - layer args: Args needed to instantiate a upsample layer. | |
| args (argument list): Arguments passed to the ``__init__`` | |
| method of the corresponding conv layer. | |
| kwargs (keyword arguments): Keyword arguments passed to the | |
| ``__init__`` method of the corresponding conv layer. | |
| Returns: | |
| nn.Module: Created upsample layer. | |
| """ | |
| if not isinstance(cfg, dict): | |
| raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |
| if 'type' not in cfg: | |
| raise KeyError( | |
| f'the cfg dict must contain the key "type", but got {cfg}') | |
| cfg_ = cfg.copy() | |
| layer_type = cfg_.pop('type') | |
| if inspect.isclass(layer_type): | |
| upsample = layer_type | |
| # Switch registry to the target scope. If `upsample` cannot be found | |
| # in the registry, fallback to search `upsample` in the | |
| # mmengine.MODELS. | |
| else: | |
| with MODELS.switch_scope_and_registry(None) as registry: | |
| upsample = registry.get(layer_type) | |
| if upsample is None: | |
| raise KeyError(f'Cannot find {upsample} in registry under scope ' | |
| f'name {registry.scope}') | |
| if upsample is nn.Upsample: | |
| cfg_['mode'] = layer_type | |
| layer = upsample(*args, **kwargs, **cfg_) | |
| return layer | |