Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel | |
| from transformers.utils import ModelOutput | |
| class TransformationModelOutput(ModelOutput): | |
| """ | |
| Base class for text model's outputs that also contains a pooling of the last hidden states. | |
| Args: | |
| text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): | |
| The text embeddings obtained by applying the projection layer to the pooler_output. | |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
| Sequence of hidden-states at the output of the last layer of the model. | |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
| sequence_length)`. | |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
| heads. | |
| """ | |
| projection_state: Optional[torch.FloatTensor] = None | |
| last_hidden_state: torch.FloatTensor = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| class RobertaSeriesConfig(XLMRobertaConfig): | |
| def __init__( | |
| self, | |
| pad_token_id=1, | |
| bos_token_id=0, | |
| eos_token_id=2, | |
| project_dim=512, | |
| pooler_fn="cls", | |
| learn_encoder=False, | |
| use_attention_mask=True, | |
| **kwargs, | |
| ): | |
| super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) | |
| self.project_dim = project_dim | |
| self.pooler_fn = pooler_fn | |
| self.learn_encoder = learn_encoder | |
| self.use_attention_mask = use_attention_mask | |
| class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): | |
| _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] | |
| _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] | |
| base_model_prefix = "roberta" | |
| config_class = RobertaSeriesConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.roberta = XLMRobertaModel(config) | |
| self.transformation = nn.Linear(config.hidden_size, config.project_dim) | |
| self.has_pre_transformation = getattr(config, "has_pre_transformation", False) | |
| if self.has_pre_transformation: | |
| self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) | |
| self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| ): | |
| r""" """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.base_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=True if self.has_pre_transformation else output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| if self.has_pre_transformation: | |
| sequence_output2 = outputs["hidden_states"][-2] | |
| sequence_output2 = self.pre_LN(sequence_output2) | |
| projection_state2 = self.transformation_pre(sequence_output2) | |
| return TransformationModelOutput( | |
| projection_state=projection_state2, | |
| last_hidden_state=outputs.last_hidden_state, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| else: | |
| projection_state = self.transformation(outputs.last_hidden_state) | |
| return TransformationModelOutput( | |
| projection_state=projection_state, | |
| last_hidden_state=outputs.last_hidden_state, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |