Spaces:
Paused
Paused
| # coding=utf-8 | |
| # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Reference: | |
| # * transformers/models/dinov2/modeling_dinov2.py | |
| # * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 | |
| # * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 | |
| """PyTorch CLIP model.""" | |
| from typing import Dict, List, Optional, Set, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from .modeling_clip import ( | |
| CLIPConfig, | |
| CLIPTextConfig, | |
| CLIPVisionConfig, | |
| CLIPEncoderLayer, | |
| CLIPTextTransformer, | |
| CLIPVisionTransformer, | |
| CLIPModel, | |
| CLIPVisionEmbeddings, | |
| CLIPVisionModel, | |
| CLIPOutput, | |
| BaseModelOutput, | |
| BaseModelOutputWithPooling, | |
| ) | |
| class ModLN(nn.Module): | |
| def __init__(self, inner_dim: int, mod_dim: int = 32): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(mod_dim, inner_dim * 2), | |
| ) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.zeros_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x: torch.Tensor, condition: torch.Tensor): | |
| """ | |
| x: [N, M, C_in], M: num of tokens | |
| condition: [N, C_mod] | |
| """ | |
| shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) | |
| return x * (1 + scale) + shift | |
| class ConditionalCLIPVisionConfig(CLIPVisionConfig): | |
| def __init__(self, modulation_dim: int = 32, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.modulation_dim = modulation_dim | |
| class ConditionalCLIPEncoderLayer(CLIPEncoderLayer): | |
| """This corresponds to the Block class in the original implementation.""" | |
| def __init__(self, config: ConditionalCLIPVisionConfig) -> None: | |
| super().__init__(config) | |
| self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) | |
| self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| causal_attention_mask: torch.Tensor, | |
| condition: Optional[torch.Tensor] = None, | |
| output_attentions: bool = False, | |
| ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: | |
| residual = hidden_states | |
| hidden_states = self.mod_norm1(self.layer_norm1(hidden_states), condition) | |
| hidden_states, attn_weights = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.mod_norm2(self.layer_norm2(hidden_states), condition) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs | |
| class ConditionalCLIPEncoder(nn.Module): | |
| def __init__(self, config: CLIPConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList( | |
| [ | |
| ConditionalCLIPEncoderLayer(config) | |
| for _ in range(config.num_hidden_layers) | |
| ] | |
| ) | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| inputs_embeds, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| condition: Optional[torch.Tensor] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[tuple, BaseModelOutput]: | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| encoder_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| hidden_states = inputs_embeds | |
| for idx, encoder_layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if self.gradient_checkpointing and self.training: | |
| layer_outputs = self._gradient_checkpointing_func( | |
| encoder_layer.__call__, | |
| hidden_states, | |
| attention_mask, | |
| causal_attention_mask, | |
| condition=condition, | |
| output_attentions=output_attentions, | |
| ) | |
| else: | |
| layer_outputs = encoder_layer( | |
| hidden_states, | |
| attention_mask, | |
| causal_attention_mask, | |
| condition=condition, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_attentions = all_attentions + (layer_outputs[1],) | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if not return_dict: | |
| return tuple( | |
| v | |
| for v in [hidden_states, encoder_states, all_attentions] | |
| if v is not None | |
| ) | |
| return BaseModelOutput( | |
| last_hidden_state=hidden_states, | |
| hidden_states=encoder_states, | |
| attentions=all_attentions, | |
| ) | |
| class ConditionalCLIPVisionTransformer(CLIPVisionTransformer): | |
| def __init__(self, config: ConditionalCLIPVisionConfig): | |
| super().__init__(config) | |
| self.config = config | |
| embed_dim = config.hidden_size | |
| self.embeddings = CLIPVisionEmbeddings(config) | |
| self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
| self.encoder = ConditionalCLIPEncoder(config) | |
| self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| condition: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| if pixel_values is None: | |
| raise ValueError("You have to specify pixel_values") | |
| hidden_states = self.embeddings(pixel_values) | |
| hidden_states = self.pre_layrnorm(hidden_states) | |
| encoder_outputs = self.encoder( | |
| inputs_embeds=hidden_states, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| condition=condition, | |
| return_dict=return_dict, | |
| ) | |
| last_hidden_state = encoder_outputs[0] | |
| pooled_output = last_hidden_state[:, 0, :] | |
| pooled_output = self.post_layernorm(pooled_output) | |
| if not return_dict: | |
| return (last_hidden_state, pooled_output) + encoder_outputs[1:] | |
| return BaseModelOutputWithPooling( | |
| last_hidden_state=last_hidden_state, | |
| pooler_output=pooled_output, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| ) | |
| class ConditionalCLIPVisionModel(CLIPVisionModel): | |
| config_class = ConditionalCLIPVisionConfig | |
| def __init__(self, config: ConditionalCLIPVisionConfig): | |
| super().__init__(config) | |
| self.vision_model = ConditionalCLIPVisionTransformer(config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| condition: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| return self.vision_model( | |
| pixel_values=pixel_values, | |
| condition=condition, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| class ConditionalCLIPModel(CLIPModel): | |
| config_class = CLIPConfig | |
| def __init__(self, config: CLIPConfig): | |
| super().__init__(config) | |
| if not isinstance(config.text_config, CLIPTextConfig): | |
| raise ValueError( | |
| "config.text_config is expected to be of type CLIPTextConfig but is of type" | |
| f" {type(config.text_config)}." | |
| ) | |
| if not isinstance(config.vision_config, CLIPVisionConfig): | |
| raise ValueError( | |
| "config.vision_config is expected to be of type CLIPVisionConfig but is of type" | |
| f" {type(config.vision_config)}." | |
| ) | |
| text_config = config.text_config | |
| vision_config = config.vision_config | |
| self.projection_dim = config.projection_dim | |
| self.text_embed_dim = text_config.hidden_size | |
| self.vision_embed_dim = vision_config.hidden_size | |
| self.text_model = CLIPTextTransformer(text_config) | |
| self.vision_model = ConditionalCLIPVisionTransformer(vision_config) | |
| self.visual_projection = nn.Linear( | |
| self.vision_embed_dim, self.projection_dim, bias=False | |
| ) | |
| self.text_projection = nn.Linear( | |
| self.text_embed_dim, self.projection_dim, bias=False | |
| ) | |
| self.logit_scale = nn.Parameter( | |
| torch.tensor(self.config.logit_scale_init_value) | |
| ) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_image_features( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| condition: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> torch.FloatTensor: | |
| # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| vision_outputs = self.vision_model( | |
| pixel_values=pixel_values, | |
| condition=condition, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| pooled_output = vision_outputs[1] # pooled_output | |
| image_features = self.visual_projection(pooled_output) | |
| return image_features | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| condition: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| return_loss: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, CLIPOutput]: | |
| # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| vision_outputs = self.vision_model( | |
| pixel_values=pixel_values, | |
| condition=condition, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| text_outputs = self.text_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| image_embeds = vision_outputs[1] | |
| image_embeds = self.visual_projection(image_embeds) | |
| text_embeds = text_outputs[1] | |
| text_embeds = self.text_projection(text_embeds) | |
| # normalized features | |
| image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | |
| text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) | |
| # cosine similarity as logits | |
| logit_scale = self.logit_scale.exp() | |
| logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale | |
| logits_per_image = logits_per_text.t() | |
| loss = None | |
| if return_loss: | |
| loss = clip_loss(logits_per_text) | |
| if not return_dict: | |
| output = ( | |
| logits_per_image, | |
| logits_per_text, | |
| text_embeds, | |
| image_embeds, | |
| text_outputs, | |
| vision_outputs, | |
| ) | |
| return ((loss,) + output) if loss is not None else output | |
| return CLIPOutput( | |
| loss=loss, | |
| logits_per_image=logits_per_image, | |
| logits_per_text=logits_per_text, | |
| text_embeds=text_embeds, | |
| image_embeds=image_embeds, | |
| text_model_output=text_outputs, | |
| vision_model_output=vision_outputs, | |
| ) | |