Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2023/3/11 8:02 上午 | |
| # @Author : NuoChen | |
| # @File : code_classification.py | |
| ## ======== Roberta ======== | |
| import torch | |
| from torch import nn | |
| from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss | |
| from transformers import RobertaModel | |
| from transformers.activations import ACT2FN | |
| from transformers.models.electra import ElectraModel | |
| from transformers.models.roformer import RoFormerModel | |
| from transformers.models.albert import AlbertModel | |
| from transformers.models.bert import BertModel, BertPreTrainedModel | |
| from transformers.models.deberta_v2 import DebertaV2Model, DebertaV2PreTrainedModel | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from transformers.models.roberta import RobertaPreTrainedModel | |
| from transformers.models.bert.modeling_bert import BertForSequenceClassification | |
| from transformers.models.megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel | |
| import logging | |
| from typing import Optional, List, Union, Tuple | |
| import torch | |
| from torch._C import NoopLogger | |
| from torch.autograd import Variable | |
| import copy | |
| import torch.nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss | |
| from transformers import RobertaModel, RobertaPreTrainedModel | |
| from transformers.models.plbart.modeling_plbart import PLBartPreTrainedModel, PLBartClassificationHead, PLBartModel | |
| from transformers.models.plbart.configuration_plbart import PLBartConfig | |
| from transformers.models.t5.modeling_t5 import T5PreTrainedModel#, T5ClassificationHead, T5Model | |
| from transformers.models.t5.modeling_t5 import T5Config,T5Stack | |
| from transformers.modeling_outputs import SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput, SequenceClassifierOutputWithPast | |
| from models.basic_modules.prefix_encoder import PrefixEncoder | |
| from models.basic_modules.adapter import BertAdaModel, RobertaAdaModel, init_adapter | |
| from tools.model_utils.parameter_freeze import ParameterFreeze | |
| freezer = ParameterFreeze() | |
| ## ======== Roberta ======== | |
| # Vanilla Fine-tuning For Roberta | |
| class RobertaForCodeClassification(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.config = config | |
| self.roberta = RobertaModel(config) | |
| if self.config.use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
| self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) | |
| self.init_weights() | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| else: | |
| self.roberta = freezer.unfreeze_lm(self.roberta) | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| r""" | |
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
| Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., | |
| config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |
| If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| pooled_output = outputs[1] | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.classifier(pooled_output) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| ## ======== CodeBERT ======== | |
| # Vanilla Fine-tuning For CodeBERT | |
| class CodeBERTForCodeClassification(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.config = config | |
| self.roberta = RobertaModel(config) | |
| if self.config.use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
| self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) | |
| self.init_weights() | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| else: | |
| self.roberta = freezer.unfreeze_lm(self.roberta) | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| r""" | |
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
| Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., | |
| config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |
| If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| pooled_output = outputs[1] | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.classifier(pooled_output) | |
| loss = None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = MSELoss() | |
| if self.num_labels == 1: | |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
| else: | |
| loss = loss_fct(logits, labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| ## ======== GraphCodeBERT ======== | |
| # Vanilla Fine-tuning For GraphCodeBERT | |
| class GraphCodeBERTForCodeClassification(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.config = config | |
| self.roberta = RobertaModel(config) | |
| if self.config.use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
| self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) | |
| self.init_weights() | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| else: | |
| self.roberta = freezer.unfreeze_lm(self.roberta) | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| r""" | |
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
| Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., | |
| config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |
| If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| pooled_output = outputs[1] | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.classifier(pooled_output) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| ## ======== PLBART ======== | |
| # Vanilla Fine-tuning For PLBART | |
| class PLBARTForCodeClassification(PLBartPreTrainedModel): | |
| _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] | |
| def __init__(self, config: PLBartConfig, **kwargs): | |
| super().__init__(config, **kwargs) | |
| self.model = PLBartModel(config) | |
| self.classification_head = PLBartClassificationHead( | |
| config.d_model, | |
| config.d_model, | |
| config.num_labels, | |
| config.classifier_dropout, | |
| ) | |
| self.model._init_weights(self.classification_head.dense) | |
| self.model._init_weights(self.classification_head.out_proj) | |
| # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
| config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None: | |
| use_cache = False | |
| if input_ids is None and inputs_embeds is not None: | |
| raise NotImplementedError( | |
| f"Passing input embeddings is currently not supported for {self.__class__.__name__}" | |
| ) | |
| outputs = self.model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| decoder_input_ids=decoder_input_ids, | |
| decoder_attention_mask=decoder_attention_mask, | |
| head_mask=head_mask, | |
| decoder_head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| encoder_outputs=encoder_outputs, | |
| inputs_embeds=inputs_embeds, | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs[0] # last hidden state | |
| eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) | |
| if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: | |
| raise ValueError("All examples must have the same number of <eos> tokens.") | |
| sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ | |
| :, -1, : | |
| ] | |
| logits = self.classification_head(sentence_representation) | |
| loss = None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.config.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = MSELoss() | |
| if self.config.num_labels == 1: | |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
| else: | |
| loss = loss_fct(logits, labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return Seq2SeqSequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| decoder_hidden_states=outputs.decoder_hidden_states, | |
| decoder_attentions=outputs.decoder_attentions, | |
| cross_attentions=outputs.cross_attentions, | |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
| encoder_hidden_states=outputs.encoder_hidden_states, | |
| encoder_attentions=outputs.encoder_attentions, | |
| ) | |
| ## ======== CodeT5 ======== | |
| # Vanilla Fine-tuning For CodeT5 | |
| class CodeT5ForCodeClassification(T5PreTrainedModel): | |
| _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] | |
| def __init__(self, config: T5Config): | |
| super().__init__(config) | |
| self.model_dim = config.d_model | |
| self.config.problem_type = None | |
| self.config.is_encoder_decoder = False | |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) | |
| encoder_config = copy.deepcopy(config) | |
| encoder_config.is_decoder = False | |
| encoder_config.is_encoder_decoder = False | |
| encoder_config.use_cache = False | |
| self.encoder = T5Stack(encoder_config, self.shared) | |
| classifier_dropout = ( | |
| config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate | |
| ) | |
| self.dropout = nn.Dropout(classifier_dropout) | |
| self.classifier = nn.Linear(config.d_model, config.num_labels) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| # Model parallel | |
| self.model_parallel = False | |
| self.device_map = None | |
| def parallelize(self, device_map=None): | |
| self.device_map = ( | |
| get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) | |
| if device_map is None | |
| else device_map | |
| ) | |
| assert_device_map(self.device_map, len(self.encoder.block)) | |
| self.encoder.parallelize(self.device_map) | |
| self.classifier.to(self.encoder.first_device) | |
| self.model_parallel = True | |
| def deparallelize(self): | |
| self.encoder.deparallelize() | |
| self.encoder = self.encoder.to("cpu") | |
| self.classifier = self.classifier.to("cpu") | |
| self.model_parallel = False | |
| self.device_map = None | |
| torch.cuda.empty_cache() | |
| def get_input_embeddings(self): | |
| return self.shared | |
| def set_input_embeddings(self, new_embeddings): | |
| self.shared = new_embeddings | |
| self.encoder.set_input_embeddings(new_embeddings) | |
| def get_encoder(self): | |
| return self.encoder | |
| def _prune_heads(self, heads_to_prune): | |
| """ | |
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base | |
| class PreTrainedModel | |
| """ | |
| for layer, heads in heads_to_prune.items(): | |
| self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. | |
| Returns: | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| head_mask=head_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # Get last hidden indices | |
| # (batch_size) -> (batch_size, 1) -> (batch_size, hidden_size) -> (batch_size, 1, hidden_size) | |
| last_hidden_indices = ( | |
| (input_ids != self.config.pad_token_id).sum(dim=-1) - 1 | |
| ).unsqueeze(dim=-1).repeat(1, outputs[0].size(-1)).unsqueeze(1) | |
| sequence_output = outputs[0].gather(dim=1, index=last_hidden_indices).squeeze(1) | |
| sequence_output = self.dropout(sequence_output) | |
| logits = self.classifier(sequence_output) | |
| loss = None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.config.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = nn.MSELoss() | |
| if self.config.num_labels == 1: | |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
| else: | |
| loss = loss_fct(logits, labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = nn.BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions | |
| ) | |