Spaces:
Sleeping
Sleeping
| """Custom models for few-shot learning specific operations.""" | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| import torch.nn.functional as F | |
| from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction | |
| from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration | |
| from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, BertOnlyMLMHead | |
| from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaModel, RobertaLMHead, RobertaClassificationHead, RobertaPreTrainedModel | |
| from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2PreTrainedModel, DebertaV2Model, StableDropout, ContextPooler, DebertaV2OnlyMLMHead | |
| from transformers.models.deberta.modeling_deberta import DebertaPreTrainedModel, DebertaModel, StableDropout, ContextPooler, DebertaOnlyMLMHead | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.bert.configuration_bert import BertConfig | |
| import logging | |
| from models.basic_modules.adapter import RobertaAdaModel, BertAdaModel | |
| import os | |
| from models.basic_modules.prefix_encoder import PrefixEncoder | |
| from tools.model_utils.parameter_freeze import ParameterFreeze | |
| freezer = ParameterFreeze() | |
| logger = logging.getLogger(__name__) | |
| # Note: 如果mask_pos为None,请检查输入的模板是否有<mask>标记,是否修改data_collator文件 | |
| """ | |
| Vanilla Prompt-tuning BERT | |
| """ | |
| class PromptBertForSequenceClassification(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.pre_seq_len = self.config.pre_seq_len | |
| self.hidden_size = self.config.hidden_size | |
| # backbone | |
| self.bert = BertModel(config) | |
| if self.config.use_freezing: | |
| self.bert = freezer.freeze_lm(self.bert) | |
| # mlm head | |
| self.cls = BertOnlyMLMHead(config) | |
| self.init_weights() | |
| # These attributes should be assigned once the model is initialized | |
| self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.bert = freezer.freeze_lm(self.bert) | |
| else: | |
| self.bert = freezer.unfreeze_lm(self.bert) | |
| def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| """ | |
| Encoding and obtain logits at masked position | |
| """ | |
| if mask_pos is not None: | |
| mask_pos = mask_pos.squeeze() | |
| # Encode everything | |
| if inputs_embeds is None: | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| else: | |
| outputs = self.bert( | |
| None, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # Logits over vocabulary tokens | |
| prediction_mask_scores = self.cls(sequence_mask_output) | |
| # Exit early and only return mask logits. | |
| if return_full_softmax: | |
| return prediction_mask_scores | |
| # Return logits for each label | |
| logits = [] | |
| for label_id in range(len(self.label_word_list)): | |
| logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| logits = torch.cat(logits, -1) | |
| # Regression task | |
| if self.config.num_labels == 1: | |
| logsoftmax = nn.LogSoftmax(-1) | |
| logits = logsoftmax(logits) # Log prob of right polarity | |
| return logits, sequence_mask_output | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| mask_pos=None, | |
| labels=None, | |
| inputs_embeds=None, | |
| block_flag=None, | |
| return_dict=None, | |
| ): | |
| logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # Regression task | |
| loss_fct = nn.KLDivLoss(log_target=True) | |
| labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| loss = loss_fct(logits.view(-1, 2), labels) | |
| else: | |
| if labels.shape == logits.shape: | |
| loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| labels, reduction="batchmean") | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| output = (logits,) | |
| if self.num_labels == 1: | |
| # Regression output | |
| output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| if not return_dict: | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| """ | |
| P-tuning BERT | |
| """ | |
| class PromptBertPtuningForSequenceClassification(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.pre_seq_len = self.config.pre_seq_len | |
| self.hidden_size = self.config.hidden_size | |
| # backbone | |
| self.bert = BertModel(config) | |
| if self.config.use_freezing: | |
| self.bert = freezer.freeze_lm(self.bert) | |
| # mlm head | |
| self.cls = BertOnlyMLMHead(config) | |
| # prompt encoder | |
| self.prompt_encoder = None | |
| # plm embedding layer | |
| self.backbone_embeddings = self.bert.embeddings.word_embeddings | |
| # prompt embedding layer | |
| self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size) | |
| self.init_weights() | |
| # These attributes should be assigned once the model is initialized | |
| self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.bert = freezer.freeze_lm(self.bert) | |
| else: | |
| self.bert = freezer.unfreeze_lm(self.bert) | |
| def generate_continuous_prompt_inputs(self, input_ids, block_flag=None, reparameterization=False): | |
| """ | |
| Generate continuous prompt embedding | |
| """ | |
| inputs_embeds = self.backbone_embeddings(input_ids) | |
| batch_size = inputs_embeds.shape[0] | |
| if block_flag is None: | |
| # the first token is set 1, others are set 0 | |
| block_flag = torch.zeros_like(input_ids).long().to(inputs_embeds.device) | |
| block_flag[:, 0] = 1 | |
| try: | |
| replace_embeds = self.prompt_embeddings( | |
| torch.LongTensor(list(range(self.pre_seq_len))).to(inputs_embeds.device)) | |
| except: | |
| import pdb | |
| pdb.set_trace() | |
| replace_embeds = self.prompt_embeddings( | |
| torch.LongTensor(list(range(self.pre_seq_len)))) | |
| replace_embeds = replace_embeds.unsqueeze(0) # [batch_size, prompt_length, embed_size] | |
| if self.prompt_encoder is not None: | |
| replace_embeds = self.prompt_encoder(replace_embeds) | |
| # edit by wjn | |
| if reparameterization: | |
| # blocked_indices = (block_flag == 1).nonzero(as_tuple=False).reshape((batch_size, self.pre_seq_len, 2))[:, :, 1] | |
| blocked_indices = (block_flag == 1).nonzero() | |
| # reparameterization | |
| for bidx in range(batch_size): | |
| for i in range(blocked_indices.shape[1]): | |
| inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[:, i, :].squeeze() | |
| else: | |
| replace_embeds = replace_embeds.expand(batch_size, self.pre_seq_len, -1).to(inputs_embeds.device) | |
| inputs_embeds = torch.cat((replace_embeds, inputs_embeds), dim=1) | |
| return inputs_embeds | |
| def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| """ | |
| Encoding and obtain logits at masked position | |
| """ | |
| batch_size = inputs_embeds.shape[0] | |
| if mask_pos is not None: | |
| mask_pos = mask_pos.squeeze() | |
| # Encode everything | |
| if inputs_embeds is None: | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| else: | |
| if inputs_embeds.shape[1] == attention_mask.shape[1]: | |
| outputs = self.bert( | |
| None, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| else: | |
| if attention_mask is not None: | |
| prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).long().to(self.bert.device) | |
| attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
| if token_type_ids is not None: | |
| prefix_token_type_ids = torch.zeros(batch_size, self.pre_seq_len).long().to(self.bert.device) | |
| token_type_ids = torch.cat((prefix_token_type_ids, token_type_ids), dim=1) | |
| outputs = self.bert( | |
| None, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # Logits over vocabulary tokens | |
| prediction_mask_scores = self.cls(sequence_mask_output) | |
| # Exit early and only return mask logits. | |
| if return_full_softmax: | |
| return prediction_mask_scores | |
| # Return logits for each label | |
| logits = [] | |
| for label_id in range(len(self.label_word_list)): | |
| logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| logits = torch.cat(logits, -1) | |
| # Regression task | |
| if self.config.num_labels == 1: | |
| logsoftmax = nn.LogSoftmax(-1) | |
| logits = logsoftmax(logits) # Log prob of right polarity | |
| return logits, sequence_mask_output | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| mask_pos=None, | |
| labels=None, | |
| inputs_embeds=None, | |
| block_flag=None, | |
| return_dict=None, | |
| ): | |
| inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag) | |
| logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # Regression task | |
| loss_fct = nn.KLDivLoss(log_target=True) | |
| labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| loss = loss_fct(logits.view(-1, 2), labels) | |
| else: | |
| if labels.shape == logits.shape: | |
| loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| labels, reduction="batchmean") | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| output = (logits,) | |
| if self.num_labels == 1: | |
| # Regression output | |
| output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| if not return_dict: | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| """ | |
| Prefix-tuning BERT | |
| """ | |
| class PromptBertPrefixForSequenceClassification(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.pre_seq_len = self.config.pre_seq_len | |
| self.hidden_size = self.config.hidden_size | |
| self.n_layer = config.num_hidden_layers | |
| self.n_head = config.num_attention_heads | |
| self.n_embd = config.hidden_size // config.num_attention_heads | |
| # backbone | |
| self.bert = BertModel(config) | |
| if self.config.use_freezing: | |
| self.bert = freezer.freeze_lm(self.bert) | |
| # mlm head | |
| self.cls = BertOnlyMLMHead(config) | |
| self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
| # plm embedding layer | |
| self.backbone_embeddings = self.bert.embeddings.word_embeddings | |
| # prompt embedding layer | |
| self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size) | |
| # prefix encoder | |
| self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
| self.prefix_encoder = PrefixEncoder(config) | |
| self.init_weights() | |
| # These attributes should be assigned once the model is initialized | |
| self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.bert = freezer.freeze_lm(self.bert) | |
| else: | |
| self.bert = freezer.unfreeze_lm(self.bert) | |
| def get_prompt(self, batch_size): | |
| prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) | |
| past_key_values = self.prefix_encoder(prefix_tokens) | |
| # bsz, seqlen, _ = past_key_values.shape | |
| past_key_values = past_key_values.view( | |
| batch_size, | |
| self.pre_seq_len, | |
| self.n_layer * 2, | |
| self.n_head, | |
| self.n_embd | |
| ) | |
| past_key_values = self.dropout(past_key_values) | |
| past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
| return past_key_values | |
| def embed_encode(self, input_ids): | |
| embedding_output = self.bert.embeddings.word_embeddings(input_ids) | |
| return embedding_output | |
| def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| batch_size = input_ids.size(0) | |
| # add prefix for prompt-tuning | |
| past_key_values = self.get_prompt(batch_size=batch_size) | |
| prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) | |
| attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
| if mask_pos is not None: | |
| mask_pos = mask_pos.squeeze() | |
| # Encode everything | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| past_key_values=past_key_values, | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # Logits over vocabulary tokens | |
| prediction_mask_scores = self.cls(sequence_mask_output) | |
| # Exit early and only return mask logits. | |
| if return_full_softmax: | |
| return prediction_mask_scores | |
| # print("prediction_mask_scores.shape=", prediction_mask_scores.shape) # [batch_size, seq_len, vocab_size] | |
| # Return logits for each label | |
| logits = [] | |
| for label_id in range(len(self.label_word_list)): | |
| logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| logits = torch.cat(logits, -1) | |
| # Regression task | |
| if self.config.num_labels == 1: | |
| logsoftmax = nn.LogSoftmax(-1) | |
| logits = logsoftmax(logits) # Log prob of right polarity | |
| return logits, sequence_mask_output | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| mask_pos=None, | |
| labels=None, | |
| inputs_embeds=None, | |
| block_flag=None, | |
| return_dict=None, | |
| ): | |
| logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # Regression task | |
| loss_fct = nn.KLDivLoss(log_target=True) | |
| labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| loss = loss_fct(logits.view(-1, 2), labels) | |
| else: | |
| if labels.shape == logits.shape: | |
| loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| labels, reduction="batchmean") | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| output = (logits,) | |
| if self.num_labels == 1: | |
| # Regression output | |
| output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| if not return_dict: | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| """ | |
| Adapter-tuning BERT | |
| """ | |
| class PromptBertAdapterForSequenceClassification(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.bert = BertAdaModel(config) | |
| self.cls = BertOnlyMLMHead(config) | |
| self.init_weights() | |
| if self.config.use_freezing: | |
| self.bert = freezer.freeze_lm_component(self.bert, "adapter") | |
| # These attributes should be assigned once the model is initialized | |
| self.model_args = None | |
| self.data_args = None | |
| self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.bert = freezer.freeze_lm_component(self.bert, "adapter") | |
| else: | |
| self.bert = freezer.unfreeze_lm(self.bert) | |
| def embed_encode(self, input_ids): | |
| embedding_output = self.bert.embeddings.word_embeddings(input_ids) | |
| return embedding_output | |
| def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| batch_size = input_ids.size(0) | |
| if mask_pos is not None: | |
| mask_pos = mask_pos.squeeze() | |
| # Encode everything | |
| if inputs_embeds is None: | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| else: | |
| outputs = self.bert( | |
| None, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # Logits over vocabulary tokens | |
| prediction_mask_scores = self.cls(sequence_mask_output) | |
| # Exit early and only return mask logits. | |
| if return_full_softmax: | |
| return prediction_mask_scores | |
| # Return logits for each label | |
| logits = [] | |
| for label_id in range(len(self.label_word_list)): | |
| logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| logits = torch.cat(logits, -1) | |
| # Regression task | |
| if self.config.num_labels == 1: | |
| logsoftmax = nn.LogSoftmax(-1) | |
| logits = logsoftmax(logits) # Log prob of right polarity | |
| return logits, sequence_mask_output | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| mask_pos=None, | |
| labels=None, | |
| inputs_embeds=None, | |
| block_flag=None, | |
| return_dict=None, | |
| ): | |
| logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # Regression task | |
| loss_fct = nn.KLDivLoss(log_target=True) | |
| labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| loss = loss_fct(logits.view(-1, 2), labels) | |
| else: | |
| if labels.shape == logits.shape: | |
| loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| labels, reduction="batchmean") | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| output = (logits,) | |
| if self.num_labels == 1: | |
| # Regression output | |
| output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| if not return_dict: | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| """ | |
| Vanilla Prompt-tuning RoBERTa | |
| """ | |
| class PromptRobertaForSequenceClassification(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.pre_seq_len = self.config.pre_seq_len | |
| self.hidden_size = self.config.hidden_size | |
| # backbone | |
| self.roberta = RobertaModel(config) | |
| if self.config.use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| # mlm head | |
| self.cls = RobertaLMHead(config) | |
| self.init_weights() | |
| # These attributes should be assigned once the model is initialized | |
| self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device) | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| else: | |
| self.roberta = freezer.unfreeze_lm(self.roberta) | |
| def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| """ | |
| Encoding and obtain logits at masked position | |
| """ | |
| if mask_pos is not None: | |
| mask_pos = mask_pos.squeeze() | |
| # Encode everything | |
| if inputs_embeds is None: | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| else: | |
| outputs = self.roberta( | |
| None, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # Logits over vocabulary tokens | |
| prediction_mask_scores = self.cls(sequence_mask_output) | |
| # Exit early and only return mask logits. | |
| if return_full_softmax: | |
| return prediction_mask_scores | |
| # Return logits for each label | |
| logits = [] | |
| for label_id in range(len(self.label_word_list)): | |
| logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| logits = torch.cat(logits, -1) | |
| # Regression task | |
| if self.config.num_labels == 1: | |
| logsoftmax = nn.LogSoftmax(-1) | |
| logits = logsoftmax(logits) # Log prob of right polarity | |
| return logits, sequence_mask_output | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| mask_pos=None, | |
| labels=None, | |
| inputs_embeds=None, | |
| block_flag=None, | |
| return_dict=None, | |
| ): | |
| logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # Regression task | |
| loss_fct = nn.KLDivLoss(log_target=True) | |
| labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| loss = loss_fct(logits.view(-1, 2), labels) | |
| else: | |
| if labels.shape == logits.shape: | |
| loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| labels, reduction="batchmean") | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| output = (logits,) | |
| if self.num_labels == 1: | |
| # Regression output | |
| output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| if not return_dict: | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| """ | |
| P-tuning RoBERTa | |
| """ | |
| class PromptRobertaPtuningForSequenceClassification(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.pre_seq_len = self.config.pre_seq_len | |
| self.hidden_size = self.config.hidden_size | |
| # backbone | |
| self.roberta = RobertaModel(config) | |
| if self.config.use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| # mlm head | |
| self.cls = RobertaLMHead(config) | |
| # prompt encoder | |
| self.prompt_encoder = None | |
| # plm embedding layer | |
| self.backbone_embeddings = self.roberta.embeddings.word_embeddings | |
| # prompt embedding layer | |
| self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size) | |
| self.init_weights() | |
| # These attributes should be assigned once the model is initialized | |
| self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device) | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.roberta = freezer.freeze_lm(self.roberta) | |
| else: | |
| self.roberta = freezer.unfreeze_lm(self.roberta) | |
| def generate_continuous_prompt_inputs(self, input_ids, block_flag=None, reparameterization=False): | |
| """ | |
| Generate continuous prompt embedding | |
| """ | |
| inputs_embeds = self.backbone_embeddings(input_ids) | |
| batch_size = inputs_embeds.shape[0] | |
| if block_flag is None: | |
| # the first token is set 1, others are set 0 | |
| block_flag = torch.zeros_like(input_ids).long().to(inputs_embeds.device) | |
| block_flag[:, 0] = 1 | |
| try: | |
| replace_embeds = self.prompt_embeddings( | |
| torch.LongTensor(list(range(self.pre_seq_len))).to(inputs_embeds.device)) | |
| except: | |
| import pdb | |
| pdb.set_trace() | |
| replace_embeds = self.prompt_embeddings(torch.LongTensor(list(range(self.pre_seq_len)))) | |
| replace_embeds = replace_embeds.unsqueeze(0) # [batch_size, prompt_length, embed_size] | |
| if self.prompt_encoder is not None: | |
| replace_embeds = self.prompt_encoder(replace_embeds) | |
| # edit by wjn | |
| if reparameterization: | |
| # blocked_indices = (block_flag == 1).nonzero(as_tuple=False).reshape((batch_size, self.pre_seq_len, 2))[:, :, 1] | |
| blocked_indices = (block_flag == 1).nonzero() | |
| # reparameterization | |
| for bidx in range(batch_size): | |
| for i in range(blocked_indices.shape[1]): | |
| inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[:, i, :].squeeze() | |
| else: | |
| replace_embeds = replace_embeds.expand(batch_size, self.pre_seq_len, -1).to(inputs_embeds.device) | |
| inputs_embeds = torch.cat((replace_embeds, inputs_embeds), dim=1) | |
| return inputs_embeds | |
| def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| """ | |
| Encoding and obtain logits at masked position | |
| """ | |
| batch_size = inputs_embeds.shape[0] | |
| if mask_pos is not None: | |
| mask_pos = mask_pos.squeeze() | |
| # Encode everything | |
| if inputs_embeds is None: | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| else: | |
| if inputs_embeds.shape[1] == attention_mask.shape[1]: | |
| outputs = self.roberta( | |
| None, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| else: | |
| if attention_mask is not None: | |
| prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).long().to(self.roberta.device) | |
| attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
| if token_type_ids is not None: | |
| prefix_token_type_ids = torch.zeros(batch_size, self.pre_seq_len).long().to(self.roberta.device) | |
| token_type_ids = torch.cat((prefix_token_type_ids, token_type_ids), dim=1) | |
| outputs = self.roberta( | |
| None, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # Logits over vocabulary tokens | |
| prediction_mask_scores = self.cls(sequence_mask_output) | |
| # Exit early and only return mask logits. | |
| if return_full_softmax: | |
| return prediction_mask_scores | |
| # Return logits for each label | |
| logits = [] | |
| for label_id in range(len(self.label_word_list)): | |
| logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| logits = torch.cat(logits, -1) | |
| # Regression task | |
| if self.config.num_labels == 1: | |
| logsoftmax = nn.LogSoftmax(-1) | |
| logits = logsoftmax(logits) # Log prob of right polarity | |
| return logits, sequence_mask_output | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| mask_pos=None, | |
| labels=None, | |
| inputs_embeds=None, | |
| block_flag=None, | |
| return_dict=None, | |
| ): | |
| inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag) | |
| logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # Regression task | |
| loss_fct = nn.KLDivLoss(log_target=True) | |
| labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| loss = loss_fct(logits.view(-1, 2), labels) | |
| else: | |
| if labels.shape == logits.shape: | |
| loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| labels, reduction="batchmean") | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| output = (logits,) | |
| if self.num_labels == 1: | |
| # Regression output | |
| output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| if not return_dict: | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| """ | |
| Prefix-tuning RoBERTa | |
| """ | |
| class PromptRobertaPrefixForSequenceClassification(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.pre_seq_len = self.config.pre_seq_len | |
| self.hidden_size = self.config.hidden_size | |
| self.n_layer = config.num_hidden_layers | |
| self.n_head = config.num_attention_heads | |
| self.n_embd = config.hidden_size // config.num_attention_heads | |
| # backbone | |
| self.robert = RobertaModel(config) | |
| if self.config.use_freezing: | |
| self.robert = freezer.freeze_lm(self.robert) | |
| # mlm head | |
| self.cls = RobertaLMHead(config) | |
| self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
| # plm embedding layer | |
| self.backbone_embeddings = self.robert.embeddings.word_embeddings | |
| # prompt embedding layer | |
| self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size) | |
| # prefix encoder | |
| self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
| self.prefix_encoder = PrefixEncoder(config) | |
| self.init_weights() | |
| # These attributes should be assigned once the model is initialized | |
| self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.robert.device) | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.robert = freezer.freeze_lm(self.robert) | |
| else: | |
| self.robert = freezer.unfreeze_lm(self.robert) | |
| def get_prompt(self, batch_size): | |
| prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.robert.device) | |
| past_key_values = self.prefix_encoder(prefix_tokens) | |
| # bsz, seqlen, _ = past_key_values.shape | |
| past_key_values = past_key_values.view( | |
| batch_size, | |
| self.pre_seq_len, | |
| self.n_layer * 2, | |
| self.n_head, | |
| self.n_embd | |
| ) | |
| past_key_values = self.dropout(past_key_values) | |
| past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
| return past_key_values | |
| def embed_encode(self, input_ids): | |
| embedding_output = self.robert.embeddings.word_embeddings(input_ids) | |
| return embedding_output | |
| def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| batch_size = input_ids.size(0) | |
| # add prefix for prompt-tuning | |
| past_key_values = self.get_prompt(batch_size=batch_size) | |
| prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.robert.device) | |
| attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
| if mask_pos is not None: | |
| mask_pos = mask_pos.squeeze() | |
| # Encode everything | |
| outputs = self.robert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| past_key_values=past_key_values, | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # Logits over vocabulary tokens | |
| prediction_mask_scores = self.cls(sequence_mask_output) | |
| # Exit early and only return mask logits. | |
| if return_full_softmax: | |
| return prediction_mask_scores | |
| # Return logits for each label | |
| logits = [] | |
| for label_id in range(len(self.label_word_list)): | |
| logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| logits = torch.cat(logits, -1) | |
| # Regression task | |
| if self.config.num_labels == 1: | |
| logsoftmax = nn.LogSoftmax(-1) | |
| logits = logsoftmax(logits) # Log prob of right polarity | |
| return logits, sequence_mask_output | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| mask_pos=None, | |
| labels=None, | |
| inputs_embeds=None, | |
| block_flag=None, | |
| return_dict=None, | |
| ): | |
| logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # Regression task | |
| loss_fct = nn.KLDivLoss(log_target=True) | |
| labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| loss = loss_fct(logits.view(-1, 2), labels) | |
| else: | |
| if labels.shape == logits.shape: | |
| loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| labels, reduction="batchmean") | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| output = (logits,) | |
| if self.num_labels == 1: | |
| # Regression output | |
| output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| if not return_dict: | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| """ | |
| Adapter-tuning RoBERTa | |
| """ | |
| class PromptRobertaAdapterForSequenceClassification(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.roberta = RobertaAdaModel(config) | |
| self.cls = RobertaLMHead(config) | |
| self.init_weights() | |
| if self.config.use_freezing: | |
| self.roberta = freezer.freeze_lm_component(self.roberta, "adapter") | |
| # These attributes should be assigned once the model is initialized | |
| self.model_args = None | |
| self.data_args = None | |
| self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device) | |
| # For regression | |
| self.lb = None | |
| self.ub = None | |
| # For label search. | |
| self.return_full_softmax = None | |
| def freeze_backbone(self, use_freezing: bool=True): | |
| if use_freezing: | |
| self.roberta = freezer.freeze_lm_component(self.roberta, "adapter") | |
| else: | |
| self.roberta = freezer.unfreeze_lm(self.berobertart) | |
| def embed_encode(self, input_ids): | |
| embedding_output = self.roberta.embeddings.word_embeddings(input_ids) | |
| return embedding_output | |
| def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| batch_size = input_ids.size(0) | |
| if mask_pos is not None: | |
| mask_pos = mask_pos.squeeze() | |
| # Encode everything | |
| if inputs_embeds is None: | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| else: | |
| outputs = self.roberta( | |
| None, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| inputs_embeds=inputs_embeds | |
| ) | |
| # Get <mask> token representation | |
| sequence_output, pooled_output = outputs[:2] | |
| sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # Logits over vocabulary tokens | |
| prediction_mask_scores = self.cls(sequence_mask_output) | |
| # Exit early and only return mask logits. | |
| if return_full_softmax: | |
| return prediction_mask_scores | |
| # Return logits for each label | |
| logits = [] | |
| for label_id in range(len(self.label_word_list)): | |
| logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| logits = torch.cat(logits, -1) | |
| # Regression task | |
| if self.config.num_labels == 1: | |
| logsoftmax = nn.LogSoftmax(-1) | |
| logits = logsoftmax(logits) # Log prob of right polarity | |
| return logits, sequence_mask_output | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| mask_pos=None, | |
| labels=None, | |
| inputs_embeds=None, | |
| block_flag=None, | |
| return_dict=None, | |
| ): | |
| logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # Regression task | |
| loss_fct = nn.KLDivLoss(log_target=True) | |
| labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| loss = loss_fct(logits.view(-1, 2), labels) | |
| else: | |
| if labels.shape == logits.shape: | |
| loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| labels, reduction="batchmean") | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| output = (logits,) | |
| if self.num_labels == 1: | |
| # Regression output | |
| output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| if not return_dict: | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| # class DebertaForPromptFinetuning(DebertaPreTrainedModel): | |
| # _keys_to_ignore_on_load_unexpected = [r"pooler"] | |
| # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] | |
| # def __init__(self, config): | |
| # super().__init__(config) | |
| # self.num_labels = config.num_labels | |
| # #self.deberta = DebertaV2Model(config) | |
| # self.deberta = DebertaModel(config) | |
| # self.cls = DebertaOnlyMLMHead(config) | |
| # if self.config.use_freezing: | |
| # self.deberta = freezer.freeze_lm(self.deberta) | |
| # self.pooler = ContextPooler(config) | |
| # output_dim = self.pooler.output_dim | |
| # self.classifier = torch.nn.Linear(output_dim, self.num_labels) | |
| # drop_out = getattr(config, "cls_dropout", None) | |
| # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out | |
| # self.dropout = StableDropout(drop_out) | |
| # classification_list = [self.pooler, self.dropout,self.classifier] | |
| # self.classifier = nn.Sequential(*classification_list) | |
| # # self.cls = DebertaV2OnlyMLMHead(config) | |
| # self.map = nn.Linear(config.hidden_size, config.hidden_size) | |
| # self.init_weights() | |
| # # These attributes should be assigned once the model is initialized | |
| # self.model_args = None | |
| # self.data_args = None | |
| # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) | |
| # self.K = 1 | |
| # self.step_size=1e-5 | |
| # # import pdb | |
| # # pdb.set_trace() | |
| # #self.step_size=config.step_size | |
| # # For regression | |
| # self.lb = None | |
| # self.ub = None | |
| # self.pre_seq_len = self.config.pre_seq_len | |
| # # For auto label search. | |
| # self.return_full_softmax = None | |
| # def freeze_backbone(self, use_freezing: bool=True): | |
| # if use_freezing: | |
| # self.deberta = freezer.freeze_lm(self.deberta) | |
| # else: | |
| # self.deberta = freezer.unfreeze_lm(self.deberta) | |
| # def embed_encode(self, input_ids): | |
| # embedding_output = self.deberta.embeddings.word_embeddings(input_ids) | |
| # return embedding_output | |
| # def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, | |
| # return_full_softmax=False): | |
| # batch_size = input_ids.size(0) | |
| # if mask_pos is not None: | |
| # mask_pos = mask_pos.squeeze() | |
| # # Encode everything | |
| # if inputs_embeds is None: | |
| # outputs = self.deberta( | |
| # input_ids, | |
| # attention_mask=attention_mask, | |
| # token_type_ids=token_type_ids | |
| # ) | |
| # else: | |
| # outputs = self.deberta( | |
| # None, | |
| # attention_mask=attention_mask, | |
| # token_type_ids=token_type_ids, | |
| # inputs_embeds=inputs_embeds | |
| # ) | |
| # # Get <mask> token representation | |
| # sequence_output = outputs[0] | |
| # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # # Logits over vocabulary tokens | |
| # prediction_mask_scores = self.cls(sequence_mask_output) | |
| # # sequence_mask_output = self.lm_head.dense(sequence_mask_output) | |
| # # Exit early and only return mask logits. | |
| # if return_full_softmax: | |
| # return prediction_mask_scores | |
| # # Return logits for each label | |
| # logits = [] | |
| # for label_id in range(len(self.label_word_list)): | |
| # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| # logits = torch.cat(logits, -1) | |
| # # Regression task | |
| # if self.config.num_labels == 1: | |
| # logsoftmax = nn.LogSoftmax(-1) | |
| # logits = logsoftmax(logits) # Log prob of right polarity | |
| # if self.model_args.hybrid == 1: | |
| # cls_logits = self.classifier(sequence_output) | |
| # return (logits, cls_logits), sequence_mask_output | |
| # return logits, sequence_mask_output | |
| # def forward( | |
| # self, | |
| # input_ids=None, | |
| # attention_mask=None, | |
| # token_type_ids=None, | |
| # mask_pos=None, | |
| # labels=None, | |
| # inputs_embeds=None, | |
| # fwd_type=0, | |
| # block_flag=None | |
| # ): | |
| # if fwd_type == 2: | |
| # assert inputs_embeds is not None | |
| # return self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, | |
| # mask_pos=mask_pos, inputs_embeds=inputs_embeds) | |
| # elif fwd_type == 1: | |
| # return self.embed_encode(input_ids) | |
| # if (self.model_args.prompt_ptuning or self.model_args.prompt_prefix) and block_flag is not None: | |
| # inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag) | |
| # logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| # if self.model_args.hybrid == 1: | |
| # logits = logits[0] | |
| # cls_logits = logits[1] | |
| # loss = None | |
| # if labels is not None: | |
| # if self.num_labels == 1: | |
| # # Regression task | |
| # loss_fct = nn.KLDivLoss(log_target=True) | |
| # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), | |
| # (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| # loss = loss_fct(logits.view(-1, 2), labels) | |
| # else: | |
| # if labels.shape == logits.shape: | |
| # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| # labels, reduction="batchmean") | |
| # else: | |
| # loss_fct = nn.CrossEntropyLoss() | |
| # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| # output = (logits,) | |
| # if self.num_labels == 1: | |
| # # Regression output | |
| # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| # return ((loss,) + output) if loss is not None else output | |
| # # add by wjn | |
| # # Prefix-tuning for Deberta | |
| # class DebertaPrefixForPromptFinetuning(DebertaPreTrainedModel): | |
| # def __init__(self, config): | |
| # super().__init__(config) | |
| # self.num_labels = config.num_labels | |
| # #self.deberta = DebertaV2Model(config) | |
| # self.deberta = DebertaModel(config) | |
| # self.cls = DebertaOnlyMLMHead(config) | |
| # self.pooler = ContextPooler(config) | |
| # output_dim = self.pooler.output_dim | |
| # self.classifier = torch.nn.Linear(output_dim, self.num_labels) | |
| # drop_out = getattr(config, "cls_dropout", None) | |
| # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out | |
| # self.dropout = StableDropout(drop_out) | |
| # classification_list = [self.pooler, self.dropout,self.classifier] | |
| # self.classifier = nn.Sequential(*classification_list) | |
| # # self.cls = DebertaV2OnlyMLMHead(config) | |
| # self.map = nn.Linear(config.hidden_size, config.hidden_size) | |
| # self.init_weights() | |
| # if self.config.use_freezing: | |
| # self.deberta = freezer.freeze_lm(self.deberta) | |
| # self.pre_seq_len = config.pre_seq_len | |
| # self.n_layer = config.num_hidden_layers | |
| # self.n_head = config.num_attention_heads | |
| # self.n_embd = config.hidden_size // config.num_attention_heads | |
| # self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
| # self.prefix_encoder = PrefixEncoder(config) | |
| # # These attributes should be assigned once the model is initialized | |
| # self.model_args = None | |
| # self.data_args = None | |
| # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) | |
| # self.K = 1 | |
| # self.step_size=1e-5 | |
| # # import pdb | |
| # # pdb.set_trace() | |
| # #self.step_size=config.step_size | |
| # # For regression | |
| # self.lb = None | |
| # self.ub = None | |
| # # For auto label search. | |
| # self.return_full_softmax = None | |
| # def freeze_backbone(self, use_freezing: bool=True): | |
| # if use_freezing: | |
| # self.deberta = freezer.freeze_lm(self.deberta) | |
| # else: | |
| # self.deberta = freezer.unfreeze_lm(self.deberta) | |
| # def get_prompt(self, batch_size): | |
| # prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) | |
| # past_key_values = self.prefix_encoder(prefix_tokens) | |
| # # bsz, seqlen, _ = past_key_values.shape | |
| # past_key_values = past_key_values.view( | |
| # batch_size, | |
| # self.pre_seq_len, | |
| # self.n_layer * 2, | |
| # self.n_head, | |
| # self.n_embd | |
| # ) | |
| # past_key_values = self.dropout(past_key_values) | |
| # past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
| # return past_key_values | |
| # def get_constrast_loss(self, | |
| # input_ids=None, | |
| # attention_mask=None, | |
| # mask_pos=None, | |
| # labels=None, | |
| # inputs_embeds=None): | |
| # self.cos = nn.CosineSimilarity(dim=-1) | |
| # _, sequence_mask_output_1 = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds) | |
| # _, sequence_mask_output_2 = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds) | |
| # sequence_mask_output_1= self.lm_head.dense(sequence_mask_output_1) | |
| # sequence_mask_output_2 = self.lm_head.dense(sequence_mask_output_2) | |
| # # input_args = [input_ids, attention_mask, mask_pos, labels, None, 1] | |
| # # embed = self.forward(*input_args) | |
| # # | |
| # # vat_args = [input_ids, attention_mask, mask_pos, labels, embed, 2] | |
| # # | |
| # # adv_logits, outputs = self.forward(*vat_args) | |
| # # | |
| # # logit_mask = F.softmax(logits, dim=-1)[torch.arange(adv_logits.size(0)), labels] > 0.7 | |
| # # | |
| # # outputs = outputs[logit_mask] | |
| # # seq_outputs = sequence_mask_output[logit_mask] | |
| # # new_label = labels[logit_mask] | |
| # # # | |
| # # # | |
| # # rand_perm = torch.randperm(outputs.size(0)) | |
| # # rand_outputs = outputs[rand_perm, :] | |
| # # rand_label = new_label[rand_perm] | |
| # # pair_label = (new_label == rand_label).long() | |
| # # | |
| # # seq_outputs = self.map(seq_outputs) | |
| # # rand_outputs = self.map(rand_outputs) | |
| # pair_labels = (labels.unsqueeze(1) == labels.unsqueeze(0)).float() | |
| # # import pdb | |
| # # pdb.set_trace() | |
| # contra_loss = self.contra_lc(sequence_mask_output_1.unsqueeze(1), sequence_mask_output_2.unsqueeze(0), pair_labels) | |
| # if torch.isnan(contra_loss): | |
| # return 0 | |
| # return contra_loss | |
| # def embed_encode(self, input_ids): | |
| # embedding_output = self.deberta.embeddings.word_embeddings(input_ids) | |
| # return embedding_output | |
| # def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| # batch_size = input_ids.size(0) | |
| # # add prefix for prompt-tuning | |
| # past_key_values = self.get_prompt(batch_size=batch_size) | |
| # prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) | |
| # attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
| # if mask_pos is not None: | |
| # mask_pos = mask_pos.squeeze() | |
| # # Encode everything | |
| # outputs = self.deberta( | |
| # input_ids, | |
| # attention_mask=attention_mask, | |
| # token_type_ids=token_type_ids, | |
| # past_key_values=past_key_values, | |
| # ) | |
| # # Get <mask> token representation | |
| # sequence_output, pooled_output = outputs[:2] | |
| # # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # # Logits over vocabulary tokens | |
| # prediction_mask_scores = self.cls(sequence_mask_output) | |
| # #sequence_mask_output = self.lm_head.dense(sequence_mask_output) | |
| # # Exit early and only return mask logits. | |
| # if return_full_softmax: | |
| # return prediction_mask_scores | |
| # # Return logits for each label | |
| # logits = [] | |
| # for label_id in range(len(self.label_word_list)): | |
| # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| # logits = torch.cat(logits, -1) | |
| # # Regression task | |
| # if self.config.num_labels == 1: | |
| # logsoftmax = nn.LogSoftmax(-1) | |
| # logits = logsoftmax(logits) # Log prob of right polarity | |
| # if self.model_args.hybrid == 1: | |
| # cls_logits = self.classifier(sequence_output) | |
| # return (logits, cls_logits), sequence_mask_output | |
| # return logits, sequence_mask_output | |
| # def forward( | |
| # self, | |
| # input_ids=None, | |
| # attention_mask=None, | |
| # token_type_ids=None, | |
| # mask_pos=None, | |
| # labels=None, | |
| # inputs_embeds=None, | |
| # fwd_type=0, | |
| # block_flag=None, | |
| # return_dict=None, | |
| # ): | |
| # if fwd_type == 2: | |
| # assert inputs_embeds is not None | |
| # return self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, | |
| # mask_pos=mask_pos, inputs_embeds=inputs_embeds) | |
| # elif fwd_type == 1: | |
| # return self.embed_encode(input_ids) | |
| # if (self.model_args.prompt_ptuning or self.model_args.prompt_prefix) and block_flag is not None: | |
| # inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag) | |
| # logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) | |
| # if self.model_args.hybrid == 1: | |
| # logits = logits[0] | |
| # cls_logits = logits[1] | |
| # loss = None | |
| # if labels is not None: | |
| # if self.num_labels == 1: | |
| # # Regression task | |
| # loss_fct = nn.KLDivLoss(log_target=True) | |
| # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), | |
| # (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| # loss = loss_fct(logits.view(-1, 2), labels) | |
| # else: | |
| # if labels.shape == logits.shape: | |
| # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| # labels, reduction="batchmean") | |
| # else: | |
| # loss_fct = nn.CrossEntropyLoss() | |
| # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| # output = (logits,) | |
| # if self.num_labels == 1: | |
| # # Regression output | |
| # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| # if not return_dict: | |
| # return ((loss,) + output) if loss is not None else output | |
| # return SequenceClassifierOutput( | |
| # loss=loss, | |
| # logits=logits, | |
| # ) | |
| # class Debertav2ForPromptFinetuning(DebertaV2PreTrainedModel): | |
| # _keys_to_ignore_on_load_unexpected = [r"pooler"] | |
| # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] | |
| # def __init__(self, config): | |
| # super().__init__(config) | |
| # self.num_labels = config.num_labels | |
| # self.deberta = DebertaV2Model(config) | |
| # if self.config.use_freezing: | |
| # self.deberta = freezer.freeze_lm(self.deberta) | |
| # self.cls = DebertaV2OnlyMLMHead(config) | |
| # #self.deberta = DebertaModel(config) | |
| # #self.cls = DebertaOnlyMLMHead(config) | |
| # self.pooler = ContextPooler(config) | |
| # output_dim = self.pooler.output_dim | |
| # self.classifier = torch.nn.Linear(output_dim, self.num_labels) | |
| # drop_out = getattr(config, "cls_dropout", None) | |
| # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out | |
| # self.dropout = StableDropout(drop_out) | |
| # classification_list = [self.pooler, self.dropout,self.classifier] | |
| # self.classifier = nn.Sequential(*classification_list) | |
| # # self.cls = DebertaV2OnlyMLMHead(config) | |
| # self.map = nn.Linear(config.hidden_size, config.hidden_size) | |
| # self.init_weights() | |
| # # These attributes should be assigned once the model is initialized | |
| # self.model_args = None | |
| # self.data_args = None | |
| # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) | |
| # self.K = 1 | |
| # self.step_size=1e-5 | |
| # # import pdb | |
| # # pdb.set_trace() | |
| # #self.step_size=config.step_size | |
| # # For regression | |
| # self.lb = None | |
| # self.ub = None | |
| # self.pre_seq_len = self.config.pre_seq_len | |
| # # For auto label search. | |
| # self.return_full_softmax = None | |
| # def freeze_backbone(self, use_freezing: bool=True): | |
| # if use_freezing: | |
| # self.deberta = freezer.freeze_lm(self.deberta) | |
| # else: | |
| # self.deberta = freezer.unfreeze_lm(self.deberta) | |
| # def embed_encode(self, input_ids): | |
| # embedding_output = self.deberta.embeddings.word_embeddings(input_ids) | |
| # return embedding_output | |
| # def encode(self, input_ids=None, attention_mask=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| # batch_size = input_ids.size(0) | |
| # if mask_pos is not None: | |
| # mask_pos = mask_pos.squeeze() | |
| # # Encode everything | |
| # if inputs_embeds is None: | |
| # outputs = self.deberta( | |
| # input_ids, | |
| # attention_mask=attention_mask | |
| # ) | |
| # else: | |
| # outputs = self.deberta( | |
| # None, | |
| # attention_mask=attention_mask, | |
| # inputs_embeds=inputs_embeds | |
| # ) | |
| # # Get <mask> token representation | |
| # sequence_output = outputs[0] | |
| # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # # Logits over vocabulary tokens | |
| # prediction_mask_scores = self.cls(sequence_mask_output) | |
| # #sequence_mask_output = self.lm_head.dense(sequence_mask_output) | |
| # # Exit early and only return mask logits. | |
| # if return_full_softmax: | |
| # return prediction_mask_scores | |
| # # Return logits for each label | |
| # logits = [] | |
| # for label_id in range(len(self.label_word_list)): | |
| # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| # logits = torch.cat(logits, -1) | |
| # # Regression task | |
| # if self.config.num_labels == 1: | |
| # logsoftmax = nn.LogSoftmax(-1) | |
| # logits = logsoftmax(logits) # Log prob of right polarity | |
| # return logits, sequence_mask_output | |
| # def forward( | |
| # self, | |
| # input_ids=None, | |
| # attention_mask=None, | |
| # mask_pos=None, | |
| # labels=None, | |
| # inputs_embeds=None, | |
| # fwd_type=0, | |
| # block_flag=None, | |
| # return_dict=None | |
| # ): | |
| # if fwd_type == 2: | |
| # assert inputs_embeds is not None | |
| # return self.encode(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos, inputs_embeds=inputs_embeds) | |
| # elif fwd_type == 1: | |
| # return self.embed_encode(input_ids) | |
| # logits, sequence_mask_output = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds) | |
| # loss = None | |
| # if labels is not None: | |
| # if self.num_labels == 1: | |
| # # Regression task | |
| # loss_fct = nn.KLDivLoss(log_target=True) | |
| # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| # loss = loss_fct(logits.view(-1, 2), labels) | |
| # else: | |
| # if labels.shape == logits.shape: | |
| # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| # labels, reduction="batchmean") | |
| # else: | |
| # loss_fct = nn.CrossEntropyLoss() | |
| # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| # if self.model_args.hybrid == 1: | |
| # cls_loss = loss_fct(cls_logits.view(-1, cls_logits.size(-1)), labels.view(-1)) | |
| # loss = loss + cls_loss | |
| # output = (logits,) | |
| # if self.num_labels == 1: | |
| # # Regression output | |
| # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| # if not return_dict: | |
| # return ((loss,) + output) if loss is not None else output | |
| # return SequenceClassifierOutput( | |
| # loss=loss, | |
| # logits=logits, | |
| # ) | |
| # class Debertav2PrefixForPromptFinetuning(DebertaV2PreTrainedModel): | |
| # _keys_to_ignore_on_load_unexpected = [r"pooler"] | |
| # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] | |
| # def __init__(self, config): | |
| # super().__init__(config) | |
| # self.num_labels = config.num_labels | |
| # self.deberta = DebertaV2Model(config) | |
| # self.cls = DebertaV2OnlyMLMHead(config) | |
| # #self.deberta = DebertaModel(config) | |
| # #self.cls = DebertaOnlyMLMHead(config) | |
| # self.pooler = ContextPooler(config) | |
| # output_dim = self.pooler.output_dim | |
| # self.classifier = torch.nn.Linear(output_dim, self.num_labels) | |
| # drop_out = getattr(config, "cls_dropout", None) | |
| # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out | |
| # self.dropout = StableDropout(drop_out) | |
| # classification_list = [self.pooler, self.dropout,self.classifier] | |
| # self.classifier = nn.Sequential(*classification_list) | |
| # # self.cls = DebertaV2OnlyMLMHead(config) | |
| # self.map = nn.Linear(config.hidden_size, config.hidden_size) | |
| # self.init_weights() | |
| # if self.config.use_freezing: | |
| # self.deberta = freezer.freeze_lm(self.deberta) | |
| # self.pre_seq_len = config.pre_seq_len | |
| # self.n_layer = config.num_hidden_layers | |
| # self.n_head = config.num_attention_heads | |
| # self.n_embd = config.hidden_size // config.num_attention_heads | |
| # self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
| # self.prefix_encoder = PrefixEncoder(config) | |
| # # These attributes should be assigned once the model is initialized | |
| # self.model_args = None | |
| # self.data_args = None | |
| # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) | |
| # self.K = 1 | |
| # self.step_size=1e-5 | |
| # # import pdb | |
| # # pdb.set_trace() | |
| # #self.step_size=config.step_size | |
| # # For regression | |
| # self.lb = None | |
| # self.ub = None | |
| # # For auto label search. | |
| # self.return_full_softmax = None | |
| # def freeze_backbone(self, use_freezing: bool=True): | |
| # if use_freezing: | |
| # self.deberta = freezer.freeze_lm(self.deberta) | |
| # else: | |
| # self.deberta = freezer.unfreeze_lm(self.deberta) | |
| # def get_prompt(self, batch_size): | |
| # prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) | |
| # past_key_values = self.prefix_encoder(prefix_tokens) | |
| # # bsz, seqlen, _ = past_key_values.shape | |
| # past_key_values = past_key_values.view( | |
| # batch_size, | |
| # self.pre_seq_len, | |
| # self.n_layer * 2, | |
| # self.n_head, | |
| # self.n_embd | |
| # ) | |
| # past_key_values = self.dropout(past_key_values) | |
| # past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
| # return past_key_values | |
| # def embed_encode(self, input_ids): | |
| # embedding_output = self.deberta.embeddings.word_embeddings(input_ids) | |
| # return embedding_output | |
| # def encode(self, input_ids=None, attention_mask=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): | |
| # batch_size = input_ids.size(0) | |
| # # add prefix for prompt-tuning | |
| # past_key_values = self.get_prompt(batch_size=batch_size) | |
| # prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) | |
| # attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
| # if mask_pos is not None: | |
| # mask_pos = mask_pos.squeeze() | |
| # # Encode everything | |
| # outputs = self.deberta( | |
| # input_ids, | |
| # attention_mask=attention_mask, | |
| # past_key_values=past_key_values, | |
| # ) | |
| # # Get <mask> token representation | |
| # sequence_output = outputs[0] | |
| # # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() | |
| # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] | |
| # # Logits over vocabulary tokens | |
| # prediction_mask_scores = self.cls(sequence_mask_output) | |
| # #sequence_mask_output = self.lm_head.dense(sequence_mask_output) | |
| # # Exit early and only return mask logits. | |
| # if return_full_softmax: | |
| # return prediction_mask_scores | |
| # # Return logits for each label | |
| # logits = [] | |
| # for label_id in range(len(self.label_word_list)): | |
| # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) | |
| # logits = torch.cat(logits, -1) | |
| # # Regression task | |
| # if self.config.num_labels == 1: | |
| # logsoftmax = nn.LogSoftmax(-1) | |
| # logits = logsoftmax(logits) # Log prob of right polarity | |
| # return logits, sequence_mask_output | |
| # def forward( | |
| # self, | |
| # input_ids=None, | |
| # attention_mask=None, | |
| # mask_pos=None, | |
| # labels=None, | |
| # inputs_embeds=None, | |
| # fwd_type=0, | |
| # block_flag=None, | |
| # return_dict=None, | |
| # ): | |
| # if fwd_type == 2: | |
| # assert inputs_embeds is not None | |
| # return self.encode(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos, inputs_embeds=inputs_embeds) | |
| # elif fwd_type == 1: | |
| # return self.embed_encode(input_ids) | |
| # logits, sequence_mask_output = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds) | |
| # loss = None | |
| # if labels is not None: | |
| # if self.num_labels == 1: | |
| # # Regression task | |
| # loss_fct = nn.KLDivLoss(log_target=True) | |
| # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) | |
| # loss = loss_fct(logits.view(-1, 2), labels) | |
| # else: | |
| # if labels.shape == logits.shape: | |
| # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), | |
| # labels, reduction="batchmean") | |
| # else: | |
| # loss_fct = nn.CrossEntropyLoss() | |
| # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| # if self.model_args.hybrid == 1: | |
| # cls_loss = loss_fct(cls_logits.view(-1, cls_logits.size(-1)), labels.view(-1)) | |
| # loss = loss + cls_loss | |
| # output = (logits,) | |
| # if self.num_labels == 1: | |
| # # Regression output | |
| # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) | |
| # if not return_dict: | |
| # return ((loss,) + output) if loss is not None else output | |
| # return SequenceClassifierOutput( | |
| # loss=loss, | |
| # logits=logits, | |
| # ) | |