Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # -------------------------------------------------------- | |
| # CroCo model for downstream tasks | |
| # -------------------------------------------------------- | |
| import torch | |
| from .croco import CroCoNet | |
| def croco_args_from_ckpt(ckpt): | |
| if "croco_kwargs" in ckpt: # CroCo v2 released models | |
| return ckpt["croco_kwargs"] | |
| elif "args" in ckpt and hasattr( | |
| ckpt["args"], "model" | |
| ): # pretrained using the official code release | |
| s = ckpt[ | |
| "args" | |
| ].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)" | |
| assert s.startswith("CroCoNet(") | |
| return eval( | |
| "dict" + s[len("CroCoNet") :] | |
| ) # transform it into the string of a dictionary and evaluate it | |
| else: # CroCo v1 released models | |
| return dict() | |
| class CroCoDownstreamMonocularEncoder(CroCoNet): | |
| def __init__(self, head, **kwargs): | |
| """Build network for monocular downstream task, only using the encoder. | |
| It takes an extra argument head, that is called with the features | |
| and a dictionary img_info containing 'width' and 'height' keys | |
| The head is setup with the croconet arguments in this init function | |
| NOTE: It works by *calling super().__init__() but with redefined setters | |
| """ | |
| super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) | |
| head.setup(self) | |
| self.head = head | |
| def _set_mask_generator(self, *args, **kwargs): | |
| """No mask generator""" | |
| return | |
| def _set_mask_token(self, *args, **kwargs): | |
| """No mask token""" | |
| self.mask_token = None | |
| return | |
| def _set_decoder(self, *args, **kwargs): | |
| """No decoder""" | |
| return | |
| def _set_prediction_head(self, *args, **kwargs): | |
| """No 'prediction head' for downstream tasks.""" | |
| return | |
| def forward(self, img): | |
| """ | |
| img if of size batch_size x 3 x h x w | |
| """ | |
| B, C, H, W = img.size() | |
| img_info = {"height": H, "width": W} | |
| need_all_layers = ( | |
| hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks | |
| ) | |
| out, _, _ = self._encode_image( | |
| img, do_mask=False, return_all_blocks=need_all_layers | |
| ) | |
| return self.head(out, img_info) | |
| class CroCoDownstreamBinocular(CroCoNet): | |
| def __init__(self, head, **kwargs): | |
| """Build network for binocular downstream task | |
| It takes an extra argument head, that is called with the features | |
| and a dictionary img_info containing 'width' and 'height' keys | |
| The head is setup with the croconet arguments in this init function | |
| """ | |
| super(CroCoDownstreamBinocular, self).__init__(**kwargs) | |
| head.setup(self) | |
| self.head = head | |
| def _set_mask_generator(self, *args, **kwargs): | |
| """No mask generator""" | |
| return | |
| def _set_mask_token(self, *args, **kwargs): | |
| """No mask token""" | |
| self.mask_token = None | |
| return | |
| def _set_prediction_head(self, *args, **kwargs): | |
| """No prediction head for downstream tasks, define your own head""" | |
| return | |
| def encode_image_pairs(self, img1, img2, return_all_blocks=False): | |
| """run encoder for a pair of images | |
| it is actually ~5% faster to concatenate the images along the batch dimension | |
| than to encode them separately | |
| """ | |
| ## the two commented lines below is the naive version with separate encoding | |
| # out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks) | |
| # out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False) | |
| ## and now the faster version | |
| out, pos, _ = self._encode_image( | |
| torch.cat((img1, img2), dim=0), | |
| do_mask=False, | |
| return_all_blocks=return_all_blocks, | |
| ) | |
| if return_all_blocks: | |
| out, out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) | |
| out2 = out2[-1] | |
| else: | |
| out, out2 = out.chunk(2, dim=0) | |
| pos, pos2 = pos.chunk(2, dim=0) | |
| return out, out2, pos, pos2 | |
| def forward(self, img1, img2): | |
| B, C, H, W = img1.size() | |
| img_info = {"height": H, "width": W} | |
| return_all_blocks = ( | |
| hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks | |
| ) | |
| out, out2, pos, pos2 = self.encode_image_pairs( | |
| img1, img2, return_all_blocks=return_all_blocks | |
| ) | |
| if return_all_blocks: | |
| decout = self._decoder( | |
| out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks | |
| ) | |
| decout = out + decout | |
| else: | |
| decout = self._decoder( | |
| out, pos, None, out2, pos2, return_all_blocks=return_all_blocks | |
| ) | |
| return self.head(decout, img_info) | |