Spaces:
Runtime error
Runtime error
| # import torch | |
| import time | |
| import argparse | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | |
| import numpy as np | |
| from .decoder import DecoderBase | |
| class LSTMDecoder(DecoderBase): | |
| """LSTM decoder with constant-length data""" | |
| def __init__(self, args, vocab, model_init, emb_init): | |
| super(LSTMDecoder, self).__init__() | |
| self.ni = args.ni | |
| self.nh = args.dec_nh | |
| self.nz = args.nz | |
| self.vocab = vocab | |
| self.device = args.device | |
| # no padding when setting padding_idx to -1 | |
| self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1) | |
| self.dropout_in = nn.Dropout(args.dec_dropout_in) | |
| self.dropout_out = nn.Dropout(args.dec_dropout_out) | |
| # for initializing hidden state and cell | |
| self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False) | |
| # concatenate z with input | |
| self.lstm = nn.LSTM(input_size=args.ni + args.nz, | |
| hidden_size=args.dec_nh, | |
| num_layers=1, | |
| batch_first=True) | |
| # prediction layer | |
| self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False) | |
| vocab_mask = torch.ones(len(vocab)) | |
| # vocab_mask[vocab['<pad>']] = 0 | |
| self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) | |
| self.reset_parameters(model_init, emb_init) | |
| def reset_parameters(self, model_init, emb_init): | |
| # for name, param in self.lstm.named_parameters(): | |
| # # self.initializer(param) | |
| # if 'bias' in name: | |
| # nn.init.constant_(param, 0.0) | |
| # # model_init(param) | |
| # elif 'weight' in name: | |
| # model_init(param) | |
| # model_init(self.trans_linear.weight) | |
| # model_init(self.pred_linear.weight) | |
| for param in self.parameters(): | |
| model_init(param) | |
| emb_init(self.embed.weight) | |
| def sample_text(self, input, z, EOS, device): | |
| sentence = [input] | |
| max_index = 0 | |
| input_word = input | |
| batch_size, n_sample, _ = z.size() | |
| seq_len = 1 | |
| z_ = z.expand(batch_size, seq_len, self.nz) | |
| seq_len = input.size(1) | |
| softmax = torch.nn.Softmax(dim=0) | |
| while max_index != EOS and len(sentence) < 100: | |
| # (batch_size, seq_len, ni) | |
| word_embed = self.embed(input_word) | |
| word_embed = torch.cat((word_embed, z_), -1) | |
| c_init = self.trans_linear(z).unsqueeze(0) | |
| h_init = torch.tanh(c_init) | |
| if len(sentence) == 1: | |
| h_init = h_init.squeeze(dim=1) | |
| c_init = c_init.squeeze(dim=1) | |
| output, hidden = self.lstm.forward(word_embed, (h_init, c_init)) | |
| else: | |
| output, hidden = self.lstm.forward(word_embed, hidden) | |
| # (batch_size * n_sample, seq_len, vocab_size) | |
| output_logits = self.pred_linear(output) | |
| output_logits = output_logits.view(-1) | |
| probs = softmax(output_logits) | |
| # max_index = torch.argmax(output_logits) | |
| max_index = torch.multinomial(probs, num_samples=1) | |
| input_word = torch.tensor([[max_index]]).to(device) | |
| sentence.append(max_index) | |
| return sentence | |
| def decode(self, input, z): | |
| """ | |
| Args: | |
| input: (batch_size, seq_len) | |
| z: (batch_size, n_sample, nz) | |
| """ | |
| # not predicting start symbol | |
| # sents_len -= 1 | |
| batch_size, n_sample, _ = z.size() | |
| seq_len = input.size(1) | |
| # (batch_size, seq_len, ni) | |
| word_embed = self.embed(input) | |
| word_embed = self.dropout_in(word_embed) | |
| if n_sample == 1: | |
| z_ = z.expand(batch_size, seq_len, self.nz) | |
| else: | |
| word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ | |
| .contiguous() | |
| # (batch_size * n_sample, seq_len, ni) | |
| word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) | |
| z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() | |
| z_ = z_.view(batch_size * n_sample, seq_len, self.nz) | |
| # (batch_size * n_sample, seq_len, ni + nz) | |
| word_embed = torch.cat((word_embed, z_), -1) | |
| z = z.view(batch_size * n_sample, self.nz) | |
| c_init = self.trans_linear(z).unsqueeze(0) | |
| h_init = torch.tanh(c_init) | |
| # h_init = self.trans_linear(z).unsqueeze(0) | |
| # c_init = h_init.new_zeros(h_init.size()) | |
| output, _ = self.lstm(word_embed, (h_init, c_init)) | |
| output = self.dropout_out(output) | |
| # (batch_size * n_sample, seq_len, vocab_size) | |
| output_logits = self.pred_linear(output) | |
| return output_logits | |
| def reconstruct_error(self, x, z): | |
| """Cross Entropy in the language case | |
| Args: | |
| x: (batch_size, seq_len) | |
| z: (batch_size, n_sample, nz) | |
| Returns: | |
| loss: (batch_size, n_sample). Loss | |
| across different sentence and z | |
| """ | |
| #remove end symbol | |
| src = x[:, :-1] | |
| # remove start symbol | |
| tgt = x[:, 1:] | |
| batch_size, seq_len = src.size() | |
| n_sample = z.size(1) | |
| # (batch_size * n_sample, seq_len, vocab_size) | |
| output_logits = self.decode(src, z) | |
| if n_sample == 1: | |
| tgt = tgt.contiguous().view(-1) | |
| else: | |
| # (batch_size * n_sample * seq_len) | |
| tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ | |
| .contiguous().view(-1) | |
| # (batch_size * n_sample * seq_len) | |
| loss = self.loss(output_logits.view(-1, output_logits.size(2)), | |
| tgt) | |
| # (batch_size, n_sample) | |
| return loss.view(batch_size, n_sample, -1).sum(-1) | |
| def log_probability(self, x, z): | |
| """Cross Entropy in the language case | |
| Args: | |
| x: (batch_size, seq_len) | |
| z: (batch_size, n_sample, nz) | |
| Returns: | |
| log_p: (batch_size, n_sample). | |
| log_p(x|z) across different x and z | |
| """ | |
| return -self.reconstruct_error(x, z) | |
| def greedy_decode(self, z): | |
| return self.sample_decode(z, greedy=True) | |
| def sample_decode(self, z, greedy=False): | |
| """sample/greedy decoding from z | |
| Args: | |
| z: (batch_size, nz) | |
| Returns: List1 | |
| List1: the decoded word sentence list | |
| """ | |
| batch_size = z.size(0) | |
| decoded_batch = [[] for _ in range(batch_size)] | |
| # (batch_size, 1, nz) | |
| c_init = self.trans_linear(z).unsqueeze(0) | |
| h_init = torch.tanh(c_init) | |
| decoder_hidden = (h_init, c_init) | |
| decoder_input = torch.tensor([self.vocab["<s>"]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1) | |
| end_symbol = torch.tensor([self.vocab["</s>"]] * batch_size, dtype=torch.long, device=self.device) | |
| mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device) | |
| length_c = 1 | |
| while mask.sum().item() != 0 and length_c < 100: | |
| # (batch_size, 1, ni) --> (batch_size, 1, ni+nz) | |
| word_embed = self.embed(decoder_input) | |
| word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1) | |
| output, decoder_hidden = self.lstm(word_embed, decoder_hidden) | |
| # (batch_size, 1, vocab_size) --> (batch_size, vocab_size) | |
| decoder_output = self.pred_linear(output) | |
| output_logits = decoder_output.squeeze(1) | |
| # (batch_size) | |
| if greedy: | |
| max_index = torch.argmax(output_logits, dim=1) | |
| else: | |
| probs = F.softmax(output_logits, dim=1) | |
| max_index = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| decoder_input = max_index.unsqueeze(1) | |
| length_c += 1 | |
| for i in range(batch_size): | |
| word = self.vocab.id2word(max_index[i].item()) | |
| if mask[i].item(): | |
| decoded_batch[i].append(self.vocab.id2word(max_index[i].item())) | |
| mask = torch.mul((max_index != end_symbol), mask) | |
| return decoded_batch | |
| class VarLSTMDecoder(LSTMDecoder): | |
| """LSTM decoder with constant-length data""" | |
| def __init__(self, args, vocab, model_init, emb_init): | |
| super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init) | |
| self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['<pad>']) | |
| vocab_mask = torch.ones(len(vocab)) | |
| vocab_mask[vocab['<pad>']] = 0 | |
| self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) | |
| self.reset_parameters(model_init, emb_init) | |
| def decode(self, input, z): | |
| """ | |
| Args: | |
| input: tuple which contains x and sents_len | |
| x: (batch_size, seq_len) | |
| sents_len: long tensor of sentence lengths | |
| z: (batch_size, n_sample, nz) | |
| """ | |
| input, sents_len = input | |
| # not predicting start symbol | |
| sents_len = sents_len - 1 | |
| batch_size, n_sample, _ = z.size() | |
| seq_len = input.size(1) | |
| # (batch_size, seq_len, ni) | |
| word_embed = self.embed(input) | |
| word_embed = self.dropout_in(word_embed) | |
| if n_sample == 1: | |
| z_ = z.expand(batch_size, seq_len, self.nz) | |
| else: | |
| word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ | |
| .contiguous() | |
| # (batch_size * n_sample, seq_len, ni) | |
| word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) | |
| z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() | |
| z_ = z_.view(batch_size * n_sample, seq_len, self.nz) | |
| # (batch_size * n_sample, seq_len, ni + nz) | |
| word_embed = torch.cat((word_embed, z_), -1) | |
| sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1) | |
| packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) | |
| z = z.view(batch_size * n_sample, self.nz) | |
| # h_init = self.trans_linear(z).unsqueeze(0) | |
| # c_init = h_init.new_zeros(h_init.size()) | |
| c_init = self.trans_linear(z).unsqueeze(0) | |
| h_init = torch.tanh(c_init) | |
| output, _ = self.lstm(packed_embed, (h_init, c_init)) | |
| output, _ = pad_packed_sequence(output, batch_first=True) | |
| output = self.dropout_out(output) | |
| # (batch_size * n_sample, seq_len, vocab_size) | |
| output_logits = self.pred_linear(output) | |
| return output_logits | |
| def reconstruct_error(self, x, z): | |
| """Cross Entropy in the language case | |
| Args: | |
| x: tuple which contains x_ and sents_len | |
| x_: (batch_size, seq_len) | |
| sents_len: long tensor of sentence lengths | |
| z: (batch_size, n_sample, nz) | |
| Returns: | |
| loss: (batch_size, n_sample). Loss | |
| across different sentence and z | |
| """ | |
| x, sents_len = x | |
| #remove end symbol | |
| src = x[:, :-1] | |
| # remove start symbol | |
| tgt = x[:, 1:] | |
| batch_size, seq_len = src.size() | |
| n_sample = z.size(1) | |
| # (batch_size * n_sample, seq_len, vocab_size) | |
| output_logits = self.decode((src, sents_len), z) | |
| if n_sample == 1: | |
| tgt = tgt.contiguous().view(-1) | |
| else: | |
| # (batch_size * n_sample * seq_len) | |
| tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ | |
| .contiguous().view(-1) | |
| # (batch_size * n_sample * seq_len) | |
| loss = self.loss(output_logits.view(-1, output_logits.size(2)), | |
| tgt) | |
| # (batch_size, n_sample) | |
| return loss.view(batch_size, n_sample, -1).sum(-1) |