| import torch | |
| import tqdm | |
| from torch import nn | |
| from transformers import MT5EncoderModel, MT5PreTrainedModel | |
| class MT5EncoderWithProjection(MT5PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.mt5_encoder = MT5EncoderModel(config) | |
| self.projection = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.post_init() | |
| def forward(self, **input_args): | |
| hidden_states = self.mt5_encoder(**input_args).last_hidden_state | |
| mask = input_args['attention_mask'] | |
| batch_embeddings = torch.sum(hidden_states * mask[:, :, None], dim=1) / torch.sum(mask, dim=1)[:, None] | |
| batch_embeddings = self.projection(batch_embeddings) | |
| return batch_embeddings | |