| import torch | |
| import torch.nn as nn | |
| from utils.bert_model import BertForSequenceEncoder | |
| class sentence_retrieval_model(nn.Module): | |
| def __init__(self, args): | |
| super(sentence_retrieval_model, self).__init__() | |
| self.pred_model = BertForSequenceEncoder.from_pretrained(args['bert_pretrain']) | |
| self.bert_hidden_dim = args['bert_hidden_dim'] | |
| self.dropout = nn.Dropout(args['dropout']) | |
| self.proj_match = nn.Linear(self.bert_hidden_dim, 1) | |
| def forward(self, inp_tensor, msk_tensor, seg_tensor): | |
| _, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor) | |
| inputs = self.dropout(inputs) | |
| score = self.proj_match(inputs).squeeze(-1) | |
| score = torch.tanh(score) | |
| return score |