Spaces:
Runtime error
Runtime error
| from argparse import ArgumentParser | |
| import math | |
| import string | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification | |
| from poetry_util import is_iambic, perfect_rhyme_end, count_syllables | |
| from constants import * | |
| def conditional_perplexity(prefix, pred, tokenizer, model, device='cuda', sep_losses=False): | |
| # calculate perplexity on pred only, conditioned on prefix | |
| sentence = prefix + pred | |
| sos_token = tokenizer.decode([0]) | |
| prefix_tensor_input = tokenizer.encode(sos_token + prefix.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device) | |
| full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device) | |
| if sep_losses: | |
| prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0].sum() | |
| full_loss = model(full_tensor_input, labels=full_tensor_input)[0].sum() | |
| else: | |
| prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0] * (prefix_tensor_input.shape[1]-1) # neg log prob of prefix | |
| full_loss = model(full_tensor_input, labels=full_tensor_input)[0] * (full_tensor_input.shape[1]-1) # neg log prob of full seq | |
| pred_loss = full_loss - prefix_loss # neg log prob of preds given prefix | |
| avg_pred_loss = pred_loss / (full_tensor_input.shape[1] - prefix_tensor_input.shape[1]) | |
| return math.exp(avg_pred_loss.item()) | |
| def grammaticality(sentences, tokenizer, model, device='cuda'): | |
| with torch.no_grad(): | |
| total_good = 0 | |
| for sent in tqdm(sentences, total=len(sentences)): | |
| good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1] | |
| total_good += good_prob | |
| return total_good / len(sentences) # avg probability of grammaticality according to model | |
| def distinctness(sentences): | |
| d1 = set() | |
| d2 = set() | |
| d3 = set() | |
| total_words = 0 | |
| for sentence in sentences: | |
| o = sentence.split(' ') | |
| total_words += len(o) | |
| d1.update(o) | |
| for i in range(len(o) - 1): | |
| d2.add(o[i] + '_' + o[i+1]) | |
| for i in range(len(o) - 2): | |
| d3.add(o[i] + '_' + o[i+1] + '_' + o[i+2]) | |
| return len(d1) / total_words, len(d2) / total_words, len(d3) / total_words | |
| if __name__=='__main__': | |
| parser = ArgumentParser() | |
| parser.add_argument('--pred_file', type=str) | |
| parser.add_argument('--prefix_file', type=str) | |
| parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) | |
| args = parser.parse_args() | |
| preds = [] | |
| with open(args.pred_file, 'r') as rf: | |
| for line in rf: | |
| preds.append(line[:-1]) # drop \n but not beginning spaces if any | |
| prefixes = [] | |
| with open(args.prefix_file, 'r') as rf: | |
| for line in rf: | |
| prefixes.append(line.strip()) | |
| assert len(prefixes) == len(preds) | |
| rhymes = 0 | |
| iambic = 0 | |
| ten_syllables = 0 | |
| end = 0 | |
| diff_rhymes = 0 | |
| all_success = 0 | |
| total = len(prefixes) | |
| for prefix, pred in zip(prefixes, preds): | |
| if is_iambic(pred): | |
| iambic += 1 | |
| if perfect_rhyme_end(prefix, pred): | |
| rhymes += 1 | |
| if prefix.split()[-1].strip(string.punctuation) != pred.split()[-1].strip(string.punctuation): | |
| diff_rhymes += 1 | |
| if count_syllables(pred) == 10: | |
| ten_syllables += 1 | |
| if pred.strip()[-1] in PHRASE_ENDS: | |
| end += 1 | |
| if is_iambic(pred) and perfect_rhyme_end(prefix, pred) and count_syllables(pred) == 10 and pred.strip()[-1] in PHRASE_ENDS: | |
| all_success += 1 | |
| print('iambic', iambic, 'out of', total, ', frac', iambic / total) | |
| print('rhymes', rhymes, 'out of', total, ', frac', rhymes / total) | |
| print('end sentence', end, 'out of', total, ', frac', end / total) | |
| print('10 syllables', ten_syllables, 'out of', total, ', frac', ten_syllables / total) | |
| print('all success', all_success, 'out of', total, ', frac', all_success / total) | |
| print('rhymes with diff word', diff_rhymes, 'out of', total, ', frac', diff_rhymes / total) | |
| print('distinctness', distinctness(preds)) | |
| grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA') | |
| grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device) | |
| grammar_model.eval() | |
| print('grammaticality', grammaticality(preds, grammar_tokenizer, grammar_model, device=args.device)) | |
| perplexities = [] | |
| eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103') | |
| eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device) | |
| eval_model.eval() | |
| for prefix, pred in zip(prefixes, preds): | |
| perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device, sep_losses=True)) | |
| print('transformer xl perplexity', np.mean(perplexities), '+/-', np.std(perplexities)) | |
| perplexities = [] | |
| eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt') | |
| eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device) | |
| eval_model.eval() | |
| for prefix, pred in zip(prefixes, preds): | |
| perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device)) | |
| print('gpt perplexity', np.mean(perplexities), '+/-', np.std(perplexities)) | |
| # NOTE: uncomment this section with the path to the Shakespeare-finetuned GPT to evaluate this metric. it's in ckpt/poetry/gpt_finetune_shakespeare.pth.tar. | |
| # eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt') | |
| # eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device) | |
| # checkpoint = torch.load('***PATH_TO_SHAKESPEARE_FINETUNED_GPT***', map_location=args.device) | |
| # mod_dict = {} | |
| # for key in checkpoint['state_dict']: | |
| # mod_dict[key.replace('classifier.', '')] = checkpoint['state_dict'][key] | |
| # eval_model.load_state_dict(mod_dict) | |
| # eval_model.eval() | |
| # perplexities = [] | |
| # for prefix, pred in zip(prefixes, preds): | |
| # perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device)) | |
| # print('shakespeare finetuned perplexity', np.mean(perplexities), '+/-', np.std(perplexities)) | |