Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # -------------------------------------------------------- | |
| # Heads for downstream tasks | |
| # -------------------------------------------------------- | |
| """ | |
| A head is a module where the __init__ defines only the head hyperparameters. | |
| A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. | |
| The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from .dpt_block import DPTOutputAdapter | |
| class PixelwiseTaskWithDPT(nn.Module): | |
| """DPT module for CroCo. | |
| by default, hooks_idx will be equal to: | |
| * for encoder-only: 4 equally spread layers | |
| * for encoder+decoder: last encoder + 3 equally spread layers of the decoder | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| hooks_idx=None, | |
| layer_dims=[96, 192, 384, 768], | |
| output_width_ratio=1, | |
| num_channels=1, | |
| postprocess=None, | |
| **kwargs, | |
| ): | |
| super(PixelwiseTaskWithDPT, self).__init__() | |
| self.return_all_blocks = True # backbone needs to return all layers | |
| self.postprocess = postprocess | |
| self.output_width_ratio = output_width_ratio | |
| self.num_channels = num_channels | |
| self.hooks_idx = hooks_idx | |
| self.layer_dims = layer_dims | |
| def setup(self, croconet): | |
| dpt_args = { | |
| "output_width_ratio": self.output_width_ratio, | |
| "num_channels": self.num_channels, | |
| } | |
| if self.hooks_idx is None: | |
| if hasattr(croconet, "dec_blocks"): # encoder + decoder | |
| step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] | |
| hooks_idx = [ | |
| croconet.dec_depth + croconet.enc_depth - 1 - i * step | |
| for i in range(3, -1, -1) | |
| ] | |
| else: # encoder only | |
| step = croconet.enc_depth // 4 | |
| hooks_idx = [ | |
| croconet.enc_depth - 1 - i * step for i in range(3, -1, -1) | |
| ] | |
| self.hooks_idx = hooks_idx | |
| print( | |
| f" PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}" | |
| ) | |
| dpt_args["hooks"] = self.hooks_idx | |
| dpt_args["layer_dims"] = self.layer_dims | |
| self.dpt = DPTOutputAdapter(**dpt_args) | |
| dim_tokens = [ | |
| ( | |
| croconet.enc_embed_dim | |
| if hook < croconet.enc_depth | |
| else croconet.dec_embed_dim | |
| ) | |
| for hook in self.hooks_idx | |
| ] | |
| dpt_init_args = {"dim_tokens_enc": dim_tokens} | |
| self.dpt.init(**dpt_init_args) | |
| def forward(self, x, img_info): | |
| out = self.dpt(x, image_size=(img_info["height"], img_info["width"])) | |
| if self.postprocess: | |
| out = self.postprocess(out) | |
| return out | |