Spaces:
Running
Running
| import torch | |
| from tqdm import tqdm | |
| from typing import Optional, Tuple | |
| from turtle import forward | |
| from torch.nn import CrossEntropyLoss | |
| from transformers import AutoModelForCausalLM | |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model | |
| class GPT2ForInContextClassification(GPT2LMHeadModel): | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, # input token id | |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| label_masks: Optional[torch.LongTensor] = None, # mask=1 means it should be calculated loss | |
| options :Optional[list] = None, # 如果是分类任务,则可以添加候选label | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| assert len(input_ids.shape) == 3 and input_ids.shape[1] == len(options) # [n, option_size, len] | |
| batch_size = input_ids.shape[0] | |
| option_size = input_ids.shape[1] | |
| input_ids = input_ids.view(-1, input_ids.shape[1], input_ids.shape[2]) # [n*option_size, len] | |
| attention_mask = attention_mask.view(-1, input_ids.shape[1], input_ids.shape[2]) if attention_mask is not None else None # [n*option_size, len] | |
| token_type_ids = token_type_ids.view(-1, input_ids.shape[1], input_ids.shape[2]) if token_type_ids is not None else None# [n*option_size, len] | |
| # labels = labels.view(-1, input_ids.shape[1], input_ids.shape[2]) # [n*option_size, len] | |
| transformer_outputs = self.transformer( | |
| input_ids, | |
| past_key_values=past_key_values, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = transformer_outputs[0] # [n*option_size, len, hidden_size] | |
| lm_logits = self.lm_head(hidden_states) # [n*option_size, len, vocab_size] | |
| lm_logits = lm_logits.view(batch_size, option_size, input_ids.shape[-1], -1) # [n, option_size, len, vocab_size] | |
| # print("len(input_ids)=", len(input_ids[0])) | |
| # print("input_ids[-1]=", input_ids[0][-1]) | |
| print("lm_logits.shape=", lm_logits.shape) | |
| losses = list() | |
| if labels is not None: | |
| for label, lm_logit in zip(labels, lm_logits): | |
| # label: [option_size, len] | |
| # lm_logit: [option_size, len, vocab_size] | |
| shift_logits = lm_logit[..., :-1, :].contiguous() | |
| # print("shift_logits.shape=", shift_logits.shape) | |
| shift_labels = label[..., 1:].contiguous() | |
| # print("shift_labels=", shift_labels) | |
| # print("shift_labels.shape=", shift_labels.shape) | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss() | |
| print("shift_logits.shape=", shift_logits.shape) | |
| print("shift_labels.shape=", shift_labels.shape) | |
| loss = [loss_fct(shift_logit.view(-1, shift_logit.size(-1)), shift_label.view(-1)) for shift_logit, shift_label in zip(shift_logits, shift_labels)] | |
| loss = torch.stack(loss) | |
| # print("loss=", loss) | |
| if label_masks is not None: | |
| loss = loss.view(lm_logits.size(0), lm_logits.size(1)) * label_masks # [option_size, len] | |
| loss = torch.sum(loss, axis=1) / torch.sum(label_mask, axis=1) # [option_size] | |
| losses.append(loss) | |
| losses = torch.stack(losses) # [n, option_size] | |
| # 将各个option的loss视为logit,loss越小,对应的概率应越大 | |
| loss_logits = torch.softmax(-losses, -1) # [n, option_size] | |
| print("losses.shape=", losses.shape) | |
| print("loss_logits.shape=", loss_logits.shape) | |
| if not return_dict: | |
| output = (lm_logits,) + transformer_outputs[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return CausalLMOutputWithCrossAttentions( | |
| loss=losses, | |
| logits=loss_logits, | |
| past_key_values=transformer_outputs.past_key_values, | |
| hidden_states=transformer_outputs.hidden_states, | |
| attentions=transformer_outputs.attentions, | |
| cross_attentions=transformer_outputs.cross_attentions, | |
| ) | |
| if __name__ == "__main__": | |
| from transformers import GPT2Tokenizer | |
| tokenizer = GPT2Tokenizer.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2") | |
| model = GPT2ForInContextClassification.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2") | |
| # input_text = "The capital city of China is Beijing. The capital city of Japan is Tokyo. The capital city of America" | |
| input_text1 = "What are follows emotions? \n\n Input: The book is very nice.\n Output: Great. \n\n Input: I never eat chocolate!\n Output: Bad. \n\n Input: This film is not wonderful.\n Output: Great" | |
| input_text2 = "What are follows emotions? \n\n Input: The book is very nice.\n Output: Great. \n\n Input: I never eat chocolate!\n Output: Bad. \n\n Input: This film is not wonderful.\n Output: Bad" | |
| # input_text = "This film is wonderful.\n Great." | |
| # input_text = "Mr. Chen was born in Shanghai. Obama was born in US. Jinping Xi was born in China." | |
| tokenizer.pad_token = tokenizer.eos_token | |
| inputs = tokenizer( | |
| [input_text1, input_text2], return_tensors="pt", | |
| max_length=60, | |
| padding="max_length") | |
| inputs["input_ids"] = inputs["input_ids"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1]) | |
| # inputs["token_type_ids"] = inputs["token_type_ids"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1]) | |
| inputs["attention_mask"] = inputs["attention_mask"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1]) | |
| inputs["labels"] = inputs["input_ids"] | |
| inputs["options"] = torch.Tensor([[0, 1], [0, 1]]).long() | |
| print(inputs["input_ids"].shape) | |
| label_mask = torch.zeros([1, 2, inputs["input_ids"].shape[2]]) | |
| # print(label_mask) | |
| label_mask[0][0][20] = 1 | |
| label_mask[0][1][20] = 1 | |
| print(label_mask) | |
| output = model(**inputs, return_dict=True) | |
| # print(output["last_hidden_state"]) | |
| # print(output["last_hidden_state"].size()) | |
| # print(output["logits"]) | |
| # print(output["logits"].size()) | |
| losses, logits = output["loss"], output["logits"] | |
| print("loss=", losses) | |
| print("logits=", logits) | |
| # gen_output = model.generate(**inputs, max_length=60) | |
| # for i in range(len(gen_output)): | |
| # gen_result = tokenizer.decode(gen_output[i]) | |
| # print("gen_result=", gen_result[len(inputs["input_ids"]):]) | |