Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2022/4/21 5:30 下午 | |
| # @Author : JianingWang | |
| # @File : span_proto.py | |
| """ | |
| This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition"" | |
| """ | |
| import os | |
| from typing import Optional | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| from typing import Union | |
| from dataclasses import dataclass | |
| from torch.nn import BCEWithLogitsLoss | |
| from transformers import MegatronBertModel, MegatronBertPreTrainedModel | |
| from transformers.file_utils import ModelOutput | |
| from transformers.models.bert import BertPreTrainedModel, BertModel | |
| class TokenProtoOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: Optional[torch.FloatTensor] = None | |
| class TokenProto(nn.Module): | |
| def __init__(self, config): | |
| """ | |
| word_encoder: Sentence encoder | |
| You need to set self.cost as your own loss function. | |
| """ | |
| nn.Module.__init__(self) | |
| self.config = config | |
| self.output_dir = "./outputs" | |
| # self.predict_dir = self.predict_result_path(self.output_dir) | |
| self.drop = nn.Dropout() | |
| self.projector = nn.Sequential( # projector | |
| nn.Linear(self.config.hidden_size, self.config.hidden_size), | |
| nn.Sigmoid(), | |
| # nn.LayerNorm(2) | |
| ) | |
| self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set | |
| # self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
| self.max_length = 64 | |
| self.margin_distance = 6.0 | |
| self.global_step = 0 | |
| def predict_result_path(self, path=None): | |
| if path is None: | |
| predict_dir = os.path.join( | |
| self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict" | |
| ) | |
| else: | |
| predict_dir = os.path.join( | |
| path, "predict" | |
| ) | |
| # if os.path.exists(predict_dir): | |
| # os.rmdir(predict_dir) # 删除历史记录 | |
| if not os.path.exists(predict_dir): # 重新创建一个新的目录 | |
| os.makedirs(predict_dir) | |
| return predict_dir | |
| def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): | |
| config = kwargs.pop("config", None) | |
| model = TokenProto(config=config) | |
| return model | |
| def __dist__(self, x, y, dim): | |
| if self.dot: | |
| return (x * y).sum(dim) | |
| else: | |
| return -(torch.pow(x - y, 2)).sum(dim) | |
| def __batch_dist__(self, S, Q, q_mask): | |
| # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] | |
| assert Q.size()[:2] == q_mask.size() | |
| Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] | |
| return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) | |
| def __get_proto__(self, embedding, tag, mask): | |
| proto = [] | |
| embedding = embedding[mask==1].view(-1, embedding.size(-1)) | |
| tag = torch.cat(tag, 0) | |
| assert tag.size(0) == embedding.size(0) | |
| for label in range(torch.max(tag)+1): | |
| proto.append(torch.mean(embedding[tag==label], 0)) | |
| proto = torch.stack(proto) | |
| return proto, embedding | |
| def forward(self, support, query): | |
| """ | |
| support: Inputs of the support set. | |
| query: Inputs of the query set. | |
| N: Num of classes | |
| K: Num of instances for each class in the support set | |
| Q: Num of instances in the query set | |
| support/query = {"index": [], "word": [], "mask": [], "label": [], "sentence_num": [], "text_mask": []} | |
| """ | |
| # support set和query set分别喂入BERT中获得各个样本的表示 | |
| support_emb = self.word_encoder(support["word"], support["mask"]) # [num_sent, number_of_tokens, 768] | |
| query_emb = self.word_encoder(query["word"], query["mask"]) # [num_sent, number_of_tokens, 768] | |
| support_emb = self.drop(support_emb) | |
| query_emb = self.drop(query_emb) | |
| # Prototypical Networks | |
| logits = [] | |
| current_support_num = 0 | |
| current_query_num = 0 | |
| assert support_emb.size()[:2] == support["mask"].size() | |
| assert query_emb.size()[:2] == query["mask"].size() | |
| for i, sent_support_num in enumerate(support["sentence_num"]): # 遍历每个采样得到的N-way K-shot任务数据 | |
| sent_query_num = query["sentence_num"][i] | |
| # Calculate prototype for each class | |
| # 因为一个batch里对应多个episode,因此 current_support_num:current_support_num+sent_support_num | |
| # 用来表示当前输入的张量中,哪个范围内的句子属于当前N-way K-shot采样数据 | |
| support_proto, embedding = self.__get_proto__( | |
| support_emb[current_support_num:current_support_num+sent_support_num], | |
| support["label"][current_support_num:current_support_num+sent_support_num], | |
| support["text_mask"][current_support_num: current_support_num+sent_support_num]) | |
| # calculate distance to each prototype | |
| logits.append(self.__batch_dist__( | |
| support_proto, | |
| query_emb[current_query_num:current_query_num+sent_query_num], | |
| query["text_mask"][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num] | |
| current_query_num += sent_query_num | |
| current_support_num += sent_support_num | |
| logits = torch.cat(logits, 0) # 每个query的从属于support set对应各个类的概率 | |
| _, pred = torch.max(logits, 1) # 挑选最大概率对应的proto类作为预测结果 | |
| # return logits, pred, embedding | |
| return TokenProtoOutput( | |
| logits=logits | |
| ) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错 | |