Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence | |
| from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Config, GPT2ForSequenceClassification, GPT2LMHeadModel, MarianTokenizer | |
| from fudge.constants import * | |
| from fudge.util import pad_mask | |
| from fudge.clickbait_classifier import BertClickbaitClassifier, ClickbaitConfig | |
| class Model(nn.Module): | |
| def __init__(self, args, gpt_pad_id, vocab_size, rhyme_group_size=None, glove_embeddings=None, verbose=True): | |
| super(Model, self).__init__() | |
| # self.topic = args.task == 'topic' | |
| self.formality = args.task == 'formality' | |
| self.iambic = args.task == 'iambic' | |
| self.rhyme = args.task == 'rhyme' | |
| self.newline = args.task == 'newline' | |
| self.clickbait = args.task == 'clickbait' | |
| # if self.topic: | |
| # self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words | |
| # if glove_embeddings is None: | |
| # if verbose: | |
| # print('initializing word embeddings from scratch') | |
| # self.word_embed = nn.Embedding(vocab_size, GLOVE_DIM, padding_idx=0) | |
| # else: | |
| # if verbose: | |
| # print('initializing word embeddings from glove') | |
| # self.word_embed = nn.Embedding.from_pretrained(glove_embeddings, padding_idx=0) | |
| # self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True) | |
| # self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| # large_hidden_dim = HIDDEN_DIM | |
| # self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM) | |
| # self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| # self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| # self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| # self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM) | |
| # self.out_linear3 = nn.Linear(HIDDEN_DIM, 1) | |
| # self.nonlinear = nn.ReLU() | |
| # elif self.formality: | |
| if self.formality: | |
| self.marian_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=0) # 0 in marian is '' | |
| self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0.5) # want it to be causal so we can learn all positions | |
| self.out_linear = nn.Linear(HIDDEN_DIM, 1) | |
| elif self.iambic: | |
| self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) | |
| self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0) # want it to be causal so we can learn all positions | |
| self.out_linear = nn.Linear(HIDDEN_DIM, 1) | |
| elif self.rhyme: | |
| self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words | |
| self.word_embed = nn.Embedding(rhyme_group_size+1, GLOVE_DIM, padding_idx=0) # this embedding for future words will actually embed the rhyme group idx | |
| self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True) | |
| self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| large_hidden_dim = HIDDEN_DIM + COUNT_SYLLABLE_DIM | |
| self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM) | |
| self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM) | |
| self.out_linear3 = nn.Linear(HIDDEN_DIM, 1) | |
| self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM) | |
| self.nonlinear = nn.ReLU() | |
| elif self.newline: | |
| self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words | |
| self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False) | |
| self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM) | |
| self.out_linear = nn.Linear(HIDDEN_DIM + COUNT_SYLLABLE_DIM, HIDDEN_DIM) | |
| self.out_linear2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) | |
| self.out_linear3 = nn.Linear(HIDDEN_DIM, 1) | |
| self.nonlinear = nn.ReLU() | |
| elif self.clickbait: | |
| # mpnet_config = ClickbaitConfig( | |
| # model_type="mpnet", | |
| # pretrained_model="sentence-transformers/all-mpnet-base-v2", | |
| # num_labels=1, | |
| # dropout=0.2, | |
| # inner_dim1=256, | |
| # inner_dim2=32, | |
| # max_length=25, | |
| # load_pretrained=True, | |
| # freeze_bert=False, | |
| # ) | |
| #TODO add a checkpoint to Classifier | |
| # print('add a checkpoint to Classifier') | |
| checkpoint = args.checkpoint #'ckpt/clickbait_classifier/checkpoint-1464' | |
| # self.classifier = BertClickbaitClassifier(config=mpnet_config).to(torch.device(args.device)) | |
| self.classifier = BertClickbaitClassifier.from_pretrained(checkpoint).to(torch.device(args.device)) | |
| else: | |
| raise NotImplementedError # TODO honestly this can/should be refactored into different models | |
| def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False, attention_mask=None): | |
| """ | |
| inputs: token ids, batch x seq, right-padded with 0s | |
| lengths: lengths of inputs; batch | |
| future_words: batch x N words to check if not predict next token, else batch | |
| log_probs: N | |
| syllables_to_go: batch | |
| """ | |
| # if self.topic: | |
| # inputs = self.gpt_embed(inputs) # batch x seq x 300 | |
| # inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) | |
| # rnn_output, _ = self.rnn(inputs) | |
| # rnn_output, _ = pad_packed_sequence(rnn_output) | |
| # rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 | |
| # hidden = rnn_output | |
| # attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq | |
| # embed = self.word_embed(future_words) # batch x N x 300 | |
| # embed_query = self.embed_key_linear(embed) | |
| # attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300 | |
| # attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N | |
| # attention_weights = attention_weights * attention_mask.unsqueeze(2) | |
| # hidden = self.attention_value_linear(hidden) | |
| # weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768 | |
| # unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300 | |
| # unnormalized_scores = torch.cat([unnormalized_scores, embed], dim=2) | |
| # unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores))) | |
| # unnormalized_scores = self.out_linear3(unnormalized_scores) | |
| # scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0) | |
| # return scores # batch x N of normalized scores or batch x | |
| # elif self.formality: | |
| if self.formality: | |
| inputs = self.marian_embed(inputs) | |
| inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) | |
| rnn_output, _ = self.rnn(inputs) | |
| rnn_output, _ = pad_packed_sequence(rnn_output) | |
| rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 | |
| return self.out_linear(rnn_output).squeeze(2) | |
| elif self.iambic: | |
| inputs = self.gpt_embed(inputs) | |
| inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) | |
| rnn_output, _ = self.rnn(inputs) | |
| rnn_output, _ = pad_packed_sequence(rnn_output) | |
| rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 | |
| return self.out_linear(rnn_output).squeeze(2) | |
| elif self.rhyme: | |
| inputs = self.gpt_embed(inputs) # batch x seq x 300 | |
| inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) | |
| rnn_output, _ = self.rnn(inputs) | |
| rnn_output, _ = pad_packed_sequence(rnn_output) | |
| rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 | |
| hidden = rnn_output | |
| attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq | |
| embed = self.word_embed(future_words) # batch x N x 300 | |
| embedded_syllables_to_go = self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, embed.shape[1], -1) # batch x N x 100 | |
| auxiliary_embed = embedded_syllables_to_go | |
| embed_query = self.embed_key_linear(torch.cat([embed, auxiliary_embed], dim=2)) | |
| attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300 | |
| attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N | |
| attention_weights = attention_weights * attention_mask.unsqueeze(2) | |
| hidden = self.attention_value_linear(hidden) | |
| weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768 | |
| unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300 | |
| unnormalized_scores = torch.cat([unnormalized_scores, embed, auxiliary_embed], dim=2) | |
| unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores))) | |
| unnormalized_scores = self.out_linear3(unnormalized_scores) | |
| scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0) | |
| return scores # batch x N of normalized scores or batch x | |
| elif self.newline: | |
| inputs = self.gpt_embed(inputs) # batch x seq x 300 | |
| inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) | |
| rnn_output, _ = self.rnn(inputs) | |
| rnn_output, _ = pad_packed_sequence(rnn_output) | |
| rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 | |
| hidden = torch.cat([rnn_output, self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, rnn_output.shape[1], -1)], dim=2) | |
| return self.out_linear3(self.nonlinear(self.out_linear2(self.nonlinear(self.out_linear(hidden))))).squeeze(2) | |
| elif self.clickbait: | |
| input_ids = torch.tensor(inputs) | |
| classifer_output = self.classifier(input_ids = input_ids, attention_mask = attention_mask).logits | |
| classifer_output = classifer_output[None,:,:] # batch x seq x 300 | |
| # return self.out_linear(rnn_output).squeeze(2) | |
| return classifer_output.squeeze(2) | |
| else: | |
| raise NotImplementedError | |