Spaces:
Paused
Paused
| # coding=utf-8 | |
| # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Reference: | |
| # * transformers/models/dinov2/modeling_dinov2.py | |
| # * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 | |
| # * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 | |
| """PyTorch DINOv2 model.""" | |
| from typing import Dict, List, Optional, Set, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from .modeling_dinov2 import ( | |
| Dinov2Config, | |
| Dinov2Layer, | |
| Dinov2Model, | |
| Dinov2Embeddings, | |
| BaseModelOutput, | |
| BaseModelOutputWithPooling, | |
| ) | |
| class ModLN(nn.Module): | |
| def __init__(self, inner_dim: int, mod_dim: int = 1024): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(mod_dim, inner_dim * 2), | |
| ) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.zeros_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x: torch.Tensor, condition: torch.Tensor): | |
| """ | |
| x: [N, M, C_in], M: num of tokens | |
| condition: [N, C_mod] | |
| """ | |
| shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) | |
| return x * (1 + scale) + shift | |
| class ConditionalDinov2Config(Dinov2Config): | |
| def __init__(self, modulation_dim: int = 1024, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.modulation_dim = modulation_dim | |
| class ConditionalDinov2Layer(Dinov2Layer): | |
| """This corresponds to the Block class in the original implementation.""" | |
| def __init__(self, config: ConditionalDinov2Config) -> None: | |
| super().__init__(config) | |
| self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) | |
| self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| head_mask: Optional[torch.Tensor] = None, | |
| condition: Optional[torch.Tensor] = None, | |
| output_attentions: bool = False, | |
| ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: | |
| self_attention_outputs = self.attention( | |
| self.mod_norm1( | |
| self.norm1(hidden_states), condition | |
| ), # in Dinov2, layernorm is applied before self-attention | |
| head_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| attention_output = self_attention_outputs[0] | |
| attention_output = self.layer_scale1(attention_output) | |
| outputs = self_attention_outputs[ | |
| 1: | |
| ] # add self attentions if we output attention weights | |
| # first residual connection | |
| hidden_states = self.drop_path(attention_output) + hidden_states | |
| # in Dinov2, layernorm is also applied after self-attention | |
| layer_output = self.mod_norm2(self.norm2(hidden_states), condition) | |
| layer_output = self.mlp(layer_output) | |
| layer_output = self.layer_scale2(layer_output) | |
| # second residual connection | |
| layer_output = self.drop_path(layer_output) + hidden_states | |
| outputs = (layer_output,) + outputs | |
| return outputs | |
| # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 | |
| class ConditionalDinov2Encoder(nn.Module): | |
| def __init__(self, config: ConditionalDinov2Config) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.layer = nn.ModuleList( | |
| [ConditionalDinov2Layer(config) for _ in range(config.num_hidden_layers)] | |
| ) | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| head_mask: Optional[torch.Tensor] = None, | |
| output_attentions: bool = False, | |
| output_hidden_states: bool = False, | |
| condition: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| ) -> Union[tuple, BaseModelOutput]: | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attentions = () if output_attentions else None | |
| for i, layer_module in enumerate(self.layer): | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| layer_head_mask = head_mask[i] if head_mask is not None else None | |
| if self.gradient_checkpointing and self.training: | |
| layer_outputs = self._gradient_checkpointing_func( | |
| layer_module.__call__, | |
| hidden_states, | |
| layer_head_mask, | |
| condition, | |
| output_attentions, | |
| ) | |
| else: | |
| layer_outputs = layer_module( | |
| hidden_states, | |
| layer_head_mask, | |
| condition, | |
| output_attentions, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_self_attentions = all_self_attentions + (layer_outputs[1],) | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| if not return_dict: | |
| return tuple( | |
| v | |
| for v in [hidden_states, all_hidden_states, all_self_attentions] | |
| if v is not None | |
| ) | |
| return BaseModelOutput( | |
| last_hidden_state=hidden_states, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attentions, | |
| ) | |
| class ConditionalDinov2Model(Dinov2Model): | |
| config_class = ConditionalDinov2Config | |
| def __init__(self, config: ConditionalDinov2Config): | |
| super().__init__(config) | |
| self.config = config | |
| self.embeddings = Dinov2Embeddings(config) | |
| self.encoder = ConditionalDinov2Encoder(config) | |
| self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| bool_masked_pos: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| condition: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| if pixel_values is None: | |
| raise ValueError("You have to specify pixel_values") | |
| # Prepare head mask if needed | |
| # 1.0 in head_mask indicate we keep the head | |
| # attention_probs has shape bsz x n_heads x N x N | |
| # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | |
| # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | |
| head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | |
| embedding_output = self.embeddings( | |
| pixel_values, bool_masked_pos=bool_masked_pos | |
| ) | |
| encoder_outputs = self.encoder( | |
| embedding_output, | |
| head_mask=head_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| condition=condition, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = encoder_outputs[0] | |
| sequence_output = self.layernorm(sequence_output) | |
| pooled_output = sequence_output[:, 0, :] | |
| if not return_dict: | |
| head_outputs = (sequence_output, pooled_output) | |
| return head_outputs + encoder_outputs[1:] | |
| return BaseModelOutputWithPooling( | |
| last_hidden_state=sequence_output, | |
| pooler_output=pooled_output, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| ) | |