Spaces:
Runtime error
Runtime error
| import random | |
| import math | |
| import os | |
| import pickle | |
| from collections import defaultdict, namedtuple | |
| import string | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' # turn off since we're using multiple threads for loading anyway | |
| from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| from fudge.util import suppress_stdout | |
| from fudge.poetry_util import is_iambic, count_syllables, get_rhymes, get_rhyme_group | |
| from fudge.constants import * | |
| DatasetInfo = namedtuple('DatasetInfo', | |
| ['index2word', 'word2index', 'total_words', 'vocab', 'glove_embeddings']) | |
| RhymeInfo = namedtuple('RhymeInfo', | |
| ['word2rhyme_group', 'rhyme_group_counts', 'rhyme_groups', 'index2rhyme_group', 'rhyme_group2index', 'total_rhyme_groups']) | |
| def collate(batch): | |
| pad_id = batch[0][4] | |
| inputs = [b[0] for b in batch] | |
| lengths = torch.LongTensor([b[1] for b in batch]) | |
| max_length = lengths.max() | |
| for i in range(len(inputs)): | |
| if len(inputs[i]) < max_length: | |
| inputs[i] = torch.cat([inputs[i], torch.zeros(max_length - len(inputs[i])).long()], dim=0) # actually 0 is fine as pad since it's masked out | |
| inputs = torch.stack(inputs, dim=0) | |
| future_words = torch.LongTensor([b[2] for b in batch]).unsqueeze(0).expand(len(batch), -1).clone() # batch x N=batch | |
| labels = torch.zeros_like(future_words).long() | |
| labels = labels.scatter(1, torch.arange(len(batch)).unsqueeze(1), torch.ones(len(batch)).long().unsqueeze(1)).clone() | |
| log_probs = torch.Tensor([b[3] for b in batch]) | |
| classification_labels = [b[5] for b in batch] # batch | |
| if type(classification_labels[0]) == list: | |
| for i in range(len(classification_labels)): | |
| assert len(classification_labels[i]) == lengths[i] | |
| if len(classification_labels[i]) < max_length: | |
| classification_labels[i] = torch.cat([torch.LongTensor(classification_labels[i]), -1 + torch.zeros(max_length - len(classification_labels[i])).long()], dim=0) | |
| else: | |
| classification_labels[i] = torch.LongTensor(classification_labels[i]) | |
| classification_labels = torch.stack(classification_labels, dim=0) # batch x seq | |
| else: | |
| assert type(classification_labels[0]) == int | |
| classification_labels = torch.LongTensor(classification_labels) # they're just int labels | |
| syllables_to_go = torch.LongTensor([b[6] for b in batch]) | |
| future_word_num_syllables = torch.LongTensor([b[7] for b in batch]) | |
| rhyme_group_index = torch.LongTensor([b[8] for b in batch]) | |
| return (inputs, lengths, future_words, log_probs, labels, classification_labels, syllables_to_go, future_word_num_syllables, rhyme_group_index) | |
| def load_rhyme_info(index2word, vocab): | |
| word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP) | |
| rhyme_group_counts = defaultdict(lambda: 0) | |
| rhyme_groups = set() | |
| for word in index2word: | |
| try: | |
| rhyme_group = get_rhyme_group(word) | |
| word2rhyme_group[word] = rhyme_group | |
| rhyme_group_counts[rhyme_group] += (vocab[word] if word in vocab else 1) # for rare words not in vocab, just use 1 | |
| rhyme_groups.add(rhyme_group) | |
| except: | |
| rhyme_group_counts[UNKNOWN_RHYME_GROUP] += (vocab[word] if word in vocab else 1) | |
| index2rhyme_group = [UNKNOWN_RHYME_GROUP] + sorted(list(rhyme_groups)) | |
| rhyme_group2index = {s: i for i, s in enumerate(index2rhyme_group)} | |
| total_rhyme_groups = sum(rhyme_group_counts.values()) | |
| return RhymeInfo(word2rhyme_group=dict(word2rhyme_group), | |
| rhyme_group_counts=dict(rhyme_group_counts), | |
| rhyme_groups=rhyme_groups, | |
| index2rhyme_group=index2rhyme_group, | |
| rhyme_group2index=rhyme_group2index, | |
| total_rhyme_groups=total_rhyme_groups) | |
| class Dataset: | |
| def __init__(self, args): | |
| print('loading data') | |
| random.seed(args.seed) | |
| self.batch_size = args.batch_size | |
| self.data_dir = args.data_dir | |
| self.topic = args.task == 'topic' | |
| self.formality = args.task == 'formality' | |
| self.iambic = args.task == 'iambic' | |
| self.rhyme = args.task == 'rhyme' | |
| self.newline = args.task == 'newline' | |
| self.tokenizer = AutoTokenizer.from_pretrained(FORMALITY_MODEL_STRING if self.formality else TOPIC_MODEL_STRING) | |
| self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) | |
| self.gpt_pad_id = self.tokenizer.encode(PAD_TOKEN)[0] # actually just the vocab size | |
| sentences = [] | |
| self.vocab = defaultdict(lambda: 0) | |
| if self.formality: | |
| self.vocab['placeholder'] = 1 # anything so we don't crash | |
| train, val, test = [], [], [] | |
| for category, label in [('formal', 1), ('informal', 0)]: | |
| with open(os.path.join(args.data_dir, 'train', category), 'r') as rf: | |
| for i, line in enumerate(rf): | |
| if len(line) > FORMALITY_MAX_LEN: | |
| line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len; chosen so only ~20 examples affected in dataset | |
| if i < FORMALITY_VAL_SIZE // 2: | |
| val.append((line.strip(), label)) | |
| else: | |
| train.append((line.strip(), label)) | |
| with open(os.path.join(args.data_dir, 'test', category), 'r') as rf: | |
| for line in rf: | |
| if len(line) > FORMALITY_MAX_LEN: | |
| line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len | |
| test.append((line.strip(), label)) | |
| self.splits = {} | |
| self.splits['train'], self.splits['val'], self.splits['test'] = train, val, test | |
| else: # topic / poetry | |
| for root, _, filenames in os.walk(args.data_dir): | |
| for fname in filenames: | |
| with open(os.path.join(root, fname), 'r') as rf: | |
| for line in rf: | |
| sentences.append(line.strip()) | |
| for word in line.strip().split(' '): | |
| self.vocab[word] += 1 | |
| random.shuffle(sentences) | |
| self.splits = {} | |
| if args.debug: | |
| self.splits['val'] = sentences | |
| self.splits['test'] = sentences | |
| self.splits['train'] = sentences | |
| else: | |
| self.splits['val'] = sentences[:TOPIC_VAL_SIZE] | |
| self.splits['test'] = sentences[TOPIC_VAL_SIZE:2*TOPIC_VAL_SIZE] | |
| self.splits['train'] = sentences[2*TOPIC_VAL_SIZE:] | |
| if args.dataset_info is not None: | |
| print('loading dataset info from file') | |
| with open(args.dataset_info, 'rb') as rf: | |
| dataset_info = pickle.load(rf) | |
| self.vocab, self.total_words, self.index2word, self.word2index, self.glove_embeddings = \ | |
| dataset_info.vocab, dataset_info.total_words, dataset_info.index2word, dataset_info.word2index, dataset_info.glove_embeddings | |
| self.dataset_info = dataset_info | |
| else: | |
| print('generating dataset info from scratch') | |
| words_values = list(self.vocab.items()) | |
| words_values = sorted(words_values, key=lambda x: x[1], reverse=True) | |
| if args.glove_file is None: | |
| print('no glove embeddings given') | |
| for word, _ in words_values[VOCAB_SIZE:]: # only use somewhat common tokens | |
| del self.vocab[word] | |
| glove_embeddings = None | |
| else: | |
| print('loading glove embeddings') | |
| glove_embeddings = {} | |
| with open(args.glove_file, 'r') as rf: | |
| for i, line in enumerate(rf): | |
| if i % GLOVE_PRINT_PROGRESS_FREQ == 0: | |
| print(i) | |
| line = line.strip().split() | |
| if len(line) != GLOVE_DIM + 1: | |
| continue # skip multi-word embeddings which are rare anyway | |
| glove_embeddings[line[0]] = [float(x) for x in line[1:]] | |
| for word, _ in words_values: | |
| if word not in glove_embeddings: | |
| del self.vocab[word] | |
| self.total_words = sum(self.vocab.values()) | |
| self.index2word = [PAD_TOKEN] + sorted(list(self.vocab.keys())) | |
| self.word2index = {s: i for i, s in enumerate(self.index2word)} | |
| self.vocab = dict(self.vocab) # so we can pickle later | |
| if glove_embeddings is None: | |
| self.glove_embeddings = None | |
| else: | |
| self.glove_embeddings = torch.stack([torch.zeros(GLOVE_DIM)] + [torch.Tensor(glove_embeddings[word]) for word in self.index2word[1:]], dim=0) | |
| self.dataset_info = DatasetInfo(index2word=self.index2word, | |
| word2index=self.word2index, | |
| total_words=self.total_words, | |
| vocab=self.vocab, | |
| glove_embeddings=self.glove_embeddings) | |
| if self.rhyme: | |
| if args.rhyme_info is not None: | |
| print('loading rhyme info from file') | |
| with open(args.rhyme_info, 'rb') as rf: | |
| self.rhyme_info = pickle.load(rf) | |
| else: | |
| self.rhyme_info = load_rhyme_info(self.index2word, self.vocab) | |
| self.word2rhyme_group, self.rhyme_group_counts, self.rhyme_groups, self.index2rhyme_group, self.rhyme_group2index, self.total_rhyme_groups = \ | |
| defaultdict(lambda: UNKNOWN_RHYME_GROUP, self.rhyme_info.word2rhyme_group), self.rhyme_info.rhyme_group_counts, self.rhyme_info.rhyme_groups, self.rhyme_info.index2rhyme_group, self.rhyme_info.rhyme_group2index, self.rhyme_info.total_rhyme_groups | |
| print('done loading data') | |
| print('split sizes:') | |
| for key in ['train', 'val', 'test']: | |
| print(key, len(self.splits[key])) | |
| if not self.formality: | |
| print('total words', self.total_words) | |
| print('vocab size', len(self.index2word)) | |
| def shuffle(self, split, seed=None): | |
| assert split in ['train', 'val', 'test'] | |
| if seed is not None: | |
| random.seed(seed) | |
| random.shuffle(self.splits[split]) | |
| def loader(self, split, num_workers=20, indices=None): | |
| assert split in ['train', 'val', 'test'] | |
| data = self.splits[split] if indices is None else [self.splits[split][i] for i in indices] | |
| return torch.utils.data.DataLoader(SplitLoader(data, self), batch_size=self.batch_size, pin_memory=True, collate_fn=collate, num_workers=num_workers) | |
| class SplitLoader(torch.utils.data.IterableDataset): | |
| def __init__(self, data, parent): | |
| super(SplitLoader).__init__() | |
| self.data = data | |
| self.pos = 0 | |
| self.parent = parent | |
| def __len__(self): | |
| return len(self.data) | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| increment = 1 | |
| worker_info = torch.utils.data.get_worker_info() | |
| if worker_info is not None: # # in a worker process | |
| increment = worker_info.num_workers | |
| worker_id = worker_info.id | |
| if self.pos == 0: | |
| self.pos = worker_id | |
| valid = False | |
| while not valid: | |
| if self.pos >= len(self): | |
| raise StopIteration | |
| if self.parent.topic: | |
| failed = False | |
| future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1 | |
| raw_sentence, classification_label = self.data[self.pos], -1 | |
| original_sentence = raw_sentence.split() | |
| sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0] | |
| length = len(sentence) | |
| min_sentence_length = MIN_SENTENCE_LENGTH | |
| if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task | |
| pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once | |
| inp = sentence[:pos_to_split] | |
| length = len(inp) | |
| num_words_in_input = len(self.parent.tokenizer.decode(inp).split()) | |
| if not failed and num_words_in_input < len(original_sentence): | |
| future_word_position_max = len(original_sentence) - 1 | |
| future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though | |
| future_word = original_sentence[future_word_position] | |
| unstripped_future_word = future_word | |
| future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though. | |
| if not failed and future_word in self.parent.word2index.keys(): | |
| word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model | |
| future_word = self.parent.word2index[future_word] | |
| pad_id = self.parent.gpt_pad_id | |
| example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index) | |
| valid = not failed | |
| elif self.parent.formality: | |
| future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1 | |
| raw_sentence, classification_label = self.data[self.pos] | |
| original_sentence = raw_sentence.split() | |
| sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0] | |
| length = len(sentence) | |
| min_sentence_length = MIN_SENTENCE_LENGTH | |
| if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task | |
| pos_to_split = length # no need to split; we're going to train on all possible prefixes simultaneously for efficiency | |
| inp = sentence[:pos_to_split] | |
| length = len(inp) | |
| num_words_in_input = len(self.parent.tokenizer.decode(inp).split()) | |
| # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway | |
| future_word_position_max = len(original_sentence) - 1 | |
| future_word_position = 0 | |
| future_word = 'placeholder' | |
| unstripped_future_word = future_word | |
| future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though. | |
| word_log_prob, future_word = 0, 0 | |
| pad_id = self.parent.gpt_pad_id | |
| example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index) | |
| valid = True | |
| elif self.parent.iambic: | |
| failed = False | |
| future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1 | |
| raw_sentence, classification_label = self.data[self.pos], -1 | |
| original_sentence = raw_sentence.split() | |
| sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0] | |
| length = len(sentence) | |
| min_sentence_length = MIN_SENTENCE_LENGTH | |
| if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task | |
| pos_to_split = random.randint(0, length - 1) | |
| # try to get a subseq of exactly 10 syllables | |
| inp = sentence[pos_to_split:] | |
| num_syllables = 0 | |
| checked = False | |
| for i in range(1, len(inp)): | |
| decoded = self.parent.tokenizer.decode(inp[:i]) | |
| num_syllables = count_syllables(decoded) | |
| if num_syllables > POETRY_LINE_SYLLABLES: | |
| inp = inp[:i-1] # might get a few data points where the split is in the middle of a word, but it should be ok for learning. | |
| last_line_length = i-1 | |
| decoded = self.parent.tokenizer.decode(inp) | |
| num_syllables = count_syllables(decoded) | |
| checked = True | |
| break | |
| if not checked or num_syllables != POETRY_LINE_SYLLABLES: | |
| failed = True | |
| length = len(inp) | |
| num_words_in_input = len(self.parent.tokenizer.decode(inp).split()) | |
| classification_label = [is_iambic(self.parent.tokenizer.decode(inp)) for _ in range(length)] # predict for whole seq including future | |
| # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway | |
| future_word_position_max = len(original_sentence) - 1 | |
| future_word_position = 0 | |
| future_word = 'placeholder' | |
| unstripped_future_word = future_word | |
| future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though. | |
| if not failed: | |
| word_log_prob, future_word = 0, 0 | |
| pad_id = self.parent.gpt_pad_id | |
| example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index) | |
| valid = not failed | |
| elif self.parent.rhyme: | |
| failed = False | |
| future_word_num_syllables, rhyme_group_index = -1, -1 | |
| raw_sentence, classification_label = self.data[self.pos], -1 | |
| original_sentence = raw_sentence.split() | |
| sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0] | |
| length = len(sentence) | |
| min_sentence_length = MIN_SENTENCE_LENGTH | |
| if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task | |
| pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once | |
| inp = sentence[:pos_to_split] | |
| length = len(inp) | |
| num_words_in_input = len(self.parent.tokenizer.decode(inp).split()) | |
| if not failed and num_words_in_input < len(original_sentence): | |
| # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway | |
| future_word_position_max = min(len(original_sentence) - 1, num_words_in_input + MAX_COUNT_SYLLABLE_DIST) | |
| future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though | |
| future_word = original_sentence[future_word_position] | |
| unstripped_future_word = future_word | |
| future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though. | |
| words_in_between = original_sentence[num_words_in_input-1:future_word_position+1] | |
| syllables_to_go = count_syllables(' '.join(words_in_between)) | |
| if syllables_to_go > MAX_COUNT_SYLLABLE_DIST: | |
| failed = True | |
| future_word_num_syllables = count_syllables(future_word) | |
| rhyme_group = self.parent.word2rhyme_group[future_word] | |
| rhyme_group_index = self.parent.rhyme_group2index[rhyme_group] | |
| # truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose. | |
| desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH) | |
| inp = inp[-desired_length:] | |
| length = len(inp) | |
| if not failed and future_word in self.parent.word2index.keys(): | |
| word_log_prob = math.log(self.parent.rhyme_group_counts[rhyme_group] / self.parent.total_rhyme_groups) | |
| future_word = rhyme_group_index # future conditioning is just the rhyme group in this case | |
| pad_id = self.parent.gpt_pad_id | |
| example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index) | |
| valid = not failed | |
| elif self.parent.newline: | |
| failed = False | |
| future_word_num_syllables, rhyme_group_index = -1, -1 | |
| raw_sentence, classification_label = self.data[self.pos], -1 | |
| original_sentence = raw_sentence.split() | |
| sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0] | |
| length = len(sentence) | |
| min_sentence_length = MIN_SENTENCE_LENGTH | |
| if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task | |
| pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once | |
| inp = sentence[:pos_to_split] | |
| while pos_to_split < len(sentence): | |
| if len(self.parent.tokenizer.decode(inp).split()) == len(self.parent.tokenizer.decode(sentence[:pos_to_split + 1]).split()): | |
| pos_to_split += 1 | |
| inp = sentence[:pos_to_split] | |
| else: | |
| break | |
| length = len(inp) | |
| num_words_in_input = len(self.parent.tokenizer.decode(inp).split()) | |
| if not failed and num_words_in_input < len(original_sentence): | |
| # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway | |
| future_word_position_max = len(original_sentence) - 1 | |
| future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though | |
| future_word = original_sentence[future_word_position] | |
| unstripped_future_word = future_word | |
| future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though. | |
| # future_word = original_sentence[-1] # useful for debugging | |
| words_in_between = original_sentence[num_words_in_input-1:future_word_position+1] | |
| syllables_to_go = count_syllables(' '.join(words_in_between)) | |
| if syllables_to_go > MAX_COUNT_SYLLABLE_DIST: | |
| failed = True | |
| # truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose. | |
| desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH) | |
| # desired_length = 10 # useful for debugging | |
| inp = inp[-desired_length:] | |
| length = len(inp) | |
| true_label = 1 if unstripped_future_word.strip()[-1] in PHRASE_ENDS else 0 # common ways to end a phrase | |
| classification_label = [-1 for _ in range(length)] | |
| classification_label[-1] = true_label # only learn at the last position | |
| if not failed and future_word in self.parent.word2index.keys(): | |
| word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model | |
| future_word = self.parent.word2index[future_word] | |
| pad_id = self.parent.gpt_pad_id | |
| example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index) | |
| valid = not failed | |
| else: | |
| raise NotImplementedError | |
| self.pos += increment | |
| return example |