Spaces:
Running
on
Zero
Running
on
Zero
| from typing import ClassVar | |
| import torch | |
| from torch import nn | |
| from modeling_florence2 import Florence2ForConditionalGeneration, Florence2VisionLanguageModel | |
| from configuration_florence2 import Florence2Config | |
| class ColFlor2Old(Florence2ForConditionalGeneration): | |
| """ | |
| ColFlor2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. | |
| """ | |
| main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related | |
| def __init__(self, config: Florence2Config, use_cache=False): | |
| super().__init__(config=config) | |
| self.dim = 128 | |
| self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim) | |
| # Now initialize weights properly | |
| self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02) | |
| self.custom_text_proj.bias.data.zero_() | |
| self.padding_side = "right" | |
| self.post_init() | |
| def forward(self, *args, **kwargs) -> torch.Tensor: | |
| # Delete output_hidden_states from kwargs | |
| kwargs.pop("output_hidden_states", None) | |
| # TO BE DELETED | |
| kwargs['decoder_input_ids'] = kwargs['input_ids'] | |
| # Create Full Attention Mask that includes the image | |
| if 'full_attention_mask' in kwargs: | |
| full_attention_mask = kwargs['full_attention_mask'] | |
| del kwargs['full_attention_mask'] | |
| else: | |
| full_attention_mask = kwargs['attention_mask'] | |
| outputs = super().forward(*args, | |
| **kwargs) # (batch_size, sequence_length, hidden_size) | |
| last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size) | |
| proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) | |
| # L2 normalization | |
| proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) | |
| proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) | |
| return proj | |
| class ColFlor(Florence2VisionLanguageModel): | |
| """ | |
| ColFlor model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. | |
| """ | |
| main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related | |
| def __init__(self, config: Florence2Config, use_cache=False): | |
| super().__init__(config=config) | |
| self.dim = 128 | |
| self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim) | |
| # Now initialize weights properly | |
| self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02) | |
| self.custom_text_proj.bias.data.zero_() | |
| self.padding_side = "right" | |
| self.post_init() | |
| def forward(self, *args, **kwargs) -> torch.Tensor: | |
| # Delete output_hidden_states from kwargs | |
| kwargs.pop("output_hidden_states", None) | |
| # Create Full Attention Mask that includes both the image and text | |
| if 'full_attention_mask' in kwargs: | |
| full_attention_mask = kwargs['full_attention_mask'] | |
| del kwargs['full_attention_mask'] | |
| else: | |
| full_attention_mask = kwargs['attention_mask'] | |
| outputs = super().forward(*args, | |
| **kwargs) # (batch_size, sequence_length, hidden_size) | |
| last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size) | |
| proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) | |
| # L2 normalization | |
| proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) | |
| proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) | |
| return proj |