Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2022/4/21 5:30 下午 | |
| # @Author : JianingWang | |
| # @File : span_proto.py | |
| """ | |
| This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition"" | |
| """ | |
| import os | |
| from typing import Optional | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| from typing import Union | |
| from dataclasses import dataclass | |
| from torch.nn import BCEWithLogitsLoss | |
| from transformers import MegatronBertModel, MegatronBertPreTrainedModel | |
| from transformers.file_utils import ModelOutput | |
| from transformers.models.bert import BertPreTrainedModel, BertModel | |
| a = torch.nn.Embedding(10, 20) | |
| a.parameters | |
| class RawGlobalPointer(nn.Module): | |
| def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True): | |
| # encodr: RoBerta-Large as encoder | |
| # inner_dim: 64 | |
| # ent_type_size: ent_cls_num | |
| super().__init__() | |
| self.encoder = encoder | |
| self.ent_type_size = ent_type_size | |
| self.inner_dim = inner_dim | |
| self.hidden_size = encoder.config.hidden_size | |
| self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2) | |
| self.RoPE = RoPE | |
| def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim): | |
| position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1) | |
| indices = torch.arange(0, output_dim // 2, dtype=torch.float) | |
| indices = torch.pow(10000, -2 * indices / output_dim) | |
| embeddings = position_ids * indices | |
| embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) | |
| embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape)))) | |
| embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim)) | |
| embeddings = embeddings.to(self.device) | |
| return embeddings | |
| def forward(self, input_ids, attention_mask, token_type_ids): | |
| self.device = input_ids.device | |
| context_outputs = self.encoder(input_ids, attention_mask, token_type_ids) | |
| # last_hidden_state:(batch_size, seq_len, hidden_size) | |
| last_hidden_state = context_outputs[0] | |
| batch_size = last_hidden_state.size()[0] | |
| seq_len = last_hidden_state.size()[1] | |
| outputs = self.dense(last_hidden_state) | |
| outputs = torch.split(outputs, self.inner_dim * 2, dim=-1) | |
| outputs = torch.stack(outputs, dim=-2) | |
| qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:] | |
| if self.RoPE: | |
| # pos_emb:(batch_size, seq_len, inner_dim) | |
| pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim) | |
| cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1) | |
| sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1) | |
| qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1) | |
| qw2 = qw2.reshape(qw.shape) | |
| qw = qw * cos_pos + qw2 * sin_pos | |
| kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1) | |
| kw2 = kw2.reshape(kw.shape) | |
| kw = kw * cos_pos + kw2 * sin_pos | |
| # logits:(batch_size, ent_type_size, seq_len, seq_len) | |
| logits = torch.einsum("bmhd,bnhd->bhmn", qw, kw) | |
| # padding mask | |
| pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len) | |
| logits = logits * pad_mask - (1 - pad_mask) * 1e12 | |
| # 排除下三角 | |
| mask = torch.tril(torch.ones_like(logits), -1) | |
| logits = logits - mask * 1e12 | |
| return logits / self.inner_dim ** 0.5 | |
| class SinusoidalPositionEmbedding(nn.Module): | |
| """定义Sin-Cos位置Embedding | |
| """ | |
| def __init__( | |
| self, output_dim, merge_mode="add", custom_position_ids=False): | |
| super(SinusoidalPositionEmbedding, self).__init__() | |
| self.output_dim = output_dim | |
| self.merge_mode = merge_mode | |
| self.custom_position_ids = custom_position_ids | |
| def forward(self, inputs): | |
| if self.custom_position_ids: | |
| seq_len = inputs.shape[1] | |
| inputs, position_ids = inputs | |
| position_ids = position_ids.type(torch.float) | |
| else: | |
| input_shape = inputs.shape | |
| batch_size, seq_len = input_shape[0], input_shape[1] | |
| position_ids = torch.arange(seq_len).type(torch.float)[None] | |
| indices = torch.arange(self.output_dim // 2).type(torch.float) | |
| indices = torch.pow(10000.0, -2 * indices / self.output_dim) | |
| embeddings = torch.einsum("bn,d->bnd", position_ids, indices) | |
| embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) | |
| embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim)) | |
| if self.merge_mode == "add": | |
| return inputs + embeddings.to(inputs.device) | |
| elif self.merge_mode == "mul": | |
| return inputs * (embeddings + 1.0).to(inputs.device) | |
| elif self.merge_mode == "zero": | |
| return embeddings.to(inputs.device) | |
| def multilabel_categorical_crossentropy(y_pred, y_true): | |
| y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes | |
| y_pred_neg = y_pred - y_true * 1e12 # mask the pred outputs of pos classes | |
| y_pred_pos = y_pred - (1 - y_true) * 1e12 # mask the pred outputs of neg classes | |
| zeros = torch.zeros_like(y_pred[..., :1]) | |
| y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) | |
| y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) | |
| neg_loss = torch.logsumexp(y_pred_neg, dim=-1) | |
| pos_loss = torch.logsumexp(y_pred_pos, dim=-1) | |
| # print(y_pred, y_true, pos_loss) | |
| return (neg_loss + pos_loss).mean() | |
| def multilabel_categorical_crossentropy2(y_pred, y_true): | |
| y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes | |
| y_pred_neg = y_pred.clone() | |
| y_pred_pos = y_pred.clone() | |
| y_pred_neg[y_true>0] -= float("inf") | |
| y_pred_pos[y_true<1] -= float("inf") | |
| # y_pred_neg = y_pred - y_true * float("inf") # mask the pred outputs of pos classes | |
| # y_pred_pos = y_pred - (1 - y_true) * float("inf") # mask the pred outputs of neg classes | |
| zeros = torch.zeros_like(y_pred[..., :1]) | |
| y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) | |
| y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) | |
| neg_loss = torch.logsumexp(y_pred_neg, dim=-1) | |
| pos_loss = torch.logsumexp(y_pred_pos, dim=-1) | |
| # print(y_pred, y_true, pos_loss) | |
| return (neg_loss + pos_loss).mean() | |
| class GlobalPointerOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| topk_probs: torch.FloatTensor = None | |
| topk_indices: torch.IntTensor = None | |
| last_hidden_state: torch.FloatTensor = None | |
| class SpanProtoOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| query_spans: list = None | |
| proto_logits: list = None | |
| topk_probs: torch.FloatTensor = None | |
| topk_indices: torch.IntTensor = None | |
| class SpanDetector(BertPreTrainedModel): | |
| def __init__(self, config): | |
| # encodr: RoBerta-Large as encoder | |
| # inner_dim: 64 | |
| # ent_type_size: ent_cls_num | |
| super().__init__(config) | |
| self.bert = BertModel(config) | |
| # self.ent_type_size = config.ent_type_size | |
| self.ent_type_size = 1 | |
| self.inner_dim = 64 | |
| self.hidden_size = config.hidden_size | |
| self.RoPE = True | |
| self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) | |
| self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) | |
| def sequence_masking(self, x, mask, value="-inf", axis=None): | |
| if mask is None: | |
| return x | |
| else: | |
| if value == "-inf": | |
| value = -1e12 | |
| elif value == "inf": | |
| value = 1e12 | |
| assert axis > 0, "axis must be greater than 0" | |
| for _ in range(axis - 1): | |
| mask = torch.unsqueeze(mask, 1) | |
| for _ in range(x.ndim - mask.ndim): | |
| mask = torch.unsqueeze(mask, mask.ndim) | |
| return x * mask + value * (1 - mask) | |
| def add_mask_tril(self, logits, mask): | |
| if mask.dtype != logits.dtype: | |
| mask = mask.type(logits.dtype) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) | |
| # 排除下三角 | |
| mask = torch.tril(torch.ones_like(logits), diagonal=-1) | |
| logits = logits - mask * 1e12 | |
| return logits | |
| def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): | |
| # with torch.no_grad(): | |
| context_outputs = self.bert(input_ids, attention_mask, token_type_ids) | |
| last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] | |
| del context_outputs | |
| outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] | |
| qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个维度,从0开始,取奇数位置所有向量汇总 | |
| batch_size = input_ids.shape[0] | |
| if self.RoPE: # 是否使用RoPE旋转位置编码 | |
| pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) | |
| cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] | |
| sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) | |
| qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) | |
| qw2 = torch.reshape(qw2, qw.shape) | |
| qw = qw * cos_pos + qw2 * sin_pos | |
| kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) | |
| kw2 = torch.reshape(kw2, kw.shape) | |
| kw = kw * cos_pos + kw2 * sin_pos | |
| logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 | |
| bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 | |
| logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 | |
| # logit_mask = self.add_mask_tril(logits, mask=attention_mask) | |
| loss = None | |
| mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 | |
| # mask = torch.where(mask > 0, 0.0, 1) | |
| if labels is not None: | |
| # y_pred = torch.zeros(input_ids.shape[0], self.ent_type_size, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) | |
| # for i in range(input_ids.shape[0]): | |
| # for j in range(self.ent_type_size): | |
| # y_pred[i, j, labels[i, j, 0], labels[i, j, 1]] = 1 | |
| # y_true = labels.reshape(input_ids.shape[0] * self.ent_type_size, -1) | |
| # y_pred = logit_mask.reshape(input_ids.shape[0] * self.ent_type_size, -1) | |
| # loss = multilabel_categorical_crossentropy(y_pred, y_true) | |
| # | |
| # weight = ((labels == 0).sum() / labels.sum())/5 | |
| # loss_fct = nn.BCEWithLogitsLoss(weight=weight) | |
| # loss_fct = nn.BCEWithLogitsLoss(reduction="none") | |
| # unmask_labels = labels.view(-1)[mask.view(-1) > 0] | |
| # loss = loss_fct(logits.view(-1)[mask.view(-1) > 0], unmask_labels.float()) | |
| # if unmask_labels.sum() > 0: | |
| # loss = (loss[unmask_labels > 0].mean()+loss[unmask_labels < 1].mean())/2 | |
| # else: | |
| # loss = loss[unmask_labels < 1].mean() | |
| # y_pred = logits.view(-1)[mask.view(-1) > 0] | |
| # y_true = labels.view(-1)[mask.view(-1) > 0] | |
| # loss = multilabel_categorical_crossentropy2(y_pred, y_true) | |
| # y_pred = logits - torch.where(mask > 0, 0.0, float("inf")).unsqueeze(1) | |
| y_pred = logits - (1-mask.unsqueeze(1))*1e12 | |
| y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| loss = multilabel_categorical_crossentropy(y_pred, y_true) | |
| with torch.no_grad(): | |
| prob = torch.sigmoid(logits) * mask.unsqueeze(1) | |
| topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) | |
| return GlobalPointerOutput( | |
| loss=loss, | |
| topk_probs=topk.values, | |
| topk_indices=topk.indices, | |
| last_hidden_state=last_hidden_state | |
| ) | |
| class SpanProto(nn.Module): | |
| def __init__(self, config): | |
| """ | |
| word_encoder: Sentence encoder | |
| You need to set self.cost as your own loss function. | |
| """ | |
| nn.Module.__init__(self) | |
| self.config = config | |
| self.output_dir = "./outputs" | |
| # self.predict_dir = self.predict_result_path(self.output_dir) | |
| self.drop = nn.Dropout() | |
| self.global_span_detector = SpanDetector(config=self.config) # global span detector | |
| self.projector = nn.Sequential( # projector | |
| nn.Linear(self.config.hidden_size, self.config.hidden_size), | |
| nn.Sigmoid(), | |
| # nn.LayerNorm(2) | |
| ) | |
| self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set | |
| # self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
| self.max_length = 64 | |
| self.margin_distance = 6.0 | |
| self.global_step = 0 | |
| def predict_result_path(self, path=None): | |
| if path is None: | |
| predict_dir = os.path.join( | |
| self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict" | |
| ) | |
| else: | |
| predict_dir = os.path.join( | |
| path, "predict" | |
| ) | |
| # if os.path.exists(predict_dir): | |
| # os.rmdir(predict_dir) # 删除历史记录 | |
| if not os.path.exists(predict_dir): # 重新创建一个新的目录 | |
| os.makedirs(predict_dir) | |
| return predict_dir | |
| def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): | |
| config = kwargs.pop("config", None) | |
| model = SpanProto(config=config) | |
| # 将bert部分参数加载进去 | |
| model.global_span_detector = SpanDetector.from_pretrained( | |
| pretrained_model_name_or_path, | |
| *model_args, | |
| **kwargs | |
| ) | |
| # 将剩余的参数加载进来 | |
| return model | |
| # @classmethod | |
| # def resize_token_embeddings(self, new_num_tokens: Optional[int] = None): | |
| # self.global_span_detector.resize_token_embeddings(new_num_tokens) | |
| def __dist__(self, x, y, dim, use_dot=False): | |
| # x: [1, class_num, hidden_dim], y: [span_num, 1, hidden_dim] | |
| # x - y: [span_num, class_num, hidden_dim] | |
| # (x - y)^2.sum(2): [span_num, class_num] | |
| if use_dot: | |
| return (x * y).sum(dim) | |
| else: | |
| return -(torch.pow(x - y, 2)).sum(dim) | |
| def __get_proto__(self, support_emb: torch, support_span: list, support_span_type: list, use_tag=False): | |
| """ | |
| support_emb: [n", seq_len, dim] | |
| support_span: [n", m, 2] e.g. [[[3, 6], [12, 13]], [[1, 3]], ...] | |
| support_span_type: [n", m] e.g. [[2, 1], [5], ...] | |
| """ | |
| prototype = list() # 每个类的proto type | |
| all_span_embs = list() # 保存每个span的embedding | |
| all_span_tags = list() | |
| # 遍历每个类 | |
| for tag in range(self.num_class): | |
| # tag_id = torch.Tensor([1 if tag == self.num_class else 0]).long().cuda() | |
| # tag_embeddings = self.tag_embeddings(tag_id).view(-1) | |
| tag_prototype = list() # [k, dim] | |
| # 遍历当前episode内的每个句子 | |
| for emb, span, type in zip(support_emb, support_span, support_span_type): | |
| # emb: [seq_len, dim], span: [m, 2], type: [m] | |
| span = torch.Tensor(span).long().cuda() # e.g. [[3, 4], [9, 11]] | |
| type = torch.Tensor(type).long().cuda() # e.g. [1, 4] | |
| # 获取当前句子中属于tag类的span | |
| try: | |
| tag_span = span[type == tag] # e.g. span==[[3, 4]], tag==1 | |
| # 遍历每个检索到的span,获得其span embedding | |
| for (s, e) in tag_span: | |
| # tag_emb = torch.cat([emb[s], emb[e - 1]]) # [2*dim] | |
| tag_emb = emb[s] + emb[e] # [dim] | |
| # if use_tag: # 添加是否为unlabeled的标记,0对应embedding表示当前的span是labeled span,否则为unlabeled span | |
| # tag_emb = tag_emb + tag_embeddings | |
| tag_prototype.append(tag_emb) | |
| all_span_embs.append(tag_emb) | |
| all_span_tags.append(tag) | |
| except: | |
| # 说明当前类不存在对应的span,则随机 | |
| tag_prototype.append(torch.randn(support_emb.shape[-1]).cuda()) | |
| # assert 1 > 2 | |
| try: | |
| prototype.append(torch.mean(torch.stack(tag_prototype), dim=0)) | |
| except: | |
| # print("the class {} has no span".format(tag)) | |
| prototype.append(torch.randn(support_emb.shape[-1]).cuda()) | |
| # assert 1 > 2 | |
| all_span_embs = torch.stack(all_span_embs).detach().cpu().numpy().tolist() | |
| return torch.stack(prototype), all_span_embs, all_span_tags # [num_class + 1, dim] | |
| def __batch_dist__(self, prototype: torch, query_emb: torch, query_spans: list, query_span_type: Union[list, None]): | |
| """ | |
| 该函数用于获得query到各个prototype的分类 | |
| """ | |
| # 首先获得当前episode的每个句子的每个span的表征向量 | |
| # 遍历每个句子 | |
| all_logits = list() # 保存每个episode,每个句子所有span的预测概率 | |
| all_types = list() | |
| visual_all_types, visual_all_embs = list(), list() # 用于展示可视化 | |
| # num = 0 | |
| for emb, span in zip(query_emb, query_spans): # 遍历每个句子 | |
| # assert len(span) == len(query_span_type[num]), "span={}\ntype{}".format(span, query_span_type[num]) | |
| # print("len(span)={}, len(type)= {}".format(len(span), len(query_span_type[num]))) | |
| span_emb = list() # 保存当前句子所有span的embedding [m", dim] | |
| try: | |
| for (s, e) in span: # 遍历每个span | |
| tag_emb = emb[s] + emb[e] # [dim] | |
| span_emb.append(tag_emb) | |
| except: | |
| span_emb = [] | |
| if len(span_emb) != 0: | |
| span_emb = torch.stack(span_emb) # [span_num, dim] | |
| # 每个span与prototype计算距离 | |
| logits = self.__dist__(prototype.unsqueeze(0), span_emb.unsqueeze(1), 2) # [span_num, num_class] | |
| # pred_types = torch.argmax(logits, -1).detach().cpu().numpy().tolist() | |
| with torch.no_grad(): | |
| pred_dist, pred_types = torch.max(logits, -1) # 获得每个query与所有prototype的距离的最近的类及其距离的平方 | |
| pred_dist = torch.pow(-1 * pred_dist, 0.5) | |
| # print("pred_dist=", pred_dist) | |
| # 如果最近的距离超过了margin distant,则该span视为unlabeled span,标注为特殊的类 | |
| pred_types[pred_dist > self.margin_distance] = self.num_class | |
| pred_types = pred_types.detach().cpu().numpy().tolist() | |
| # # 获得概率分布 | |
| # with torch.no_grad(): | |
| # prob = torch.softmax(logits, -1) | |
| # pred_proba, pred_types = torch.max(logits, -1) # 获得每个span预测概率最大的类及其概率 | |
| # pred_types[pred_proba <= 0.6] = self.num_class # 如果当前预测的最大概率不满足,则说明其可能是一个其他实体 | |
| # pred_types = pred_types.detach().cpu().numpy().tolist() | |
| all_logits.append(logits) | |
| all_types.append(pred_types) | |
| visual_all_types.extend(pred_types) | |
| visual_all_embs.extend(span_emb.detach().cpu().numpy().tolist()) | |
| else: | |
| all_logits.append([]) | |
| all_types.append([]) | |
| # num += 1 | |
| if query_span_type is not None: | |
| # query_span_type: [n", m] | |
| try: | |
| all_type = torch.Tensor([type for types in query_span_type for type in types]).long().cuda() # [span_num] | |
| loss = nn.CrossEntropyLoss()(torch.cat(all_logits, 0), all_type) | |
| except: | |
| all_logit, all_type = list(), list() | |
| for logits, types in zip(all_logits, query_span_type): | |
| if len(logits) != 0 and len(types) != 0 and len(logits) == len(types): | |
| # print("len(logits)=", len(logits)) | |
| # print("len(types)=", len(types)) | |
| # print("logits=", logits) | |
| all_logit.append(logits) | |
| all_type.extend(types) | |
| # print("all_logit=", all_logit) | |
| if len(all_logit) != 0: | |
| all_logit = torch.cat(all_logit, 0) | |
| all_type = torch.Tensor(all_type).long().cuda() | |
| # print("len(all_logits)=", len(all_logits)) | |
| # print("len(query_span_type)=", len(query_span_type)) | |
| # print("types.shape=", torch.Tensor(all_type).shape) | |
| # min_len = min(len(all_type), len(all_type)) | |
| # all_logit, all_type = all_logit[: min_len], all_type[: min_len] | |
| # print("logits.shape=", all_logit.shape) | |
| # print("all_type=", all_type) | |
| loss = nn.CrossEntropyLoss()(all_logit, all_type) | |
| else: | |
| loss = 0. | |
| else: | |
| loss = None | |
| all_logits = [i.detach().cpu().numpy().tolist() for i in all_logits if len(i) != 0] | |
| return loss, all_logits, all_types, visual_all_types, visual_all_embs | |
| def __batch_margin__(self, prototype: torch, query_emb: torch, query_unlabeled_spans: list, | |
| query_labeled_spans: list, query_span_type: list): | |
| """ | |
| 该函数用于拉开unlabeled span与各个prototype的距离,拉近labeled span到对应类别的距离 | |
| """ | |
| # prototype: [num_class, dim], negative: [span_num, dim] | |
| # 获得每个unlabeled span与每个prototype的距离的平方,目标是对于每个距离平方都要设置大于margin阈值 | |
| def distance(input1, input2, p=2, eps=1e-6): | |
| # Compute the distance (p-norm) | |
| norm = torch.pow(torch.abs((input1 - input2 + eps)), p) | |
| pnorm = torch.pow(torch.sum(norm, -1), 1.0 / p) | |
| return pnorm | |
| unlabeled_span_emb, labeled_span_emb, labeled_span_type = list(), list(), list() | |
| for emb, span in zip(query_emb, query_unlabeled_spans): # 遍历每个句子 | |
| # 保存当前句子所有span的embedding [m", dim] | |
| for (s, e) in span: # 遍历每个span | |
| tag_emb = emb[s] + emb[e] # [dim] | |
| unlabeled_span_emb.append(tag_emb) | |
| # for emb, span, type in zip(query_emb, query_labeled_spans, query_span_type): # 遍历每个句子 | |
| # # 保存当前句子所有span的embedding [m", dim] | |
| # for (s, e) in span: # 遍历每个span | |
| # tag_emb = emb[s] + emb[e] # [dim] | |
| # labeled_span_emb.append(tag_emb) | |
| # labeled_span_type.extend(type) | |
| try: | |
| unlabeled_span_emb = torch.stack(unlabeled_span_emb) # [span_num, dim] | |
| # labeled_span_emb = torch.stack(labeled_span_emb) # [span_num, dim] | |
| # labeled_span_type = torch.stack(labeled_span_type) # [span_num] | |
| except: | |
| return 0. | |
| unlabeled_dist = distance(prototype.unsqueeze(0), unlabeled_span_emb.unsqueeze(1)) # [span_num, num_class] | |
| # labeled_dist = distance(prototype.unsqueeze(0), labeled_span_emb.unsqueeze(1)) # [span_num, num_class] | |
| # 获得每个span对应ground truth类别距离prototype的距离 | |
| # labeled_type_dist = torch.gather(labeled_dist, -1, labeled_span_type.unsqueeze(1)) # [span_num, 1] | |
| # print(dist) | |
| unlabeled_output = torch.maximum(torch.zeros_like(unlabeled_dist), self.margin_distance - unlabeled_dist) | |
| # labeled_output = torch.maximum(torch.zeros_like(labeled_type_dist), labeled_type_dist) | |
| # return torch.mean(unlabeled_output) + torch.mean(labeled_output) | |
| return torch.mean(unlabeled_output) | |
| def forward( | |
| self, | |
| episode_ids, | |
| support, query, | |
| num_class, | |
| num_example, | |
| mode=None, | |
| short_labels=None, | |
| stage:str ="train", | |
| path: str=None | |
| ): | |
| """ | |
| episode_ids: Input of the idx of each episode data. (only list) | |
| support: Inputs of the support set. | |
| query: Inputs of the query set. | |
| num_class: Num of classes | |
| K: Num of instances for each class in the support set | |
| Q: Num of instances for each class in the query set | |
| return: logits, pred | |
| """ | |
| if stage.startswith("train"): | |
| self.global_step += 1 | |
| self.num_class = num_class # N-way K-shot里的N | |
| self.num_example = num_example # N-way K-shot里的K | |
| # print("num_class=", num_class) | |
| self.mode = mode # FewNERD mode=inter/intra | |
| self.max_length = support["input_ids"].shape[1] | |
| support_inputs, support_attention_masks, support_type_ids = \ | |
| support["input_ids"], support["attention_mask"], support["token_type_ids"] # torch, [n, seq_len] | |
| query_inputs, query_attention_masks, query_type_ids = \ | |
| query["input_ids"], query["attention_mask"], query["token_type_ids"] # torch, [n, seq_len] | |
| support_labels = support["labels"] # torch, | |
| query_labels = query["labels"] # torch, | |
| # global span detector: obtain all mention span and loss | |
| support_detector_outputs = self.global_span_detector( | |
| support_inputs, support_attention_masks, support_type_ids, support_labels, short_labels=short_labels | |
| ) | |
| query_detector_outputs = self.global_span_detector( | |
| query_inputs, query_attention_masks, query_type_ids, query_labels, short_labels=short_labels | |
| ) | |
| device_id = support_inputs.device.index | |
| # if stage == "train_span": | |
| if self.global_step <= 500 and stage == "train": | |
| # only train span detector | |
| return SpanProtoOutput( | |
| loss=support_detector_outputs.loss, | |
| topk_probs=query_detector_outputs.topk_probs, | |
| topk_indices=query_detector_outputs.topk_indices, | |
| ) | |
| # obtain labeled span from the support set | |
| support_labeled_spans = support["labeled_spans"] # all labeled span, list, [n, m, 2], n sentence, m entity span, 2 (start / end) | |
| support_labeled_types = support["labeled_types"] # all labeled ent type id, list, [n, m], | |
| query_labeled_spans = query["labeled_spans"] # all labeled span, list, [n, m, 2], n sentence, m entity span, 2 (start / end) | |
| query_labeled_types = query["labeled_types"] # all labeled ent type id, list, [n, m], | |
| # for span, type in zip(query_labeled_spans, query_labeled_types): # 遍历每个句子 | |
| # assert len(span) == len(type), "span={}\ntype{}".format(span, type) | |
| # obtain unlabeled span from the support set | |
| # according to the detector, we can obtain multiple unlabeled span, which generated by the detector | |
| # but not labeled in n-way k-shot episode | |
| # support_predict_spans = self.get_topk_spans( # | |
| # support_detector_outputs.topk_probs, | |
| # support_detector_outputs.topk_indices, | |
| # support["input_ids"] | |
| # ) # [n, m, 2] | |
| # print("predicted support span num={}".format([len(i) for i in support_predict_spans])) | |
| # e.g. 打印一个所有句子,每个元素表示每个句子中的span个数,[5, 50, 4, 43, 5, 5, 1, 50, 2, 5, 6, 4, 50, 8, 12, 28, 17] | |
| # we can also obtain all predicted span from the query set | |
| query_predict_spans = self.get_topk_spans( # | |
| query_detector_outputs.topk_probs, | |
| query_detector_outputs.topk_indices, | |
| query["input_ids"], | |
| threshold=0.9 if stage.startswith("train") else 0.95, | |
| is_query=True | |
| ) # [n, m, 2] | |
| # print("predicted query span num={}".format([len(i) for i in query_predict_spans])) | |
| # merge predicted span and labeled span, and generate other class for unlabeled span set | |
| # support_all_spans, support_span_types = self.merge_span( | |
| # labeled_spans=support_labeled_spans, | |
| # labeled_types=support_labeled_types, | |
| # predict_spans=support_predict_spans, | |
| # stage=stage | |
| # ) # [n, m, 2] n 个句子,每个句子有若干个span | |
| # print("merged support span num={}".format([len(i) for i in support_all_spans])) | |
| if stage.startswith("train"): | |
| # 在训练阶段,需要知道detector识别的所有区间中,哪些是labeled,哪些是unlabeled,将unlabeled span全部分离出来 | |
| query_unlabeled_spans = self.split_span( # 拆分出unlabeled span,用于后面的margin loss | |
| labeled_spans=query_labeled_spans, | |
| labeled_types=query_labeled_types, | |
| predict_spans=query_predict_spans, | |
| stage=stage | |
| ) # [n, m, 2] n 个句子,每个句子有若干个span | |
| # print("merged query span num={}".format([len(i) for i in query_all_spans])) | |
| query_all_spans = query_labeled_spans | |
| query_span_types = query_labeled_types | |
| else: | |
| # 在推理阶段,直接全部merge | |
| query_unlabeled_spans = None | |
| query_all_spans, _ = self.merge_span( | |
| labeled_spans=query_labeled_spans, | |
| labeled_types=query_labeled_types, | |
| predict_spans=query_predict_spans, | |
| stage=stage | |
| ) # [n, m, 2] n 个句子,每个句子有若干个span | |
| # 在dev和test时,此时query部分的span完全靠detector识别 | |
| # query_all_spans = query_predict_spans | |
| query_span_types = None | |
| # 用于查看推理阶段dev或test的query上detector的预测结果 | |
| # for query_label, query_pred in zip(query_labeled_spans, query_predict_spans): | |
| # print(" ==== ") | |
| # print("query_labeled_spans=", query_label) | |
| # print("query_predict_spans=", query_pred) | |
| # obtain representations of each token | |
| support_emb, query_emb = support_detector_outputs.last_hidden_state, \ | |
| query_detector_outputs.last_hidden_state # [n, seq_len, dim] | |
| support_emb, query_emb = self.projector(support_emb), self.projector(query_emb) # [n, seq_len, dim] | |
| # all_query_spans = list() # 保存每个episode的所有句子所有的预测span | |
| # all_proto_logits = list() # 保存每个episode的所有句子每个预测span对应的entity type | |
| batch_result = dict() | |
| proto_losses = list() # 保存每个episode的loss | |
| # batch_visual = list() # 保存每个episode所有span的表征向量,用于可视化 | |
| current_support_num = 0 | |
| current_query_num = 0 | |
| typing_loss = None | |
| # 遍历每个episode | |
| for i, sent_support_num in enumerate(support["sentence_num"]): | |
| sent_query_num = query["sentence_num"][i] | |
| id_ = episode_ids[i] # 当前episode的编号 | |
| # 对于support,只对labeled span获得prototype | |
| # locate one episode and obtain the span prototype | |
| # [n", seq_len, dim] n" sentence in one episode | |
| # support_proto [num_class + 1, dim] | |
| support_proto, all_span_embs, all_span_tags = self.__get_proto__( | |
| support_emb[current_support_num: current_support_num + sent_support_num], # [n", seq_len, dim] | |
| support_labeled_spans[current_support_num: current_support_num + sent_support_num], # [n", m] | |
| support_labeled_types[current_support_num: current_support_num + sent_support_num], # [n", m] | |
| ) | |
| # 对于query set每个labeled span,使用标准的prototype learning | |
| # for each query, we first obtain corresponding span, and then calculate distance between it and each prototype | |
| # # [n", seq_len, dim] n" sentence in one episode | |
| proto_loss, proto_logits, all_types, visual_all_types, visual_all_embs = self.__batch_dist__( | |
| support_proto, | |
| query_emb[current_query_num: current_query_num + sent_query_num], # [n", seq_len, dim] | |
| query_all_spans[current_query_num: current_query_num + sent_query_num], # [n", m] | |
| query_span_types[current_query_num: current_query_num + sent_query_num] if query_span_types else None, # [n", m] | |
| ) | |
| visual_data = { | |
| "data": all_span_embs + visual_all_embs, | |
| "target": all_span_tags + visual_all_types, | |
| } | |
| # 对于query unlabeled span,遍历每个span,拉开与所有prototype的距离,选择margin loss | |
| if stage.startswith("train"): | |
| margin_loss = self.__batch_margin__( | |
| support_proto, | |
| query_emb[current_query_num: current_query_num + sent_query_num], # [n", seq_len, dim] | |
| query_unlabeled_spans[current_query_num: current_query_num + sent_query_num], # [n", span_num] | |
| query_all_spans[current_query_num: current_query_num + sent_query_num], | |
| query_span_types[current_query_num: current_query_num + sent_query_num], | |
| ) | |
| proto_losses.append(proto_loss + margin_loss) | |
| batch_result[id_] = { | |
| "spans": query_all_spans[current_query_num: current_query_num + sent_query_num], | |
| "types": all_types, | |
| "visualization": visual_data | |
| } | |
| current_query_num += sent_query_num | |
| current_support_num += sent_support_num | |
| # proto_logits = torch.stack(proto_logits) | |
| if stage.startswith("train"): | |
| typing_loss = torch.mean(torch.stack(proto_losses), dim=-1) | |
| if not stage.startswith("train"): | |
| self.__save_evaluate_predicted_result__(batch_result, device_id=device_id, stage=stage, path=path) | |
| # return SpanProtoOutput( | |
| # loss=((support_detector_outputs.loss + query_detector_outputs.loss) / 2.0 + typing_loss) | |
| # if stage.startswith("train") else (support_detector_outputs.loss + query_detector_outputs.loss), | |
| # ) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错 | |
| return SpanProtoOutput( | |
| loss=(support_detector_outputs.loss + typing_loss) | |
| if stage.startswith("train") else query_detector_outputs.loss, | |
| ) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错 | |
| def __save_evaluate_predicted_result__(self, new_result: dict, device_id: int = 0, stage="dev", path=None): | |
| """ | |
| 本函数用于在forward时保存每一个batch内的预测span以及span type | |
| new_result / result: { | |
| "(id)": { # id-th episode query | |
| "spans": [[[1, 4], [6, 7], xxx], ... ] # [sent_num, span_num, 2] | |
| "types": [[2, 0, xxx], ...] # [sent_num, span_num] | |
| }, | |
| xxx | |
| } | |
| """ | |
| # 拉取当前任务中已经预测的结果 | |
| self.predict_dir = self.predict_result_path(path) | |
| npy_file_name = os.path.join(self.predict_dir, "{}_predictions_{}.npy".format(stage, device_id)) | |
| result = dict() | |
| if os.path.exists(npy_file_name): | |
| result = np.load(npy_file_name, allow_pickle=True)[()] | |
| # 合并 | |
| for episode_id, query_res in new_result.items(): | |
| result[episode_id] = query_res | |
| # 保存 | |
| np.save(npy_file_name, result, allow_pickle=True) | |
| def get_topk_spans(self, probs, indices, input_ids, threshold=0.60, low_threshold=0.1, is_query=False): | |
| """ | |
| probs: [n, m] | |
| indices: [n, m] | |
| input_texts: [n, seq_len] | |
| is_query: if true, each sentence must recall at least one span | |
| """ | |
| probs = probs.squeeze(1).detach().cpu() # topk结果的概率 [n, m] # 返回的已经是按照概率进行降序排列的结果 | |
| indices = indices.squeeze(1).detach().cpu() # topk结果的索引 [n, m] # 返回的已经是按照概率进行降序排列的结果 | |
| input_ids = input_ids.detach().cpu() | |
| # print("probs=", probs) # [n, m] | |
| # print("indices=", indices) # [n, m] | |
| predict_span = list() | |
| if is_query: | |
| low_threshold = 0.0 | |
| for prob, index, text in zip(probs, indices, input_ids): # 遍历每个句子,其对应若干预测的span及其概率 | |
| threshold_ = threshold | |
| index_ids = torch.Tensor([i for i in range(len(index))]).long() | |
| span = set() | |
| # TODO 1. 调节阈值 2. 处理输出实体重叠问题 | |
| entity_index = index[prob >= low_threshold] | |
| index_ids = index_ids[prob >= low_threshold] | |
| while threshold_ >= low_threshold: # 动态控制阈值,以确保可以召回出span数量是尽可能均匀的(如果所有句子使用同一个阈值,那么每个句子被召回的span数量参差不齐) | |
| for ei, entity in enumerate(entity_index): | |
| p = prob[index_ids[ei]] | |
| if p < threshold_: # 如果此时候选的span得分已经低于阈值,由于获得的结果已经是降序排列的,则后续的结果一定都低于阈值,则直接结束 | |
| break | |
| # 1D index转2D index | |
| start_end = np.unravel_index(entity, (self.max_length, self.max_length)) | |
| # print("self.max_length=", self.max_length) | |
| s, e = start_end[0], start_end[1] | |
| ans = text[s: e] | |
| # if ans not in answer: | |
| # answer.append(ans) | |
| # topk_answer_dict[ans] = {"prob": float(prob[index_ids[ei]]), "pos": [(s, e)]} | |
| span.add((s, e)) | |
| # 满足下列几个条件的,动态调低阈值,并重新筛选 | |
| if len(span) <= 3: | |
| threshold_ -= 0.05 | |
| else: | |
| break | |
| if len(span) == 0: | |
| # 如果当前没有召回出任何span,则直接选择[cls]作为结果(相当于MRC的unanswerable) | |
| span = [[0, 0]] | |
| span = [list(i) for i in list(span)] | |
| # print("prob=", prob) e.g. [0.96, 0.85, 0.04, 0.00, ...] | |
| # print("span=", span) e.g. [[20, 23], [11, 14]] | |
| predict_span.append(span) | |
| return predict_span | |
| def split_span(self, labeled_spans: list, labeled_types: list, predict_spans: list, stage: str = "train"): | |
| """ | |
| # 对detector预测的所有span,划分出哪些是labeled span,哪些是unlabeled span | |
| """ | |
| def check_similar_span(span1, span2): | |
| """ | |
| 检测两个span是否接近,例如[12, 16], [11, 16], [13, 15], [12, 17]是接近的 | |
| """ | |
| # 考虑一个特殊情况,例如 [12, 12], [13, 13] | |
| if len(span1) == 0 or len(span2) == 0: | |
| return False | |
| if span1[0] == span1[1] and span2[0] == span2[1] and abs(span1[0] - span2[0]) == 1: | |
| return False | |
| if abs(span1[0] - span2[0]) <= 1 and abs(span1[1] - span2[1]) <= 1: # 两个区间的起点和终点分别相差1以内 | |
| return True | |
| return False | |
| all_spans, span_types = list(), list() # [n, m] | |
| num = 0 | |
| unlabeled_spans = list() | |
| for labeled_span, labeled_type, predict_span in zip(labeled_spans, labeled_types, predict_spans): | |
| # 对detector预测的所有span,划分出哪些是labeled span,哪些是unlabeled span | |
| unlabeled_span = list() | |
| # if len(all_span) != len(span_type): | |
| # length = min(len(all_span), len(span_type)) | |
| # all_span, span_type = all_span[: length], span_type[: length] | |
| for span in predict_span: # 遍历每个预测的span | |
| if span not in labeled_span: # 如果span没有存在,则说明当前的span是unlabeled的 | |
| # 可能存在一些临界点非常接近的(global pointer预测的临界点有时候很模糊),对于临界点相近的予以排除 | |
| is_remove = False | |
| for span_x in labeled_span: # 遍历所有已经被merge的span | |
| is_remove = check_similar_span(span_x, span) # 如果已存在的span,和当前的span很接近,则排除当前的span | |
| if is_remove is True: | |
| break | |
| if is_remove is True: | |
| continue | |
| unlabeled_span.append(span) | |
| # if self.global_step % 1000 == 0: | |
| # print(" === ") | |
| # print("labeled_span=", labeled_span) # [[1, 3], [12, 14], [25, 25], [7, 7]] | |
| # print("predict_span=", predict_span) # [[25, 25], [1, 3], [12, 14], [7, 7]] | |
| # if len(unlabeled_span) == 0 and stage.startswith("train"): | |
| # # 如果当前句子没有一个unlabeled span,则需要进行负采样,以确保unlabeled不为空 | |
| # # print("unlabeled span is empty, so we randomly select one span as the unlabeled span") | |
| # # all_span.append([0, 0]) | |
| # # span_type.append(self.num_class) | |
| # while True: | |
| # random_span = np.random.randint(0, 32, 2).tolist() | |
| # if abs(random_span[0] - random_span[1]) > 10: | |
| # continue | |
| # random_span = [random_span[1], random_span[0]] if random_span[0] > random_span[1] else random_span | |
| # if random_span in labeled_span or random_span in unlabeled_span: | |
| # continue | |
| # unlabeled_span.append(random_span) | |
| # break | |
| num += len(unlabeled_span) | |
| unlabeled_spans.append(unlabeled_span) | |
| # print("num=", num) | |
| return unlabeled_spans | |
| def merge_span(self, labeled_spans: list, labeled_types: list, predict_spans: list, stage: str = "train"): | |
| def check_similar_span(span1, span2): | |
| """ | |
| 检测两个span是否接近,例如[12, 16], [11, 16], [13, 15], [12, 17]是接近的 | |
| """ | |
| # 考虑一个特殊情况,例如 [12, 12], [13, 13] | |
| if len(span1) == 0 or len(span2) == 0: | |
| return False | |
| if span1[0] == span1[1] and span2[0] == span2[1] and abs(span1[0] - span2[0]) == 1: | |
| return False | |
| if abs(span1[0] - span2[0]) <= 1 and abs(span1[1] - span2[1]) <= 1: # 两个区间的起点和终点分别相差1以内 | |
| return True | |
| return False | |
| all_spans, span_types = list(), list() # [n, m] | |
| for labeled_span, labeled_type, predict_span in zip(labeled_spans, labeled_types, predict_spans): | |
| # 遍历每个句子,对它们的span进行合并 | |
| unlabeled_num = 0 | |
| all_span, span_type = labeled_span, labeled_type # 先加入所有labeled span | |
| if len(all_span) != len(span_type): | |
| length = min(len(all_span), len(span_type)) | |
| all_span, span_type = all_span[: length], span_type[: length] | |
| for span in predict_span: # 遍历每个预测的span | |
| if span not in all_span: # 如果span没有存在,则说明当前的span是unlabeled的 | |
| # 可能存在一些临界点非常接近的(global pointer预测的临界点有时候很模糊),对于临界点相近的予以排除 | |
| is_remove = False | |
| for span_x in all_span: # 遍历所有已经被merge的span | |
| is_remove = check_similar_span(span_x, span) # 如果已存在的span,和当前的span很接近,则排除当前的span | |
| if is_remove is True: | |
| break | |
| if is_remove is True: | |
| continue | |
| all_span.append(span) | |
| span_type.append(self.num_class) # e.g. 5-way问题,已有标签为0,1,2,3,4,因此新增一个标签为5 | |
| unlabeled_num += 1 | |
| # if self.global_step % 1000 == 0: | |
| # print(" === ") | |
| # print("labeled_span=", labeled_span) # [[1, 3], [12, 14], [25, 25], [7, 7]] | |
| # print("predict_span=", predict_span) # [[25, 25], [1, 3], [12, 14], [7, 7]] | |
| if unlabeled_num == 0 and stage.startswith("train"): | |
| # 如果当前句子没有一个unlabeled span,则需要进行负采样,以确保unlabeled不为空 | |
| # print("unlabeled span is empty, so we randomly select one span as the unlabeled span") | |
| # all_span.append([0, 0]) | |
| # span_type.append(self.num_class) | |
| while True: | |
| random_span = np.random.randint(0, 32, 2).tolist() | |
| if abs(random_span[0] - random_span[1]) > 10: | |
| continue | |
| random_span = [random_span[1], random_span[0]] if random_span[0] > random_span[1] else random_span | |
| if random_span in all_span: | |
| continue | |
| all_span.append(random_span) | |
| span_type.append(self.num_class) | |
| break | |
| # if len(all_span) != len(span_type): | |
| # all_span = [[0, 0]] | |
| # span_type = [self.num_class] | |
| all_spans.append(all_span) | |
| span_types.append(span_type) | |
| return all_spans, span_types | |