Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # @Time : 2022/4/21 5:30 下午 | |
| # @Author : JianingWang | |
| # @File : global_pointer.py | |
| from typing import Optional | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| from dataclasses import dataclass | |
| from torch.nn import BCEWithLogitsLoss | |
| from transformers import MegatronBertModel, MegatronBertPreTrainedModel | |
| from transformers.file_utils import ModelOutput | |
| from transformers.models.bert import BertPreTrainedModel, BertModel | |
| from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel | |
| from roformer import RoFormerPreTrainedModel, RoFormerModel, RoFormerModel | |
| class RawGlobalPointer(nn.Module): | |
| def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True): | |
| # encodr: RoBerta-Large as encoder | |
| # inner_dim: 64 | |
| # ent_type_size: ent_cls_num | |
| super().__init__() | |
| self.encoder = encoder | |
| self.ent_type_size = ent_type_size | |
| self.inner_dim = inner_dim | |
| self.hidden_size = encoder.config.hidden_size | |
| self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2) | |
| self.RoPE = RoPE | |
| def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim): | |
| position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1) | |
| indices = torch.arange(0, output_dim // 2, dtype=torch.float) | |
| indices = torch.pow(10000, -2 * indices / output_dim) | |
| embeddings = position_ids * indices | |
| embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) | |
| embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape)))) | |
| embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim)) | |
| embeddings = embeddings.to(self.device) | |
| return embeddings | |
| def forward(self, input_ids, attention_mask, token_type_ids): | |
| self.device = input_ids.device | |
| context_outputs = self.encoder(input_ids, attention_mask, token_type_ids) | |
| # last_hidden_state:(batch_size, seq_len, hidden_size) | |
| last_hidden_state = context_outputs[0] | |
| batch_size = last_hidden_state.size()[0] | |
| seq_len = last_hidden_state.size()[1] | |
| outputs = self.dense(last_hidden_state) | |
| outputs = torch.split(outputs, self.inner_dim * 2, dim=-1) | |
| outputs = torch.stack(outputs, dim=-2) | |
| qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:] | |
| if self.RoPE: | |
| # pos_emb:(batch_size, seq_len, inner_dim) | |
| pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim) | |
| cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1) | |
| sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1) | |
| qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1) | |
| qw2 = qw2.reshape(qw.shape) | |
| qw = qw * cos_pos + qw2 * sin_pos | |
| kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1) | |
| kw2 = kw2.reshape(kw.shape) | |
| kw = kw * cos_pos + kw2 * sin_pos | |
| # logits:(batch_size, ent_type_size, seq_len, seq_len) | |
| logits = torch.einsum("bmhd,bnhd->bhmn", qw, kw) | |
| # padding mask | |
| pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len) | |
| logits = logits * pad_mask - (1 - pad_mask) * 1e12 | |
| # 排除下三角 | |
| mask = torch.tril(torch.ones_like(logits), -1) | |
| logits = logits - mask * 1e12 | |
| return logits / self.inner_dim ** 0.5 | |
| class SinusoidalPositionEmbedding(nn.Module): | |
| """定义Sin-Cos位置Embedding | |
| """ | |
| def __init__( | |
| self, output_dim, merge_mode="add", custom_position_ids=False): | |
| super(SinusoidalPositionEmbedding, self).__init__() | |
| self.output_dim = output_dim | |
| self.merge_mode = merge_mode | |
| self.custom_position_ids = custom_position_ids | |
| def forward(self, inputs): | |
| if self.custom_position_ids: | |
| seq_len = inputs.shape[1] | |
| inputs, position_ids = inputs | |
| position_ids = position_ids.type(torch.float) | |
| else: | |
| input_shape = inputs.shape | |
| batch_size, seq_len = input_shape[0], input_shape[1] | |
| position_ids = torch.arange(seq_len).type(torch.float)[None] | |
| indices = torch.arange(self.output_dim // 2).type(torch.float) | |
| indices = torch.pow(10000.0, -2 * indices / self.output_dim) | |
| embeddings = torch.einsum("bn,d->bnd", position_ids, indices) | |
| embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) | |
| embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim)) | |
| if self.merge_mode == "add": | |
| return inputs + embeddings.to(inputs.device) | |
| elif self.merge_mode == "mul": | |
| return inputs * (embeddings + 1.0).to(inputs.device) | |
| elif self.merge_mode == "zero": | |
| return embeddings.to(inputs.device) | |
| def multilabel_categorical_crossentropy(y_pred, y_true): | |
| y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes | |
| y_pred_neg = y_pred - y_true * 1e12 # mask the pred outputs of pos classes | |
| y_pred_pos = y_pred - (1 - y_true) * 1e12 # mask the pred outputs of neg classes | |
| zeros = torch.zeros_like(y_pred[..., :1]) | |
| y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) | |
| y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) | |
| neg_loss = torch.logsumexp(y_pred_neg, dim=-1) | |
| pos_loss = torch.logsumexp(y_pred_pos, dim=-1) | |
| # print(y_pred, y_true, pos_loss) | |
| return (neg_loss + pos_loss).mean() | |
| def multilabel_categorical_crossentropy2(y_pred, y_true): | |
| y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes | |
| y_pred_neg = y_pred.clone() | |
| y_pred_pos = y_pred.clone() | |
| y_pred_neg[y_true>0] -= float("inf") | |
| y_pred_pos[y_true<1] -= float("inf") | |
| # y_pred_neg = y_pred - y_true * float("inf") # mask the pred outputs of pos classes | |
| # y_pred_pos = y_pred - (1 - y_true) * float("inf") # mask the pred outputs of neg classes | |
| zeros = torch.zeros_like(y_pred[..., :1]) | |
| y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) | |
| y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) | |
| neg_loss = torch.logsumexp(y_pred_neg, dim=-1) | |
| pos_loss = torch.logsumexp(y_pred_pos, dim=-1) | |
| # print(y_pred, y_true, pos_loss) | |
| return (neg_loss + pos_loss).mean() | |
| class GlobalPointerOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| topk_probs: torch.FloatTensor = None | |
| topk_indices: torch.IntTensor = None | |
| class BertForEffiGlobalPointer(BertPreTrainedModel): | |
| def __init__(self, config): | |
| # encodr: RoBerta-Large as encoder | |
| # inner_dim: 64 | |
| # ent_type_size: ent_cls_num | |
| super().__init__(config) | |
| self.bert = BertModel(config) | |
| self.ent_type_size = config.ent_type_size | |
| self.inner_dim = config.inner_dim | |
| self.hidden_size = config.hidden_size | |
| self.RoPE = config.RoPE | |
| self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) | |
| self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) | |
| def sequence_masking(self, x, mask, value="-inf", axis=None): | |
| if mask is None: | |
| return x | |
| else: | |
| if value == "-inf": | |
| value = -1e12 | |
| elif value == "inf": | |
| value = 1e12 | |
| assert axis > 0, "axis must be greater than 0" | |
| for _ in range(axis - 1): | |
| mask = torch.unsqueeze(mask, 1) | |
| for _ in range(x.ndim - mask.ndim): | |
| mask = torch.unsqueeze(mask, mask.ndim) | |
| return x * mask + value * (1 - mask) | |
| def add_mask_tril(self, logits, mask): | |
| if mask.dtype != logits.dtype: | |
| mask = mask.type(logits.dtype) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) | |
| # 排除下三角 | |
| mask = torch.tril(torch.ones_like(logits), diagonal=-1) | |
| logits = logits - mask * 1e12 | |
| return logits | |
| def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): | |
| # with torch.no_grad(): | |
| context_outputs = self.bert(input_ids, attention_mask, token_type_ids) | |
| last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] | |
| outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] | |
| qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总 | |
| batch_size = input_ids.shape[0] | |
| if self.RoPE: | |
| pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) | |
| cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] | |
| sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) | |
| qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) | |
| qw2 = torch.reshape(qw2, qw.shape) | |
| qw = qw * cos_pos + qw2 * sin_pos | |
| kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) | |
| kw2 = torch.reshape(kw2, kw.shape) | |
| kw = kw * cos_pos + kw2 * sin_pos | |
| logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 | |
| bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 | |
| logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 | |
| # logit_mask = self.add_mask_tril(logits, mask=attention_mask) | |
| loss = None | |
| mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 | |
| # mask = torch.where(mask > 0, 0.0, 1) | |
| if labels is not None: | |
| y_pred = logits - (1-mask.unsqueeze(1))*1e12 | |
| y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| loss = multilabel_categorical_crossentropy(y_pred, y_true) | |
| with torch.no_grad(): | |
| prob = torch.sigmoid(logits) * mask.unsqueeze(1) | |
| topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) | |
| return GlobalPointerOutput( | |
| loss=loss, | |
| topk_probs=topk.values, | |
| topk_indices=topk.indices | |
| ) | |
| class RobertaForEffiGlobalPointer(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| # encodr: RoBerta-Large as encoder | |
| # inner_dim: 64 | |
| # ent_type_size: ent_cls_num | |
| super().__init__(config) | |
| self.roberta = RobertaModel(config) | |
| self.ent_type_size = config.ent_type_size | |
| self.inner_dim = config.inner_dim | |
| self.hidden_size = config.hidden_size | |
| self.RoPE = config.RoPE | |
| self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) | |
| self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) | |
| def sequence_masking(self, x, mask, value="-inf", axis=None): | |
| if mask is None: | |
| return x | |
| else: | |
| if value == "-inf": | |
| value = -1e12 | |
| elif value == "inf": | |
| value = 1e12 | |
| assert axis > 0, "axis must be greater than 0" | |
| for _ in range(axis - 1): | |
| mask = torch.unsqueeze(mask, 1) | |
| for _ in range(x.ndim - mask.ndim): | |
| mask = torch.unsqueeze(mask, mask.ndim) | |
| return x * mask + value * (1 - mask) | |
| def add_mask_tril(self, logits, mask): | |
| if mask.dtype != logits.dtype: | |
| mask = mask.type(logits.dtype) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) | |
| # 排除下三角 | |
| mask = torch.tril(torch.ones_like(logits), diagonal=-1) | |
| logits = logits - mask * 1e12 | |
| return logits | |
| def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): | |
| # with torch.no_grad(): | |
| context_outputs = self.roberta(input_ids, attention_mask, token_type_ids) | |
| last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] | |
| outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] | |
| qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总 | |
| batch_size = input_ids.shape[0] | |
| if self.RoPE: | |
| pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) | |
| cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] | |
| sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) | |
| qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) | |
| qw2 = torch.reshape(qw2, qw.shape) | |
| qw = qw * cos_pos + qw2 * sin_pos | |
| kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) | |
| kw2 = torch.reshape(kw2, kw.shape) | |
| kw = kw * cos_pos + kw2 * sin_pos | |
| logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 | |
| bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 | |
| logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 | |
| # logit_mask = self.add_mask_tril(logits, mask=attention_mask) | |
| loss = None | |
| mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 | |
| # mask = torch.where(mask > 0, 0.0, 1) | |
| if labels is not None: | |
| y_pred = logits - (1-mask.unsqueeze(1))*1e12 | |
| y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| loss = multilabel_categorical_crossentropy(y_pred, y_true) | |
| with torch.no_grad(): | |
| prob = torch.sigmoid(logits) * mask.unsqueeze(1) | |
| topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) | |
| return GlobalPointerOutput( | |
| loss=loss, | |
| topk_probs=topk.values, | |
| topk_indices=topk.indices | |
| ) | |
| class RoformerForEffiGlobalPointer(RoFormerPreTrainedModel): | |
| def __init__(self, config): | |
| # encodr: RoBerta-Large as encoder | |
| # inner_dim: 64 | |
| # ent_type_size: ent_cls_num | |
| super().__init__(config) | |
| self.roformer = RoFormerModel(config) | |
| self.ent_type_size = config.ent_type_size | |
| self.inner_dim = config.inner_dim | |
| self.hidden_size = config.hidden_size | |
| self.RoPE = config.RoPE | |
| self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) | |
| self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) | |
| def sequence_masking(self, x, mask, value="-inf", axis=None): | |
| if mask is None: | |
| return x | |
| else: | |
| if value == "-inf": | |
| value = -1e12 | |
| elif value == "inf": | |
| value = 1e12 | |
| assert axis > 0, "axis must be greater than 0" | |
| for _ in range(axis - 1): | |
| mask = torch.unsqueeze(mask, 1) | |
| for _ in range(x.ndim - mask.ndim): | |
| mask = torch.unsqueeze(mask, mask.ndim) | |
| return x * mask + value * (1 - mask) | |
| def add_mask_tril(self, logits, mask): | |
| if mask.dtype != logits.dtype: | |
| mask = mask.type(logits.dtype) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) | |
| # 排除下三角 | |
| mask = torch.tril(torch.ones_like(logits), diagonal=-1) | |
| logits = logits - mask * 1e12 | |
| return logits | |
| def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): | |
| # with torch.no_grad(): | |
| context_outputs = self.roformer(input_ids, attention_mask, token_type_ids) | |
| last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] | |
| outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] | |
| qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总 | |
| batch_size = input_ids.shape[0] | |
| if self.RoPE: | |
| pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) | |
| cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] | |
| sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) | |
| qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) | |
| qw2 = torch.reshape(qw2, qw.shape) | |
| qw = qw * cos_pos + qw2 * sin_pos | |
| kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) | |
| kw2 = torch.reshape(kw2, kw.shape) | |
| kw = kw * cos_pos + kw2 * sin_pos | |
| logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 | |
| bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 | |
| logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 | |
| # logit_mask = self.add_mask_tril(logits, mask=attention_mask) | |
| loss = None | |
| mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 | |
| # mask = torch.where(mask > 0, 0.0, 1) | |
| if labels is not None: | |
| y_pred = logits - (1-mask.unsqueeze(1))*1e12 | |
| y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| loss = multilabel_categorical_crossentropy(y_pred, y_true) | |
| with torch.no_grad(): | |
| prob = torch.sigmoid(logits) * mask.unsqueeze(1) | |
| topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) | |
| return GlobalPointerOutput( | |
| loss=loss, | |
| topk_probs=topk.values, | |
| topk_indices=topk.indices | |
| ) | |
| class MegatronForEffiGlobalPointer(MegatronBertPreTrainedModel): | |
| def __init__(self, config): | |
| # encodr: RoBerta-Large as encoder | |
| # inner_dim: 64 | |
| # ent_type_size: ent_cls_num | |
| super().__init__(config) | |
| self.bert = MegatronBertModel(config) | |
| self.ent_type_size = config.ent_type_size | |
| self.inner_dim = config.inner_dim | |
| self.hidden_size = config.hidden_size | |
| self.RoPE = config.RoPE | |
| self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) | |
| self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) | |
| def sequence_masking(self, x, mask, value="-inf", axis=None): | |
| if mask is None: | |
| return x | |
| else: | |
| if value == "-inf": | |
| value = -1e12 | |
| elif value == "inf": | |
| value = 1e12 | |
| assert axis > 0, "axis must be greater than 0" | |
| for _ in range(axis - 1): | |
| mask = torch.unsqueeze(mask, 1) | |
| for _ in range(x.ndim - mask.ndim): | |
| mask = torch.unsqueeze(mask, mask.ndim) | |
| return x * mask + value * (1 - mask) | |
| def add_mask_tril(self, logits, mask): | |
| if mask.dtype != logits.dtype: | |
| mask = mask.type(logits.dtype) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) | |
| logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) | |
| # 排除下三角 | |
| mask = torch.tril(torch.ones_like(logits), diagonal=-1) | |
| logits = logits - mask * 1e12 | |
| return logits | |
| def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): | |
| # with torch.no_grad(): | |
| context_outputs = self.bert(input_ids, attention_mask, token_type_ids) | |
| last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] | |
| outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] | |
| qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总 | |
| batch_size = input_ids.shape[0] | |
| if self.RoPE: | |
| pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) | |
| cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] | |
| sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) | |
| qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) | |
| qw2 = torch.reshape(qw2, qw.shape) | |
| qw = qw * cos_pos + qw2 * sin_pos | |
| kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) | |
| kw2 = torch.reshape(kw2, kw.shape) | |
| kw = kw * cos_pos + kw2 * sin_pos | |
| logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 | |
| bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 | |
| logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 | |
| # logit_mask = self.add_mask_tril(logits, mask=attention_mask) | |
| loss = None | |
| mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 | |
| # mask = torch.where(mask > 0, 0.0, 1) | |
| if labels is not None: | |
| y_pred = logits - (1-mask.unsqueeze(1))*1e12 | |
| y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) | |
| loss = multilabel_categorical_crossentropy(y_pred, y_true) | |
| with torch.no_grad(): | |
| prob = torch.sigmoid(logits) * mask.unsqueeze(1) | |
| topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) | |
| return GlobalPointerOutput( | |
| loss=loss, | |
| topk_probs=topk.values, | |
| topk_indices=topk.indices | |
| ) | |