Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import time | |
| import pickle | |
| import math | |
| from argparse import ArgumentParser | |
| from typing import Iterable, List, Optional, Tuple | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelWithLMHead | |
| from torch import Tensor | |
| from fudge.data import Dataset | |
| from fudge.model import Model | |
| from fudge.util import num_params | |
| from fudge.constants import * | |
| tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum') | |
| classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') | |
| def main(args): | |
| with open(args.dataset_info, 'rb') as rf: | |
| dataset_info = pickle.load(rf) | |
| article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young. | |
| Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports. | |
| The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones. | |
| Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home. | |
| 'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing | |
| to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again' | |
| , even though he's had a chance to catch-up with other cast members.""" | |
| tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) | |
| pad_id = tokenizer.encode(PAD_TOKEN)[0] | |
| #For loading Clickbait summarizer | |
| model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device) | |
| model.eval() | |
| checkpoint = torch.load(args.ckpt, map_location=args.device) | |
| model_args = checkpoint['args'] | |
| conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
| conditioning_model.load_state_dict(checkpoint['state_dict']) | |
| conditioning_model = conditioning_model.to(args.device) | |
| conditioning_model.eval() | |
| print("=> loaded checkpoint '{}' (epoch {})" | |
| .format(args.ckpt, checkpoint['epoch'])) | |
| print('num params', num_params(conditioning_model)) | |
| while True: | |
| results = generate_clickbait(model, | |
| tokenizer, | |
| conditioning_model, | |
| [args.input_text], | |
| dataset_info, | |
| precondition_topk=args.precondition_topk, | |
| do_sample=args.do_sample, | |
| length_cutoff=args.length_cutoff, | |
| condition_lambda=args.condition_lambda, | |
| article_content=article_content, | |
| device=args.device) | |
| # print(results) | |
| import pdb; pdb.set_trace() | |
| def generate_clickbait(model, | |
| tokenizer, | |
| conditioning_model, | |
| input_text, | |
| dataset_info, | |
| precondition_topk, | |
| length_cutoff, | |
| condition_lambda=1.0, | |
| article_content=None, | |
| device='cuda'): | |
| with torch.no_grad(): | |
| batch_size = len(input_text) | |
| # encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq | |
| max_input_length = 512 | |
| encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length = max_input_length).to(device) # batch x seq | |
| # encoded_input_article = torch.cat(encoded_input_article, dim=0) | |
| # attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device) | |
| # CHANGE=ko | |
| encoded_input = tokenizer('<pad>', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq | |
| # encoded_input = tokenizer('<pad>'+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq | |
| # encoded_input = torch.cat(encoded_input, dim=0) | |
| encoded_input = encoded_input['input_ids'] | |
| lengths = torch.LongTensor([encoded_input.shape[1]]).to(device) | |
| # lengths = 1 | |
| past = None | |
| use_cache = True | |
| # CHANGE | |
| # model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)} | |
| model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'], | |
| attention_mask=encoded_input_article['attention_mask'], | |
| return_dict=True, | |
| output_attentions=False, | |
| output_hidden_states=False), | |
| } | |
| while lengths.max() < length_cutoff: | |
| model_inputs = model.prepare_inputs_for_generation( | |
| input_ids = encoded_input_article['input_ids'], | |
| decoder_input_ids=encoded_input, | |
| # past=past, | |
| attention_mask=encoded_input_article['attention_mask'], | |
| use_cache=use_cache, | |
| **model_kwargs | |
| ) | |
| outputs = model(**model_inputs, return_dict=True) | |
| logits = outputs.logits[:, -1, :] | |
| if "past_key_values" in outputs: | |
| model_kwargs["past"] = outputs.past_key_values | |
| # logits = model(encoded_input)[0][:, -1, :] # batch x vocab | |
| top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk | |
| new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1 | |
| expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk | |
| if condition_lambda == 0: | |
| condition_logits = torch.zeros_like(top_logits).float() | |
| condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N | |
| else: | |
| decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False) | |
| resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest') | |
| encoded_with_classifier = resulting_tokenization['input_ids'] | |
| attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device) | |
| tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device) | |
| condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1 | |
| expanded_lengths.flatten(0, 1), # batch*topk | |
| None, | |
| None, | |
| None, | |
| attention_mask=attention_mask | |
| ) | |
| condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N | |
| condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs | |
| condition_logits = torch.mean(condition_logits, dim=2) | |
| full_logits = top_logits + condition_logits * condition_lambda # batch x topk | |
| post_logits, post_indices = full_logits.topk(precondition_topk, dim=1) | |
| post_probs = F.softmax(post_logits, dim=1) | |
| # index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch | |
| index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch | |
| # next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch | |
| next_indices = top_indices[:, index_into_top_indices] # batch | |
| # encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1 | |
| encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1) | |
| lengths = lengths + 1 # batch | |
| # print(tokenizer.decode(encoded_input[0], add_special_tokens=False)) | |
| return [tokenizer.decode(s) for s in encoded_input] | |
| if __name__=='__main__': | |
| parser = ArgumentParser() | |
| # DATA | |
| parser.add_argument('--ckpt', type=str, required=True) | |
| parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info') | |
| parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en') | |
| parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on') | |
| parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning') | |
| parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy') | |
| parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model') | |
| parser.add_argument('--length_cutoff', type=int, default=512, help='max length') | |
| parser.add_argument('--seed', type=int, default=1, help='random seed') | |
| parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) | |
| parser.add_argument('--debug', action='store_true', default=False) | |
| args = parser.parse_args() | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| main(args) | |