Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2023/4/05 18:02 下午 | |
| # @Author : NuoChen | |
| # @File : code_generation.py | |
| from transformers import PLBartTokenizer, PLBartForSequenceClassification, PLBartConfig, PLBartForConditionalGeneration | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutput, | |
| BaseModelOutputWithPastAndCrossAttentions, | |
| CausalLMOutputWithCrossAttentions, | |
| Seq2SeqLMOutput, | |
| Seq2SeqModelOutput, | |
| Seq2SeqSequenceClassifierOutput, | |
| ) | |
| import torch | |
| from torch import nn | |
| from typing import Optional, List, Union, Tuple | |
| from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss | |
| from transformers import RobertaModel, RobertaPreTrainedModel | |
| from transformers.models.plbart.modeling_plbart import PLBartPreTrainedModel, PLBartModel | |
| from transformers.models.plbart.configuration_plbart import PLBartConfig | |
| def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): | |
| """ | |
| Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not | |
| have a single `decoder_start_token_id` in contrast to other Bart-like models. | |
| """ | |
| prev_output_tokens = input_ids.clone() | |
| if pad_token_id is None: | |
| raise ValueError("self.model.config.pad_token_id has to be defined.") | |
| # replace possible -100 values in labels by `pad_token_id` | |
| prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) | |
| index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) | |
| decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() | |
| prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() | |
| prev_output_tokens[:, 0] = decoder_start_tokens | |
| return prev_output_tokens | |
| class PLBARTForCodeGeneration(PLBartPreTrainedModel): | |
| base_model_prefix = "model" | |
| _keys_to_ignore_on_load_missing = [ | |
| r"final_logits_bias", | |
| r"encoder.version", | |
| r"decoder.version", | |
| r"lm_head.weight", | |
| ] | |
| def __init__(self, config: PLBartConfig): | |
| super().__init__(config) | |
| self.model = PLBartModel(config) | |
| self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) | |
| self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) | |
| self.init_weights() | |
| def get_encoder(self): | |
| return self.model.get_encoder() | |
| def get_decoder(self): | |
| return self.model.get_decoder() | |
| def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: | |
| new_embeddings = super().resize_token_embeddings(new_num_tokens) | |
| self._resize_final_logits_bias(new_num_tokens) | |
| return new_embeddings | |
| def _resize_final_logits_bias(self, new_num_tokens: int) -> None: | |
| old_num_tokens = self.final_logits_bias.shape[-1] | |
| if new_num_tokens <= old_num_tokens: | |
| new_bias = self.final_logits_bias[:, :new_num_tokens] | |
| else: | |
| extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) | |
| new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) | |
| self.register_buffer("final_logits_bias", new_bias) | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.LongTensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| decoder_inputs_embeds=None, | |
| labels: Optional[torch.Tensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| Returns: | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None: | |
| if decoder_input_ids is None: | |
| decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) | |
| outputs = self.model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| decoder_input_ids=decoder_input_ids, | |
| encoder_outputs=encoder_outputs, | |
| decoder_attention_mask=decoder_attention_mask, | |
| head_mask=head_mask, | |
| decoder_head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias | |
| masked_lm_loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss() | |
| masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) | |
| if not return_dict: | |
| output = (lm_logits,) + outputs[1:] | |
| return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
| return Seq2SeqLMOutput( | |
| loss=masked_lm_loss, | |
| logits=lm_logits, | |
| past_key_values=outputs.past_key_values, | |
| decoder_hidden_states=outputs.decoder_hidden_states, | |
| decoder_attentions=outputs.decoder_attentions, | |
| cross_attentions=outputs.cross_attentions, | |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
| encoder_hidden_states=outputs.encoder_hidden_states, | |
| encoder_attentions=outputs.encoder_attentions, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| decoder_input_ids: torch.LongTensor, | |
| past: Optional[List[torch.FloatTensor]] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| use_cache: Optional[bool] = None, | |
| encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
| **kwargs # TODO: Check if this is needed. It is unused? | |
| ) -> Dict[str, Any]: | |
| # cut decoder_input_ids if past is used | |
| if past is not None: | |
| decoder_input_ids = decoder_input_ids[:, -1:] | |
| return { | |
| "input_ids": None, # encoder_outputs is defined. input_ids not needed | |
| "encoder_outputs": encoder_outputs, | |
| "past_key_values": past, | |
| "decoder_input_ids": decoder_input_ids, | |
| "attention_mask": attention_mask, | |
| "head_mask": head_mask, | |
| "decoder_head_mask": decoder_head_mask, | |
| "cross_attn_head_mask": cross_attn_head_mask, | |
| "use_cache": use_cache, # change this to avoid caching (presumably for debugging) | |
| } | |
| def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | |
| return shift_tokens_right(labels, self.config.pad_token_id) | |
| def _reorder_cache(past, beam_idx): | |
| reordered_past = () | |
| for layer_past in past: | |
| # cached cross_attention states don't have to be reordered -> they are always the same | |
| reordered_past += ( | |
| tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], | |
| ) | |
| return reordered_past | |