Spaces:
Build error
Build error
| import copy | |
| import functools | |
| import itertools | |
| import logging | |
| import random | |
| import string | |
| from typing import List, Optional | |
| import requests | |
| import numpy as np | |
| import tensorflow as tf | |
| import streamlit as st | |
| from gazpacho import Soup, get | |
| from transformers import BertTokenizer, TFAutoModelForMaskedLM | |
| DEFAULT_QUERY = "Machines will take over the world soon" | |
| N_RHYMES = 10 | |
| ITER_FACTOR = 5 | |
| LANGUAGE = st.sidebar.radio("Language", ["english", "dutch"],0) | |
| if LANGUAGE == "english": | |
| MODEL_PATH = "bert-large-cased-whole-word-masking" | |
| elif LANGUAGE == "dutch": | |
| MODEL_PATH = "GroNLP/bert-base-dutch-cased" | |
| else: | |
| raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english' or 'dutch'.") | |
| def main(): | |
| st.markdown( | |
| "<sup>Created with " | |
| "[Datamuse](https://www.datamuse.com/api/), " | |
| "[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), " | |
| "[Hugging Face](https://huggingface.co/), " | |
| "[Streamlit](https://streamlit.io/) and " | |
| "[App Engine](https://cloud.google.com/appengine/)." | |
| " Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) " | |
| "or check the " | |
| "[source](https://github.com/godatadriven/rhyme-with-ai).</sup>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.title("Rhyme with AI") | |
| query = get_query() | |
| if not query: | |
| query = DEFAULT_QUERY | |
| rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE) | |
| if rhyme_words_options: | |
| logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options) | |
| start_rhyming(query, rhyme_words_options) | |
| else: | |
| st.write("No rhyme words found") | |
| def get_query(): | |
| q = sanitize( | |
| st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY) | |
| ) | |
| if not q: | |
| return DEFAULT_QUERY | |
| return q | |
| def start_rhyming(query, rhyme_words_options): | |
| st.markdown("## My Suggestions:") | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| max_iter = len(query.split()) * ITER_FACTOR | |
| rhyme_words = rhyme_words_options[:N_RHYMES] | |
| model, tokenizer = load_model(MODEL_PATH) | |
| sentence_generator = RhymeGenerator(model, tokenizer) | |
| sentence_generator.start(query, rhyme_words) | |
| current_sentences = [" " for _ in range(N_RHYMES)] | |
| for i in range(max_iter): | |
| previous_sentences = copy.deepcopy(current_sentences) | |
| current_sentences = sentence_generator.mutate() | |
| display_output(status_text, query, current_sentences, previous_sentences) | |
| progress_bar.progress(i / (max_iter - 1)) | |
| st.balloons() | |
| def load_model(model_path): | |
| return ( | |
| TFAutoModelForMaskedLM.from_pretrained(model_path), | |
| BertTokenizer.from_pretrained(model_path), | |
| ) | |
| def display_output(status_text, query, current_sentences, previous_sentences): | |
| print_sentences = [] | |
| for new, old in zip(current_sentences, previous_sentences): | |
| formatted = color_new_words(new, old) | |
| after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>" | |
| print_sentences.append(after_comma) | |
| status_text.markdown( | |
| query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True | |
| ) | |
| class TokenWeighter: | |
| def __init__(self, tokenizer): | |
| self.tokenizer_ = tokenizer | |
| self.proba = self.get_token_proba() | |
| def get_token_proba(self): | |
| valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab) | |
| return valid_token_mask | |
| def _filter_short_partial(self, vocab): | |
| valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k] | |
| is_valid = np.zeros(len(vocab.keys())) | |
| is_valid[valid_token_ids] = 1 | |
| return is_valid | |
| class RhymeGenerator: | |
| def __init__( | |
| self, | |
| model: TFAutoModelForMaskedLM, | |
| tokenizer: BertTokenizer, | |
| token_weighter: TokenWeighter = None, | |
| ): | |
| """Generate rhymes. | |
| Parameters | |
| ---------- | |
| model : Model for masked language modelling | |
| tokenizer : Tokenizer for model | |
| token_weighter : Class that weighs tokens | |
| """ | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| if token_weighter is None: | |
| token_weighter = TokenWeighter(tokenizer) | |
| self.token_weighter = token_weighter | |
| self._logger = logging.getLogger(__name__) | |
| self.tokenized_rhymes_ = None | |
| self.position_probas_ = None | |
| # Easy access. | |
| self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0] | |
| self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0] | |
| self.mask_token_id = self.tokenizer.mask_token_id | |
| def start(self, query: str, rhyme_words: List[str]) -> None: | |
| """Start the sentence generator. | |
| Parameters | |
| ---------- | |
| query : Seed sentence | |
| rhyme_words : Rhyme words for next sentence | |
| """ | |
| # TODO: What if no content? | |
| self._logger.info("Got sentence %s", query) | |
| tokenized_rhymes = [ | |
| self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words | |
| ] | |
| # Make same length. | |
| self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences( | |
| tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id | |
| ) | |
| p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id | |
| self.position_probas_ = p / p.sum(1).reshape(-1, 1) | |
| def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]: | |
| """Initialize the rhymes. | |
| * Tokenize input | |
| * Append a comma if the sentence does not end in it (might add better predictions as it | |
| shows the two sentence parts are related) | |
| * Make second line as long as the original | |
| * Add a period | |
| Parameters | |
| ---------- | |
| query : First line | |
| rhyme_word : Last word for second line | |
| Returns | |
| ------- | |
| Tokenized rhyme lines | |
| """ | |
| query_token_ids = self.tokenizer.encode(query, add_special_tokens=False) | |
| rhyme_word_token_ids = self.tokenizer.encode( | |
| rhyme_word, add_special_tokens=False | |
| ) | |
| if query_token_ids[-1] != self.comma_token_id: | |
| query_token_ids.append(self.comma_token_id) | |
| magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma | |
| return ( | |
| query_token_ids | |
| + [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction) | |
| + rhyme_word_token_ids | |
| + [self.period_token_id] | |
| ) | |
| def mutate(self): | |
| """Mutate the current rhymes. | |
| Returns | |
| ------- | |
| Mutated rhymes | |
| """ | |
| self.tokenized_rhymes_ = self._mutate( | |
| self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba | |
| ) | |
| rhymes = [] | |
| for i in range(len(self.tokenized_rhymes_)): | |
| rhymes.append( | |
| self.tokenizer.convert_tokens_to_string( | |
| self.tokenizer.convert_ids_to_tokens( | |
| self.tokenized_rhymes_[i], skip_special_tokens=True | |
| ) | |
| ) | |
| ) | |
| return rhymes | |
| def _mutate( | |
| self, | |
| tokenized_rhymes: np.ndarray, | |
| position_probas: np.ndarray, | |
| token_id_probas: np.ndarray, | |
| ) -> np.ndarray: | |
| replacements = [] | |
| for i in range(tokenized_rhymes.shape[0]): | |
| mask_idx, masked_token_ids = self._mask_token( | |
| tokenized_rhymes[i], position_probas[i] | |
| ) | |
| tokenized_rhymes[i] = masked_token_ids | |
| replacements.append(mask_idx) | |
| predictions = self._predict_masked_tokens(tokenized_rhymes) | |
| for i, token_ids in enumerate(tokenized_rhymes): | |
| replace_ix = replacements[i] | |
| token_ids[replace_ix] = self._draw_replacement( | |
| predictions[i], token_id_probas, replace_ix | |
| ) | |
| tokenized_rhymes[i] = token_ids | |
| return tokenized_rhymes | |
| def _mask_token(self, token_ids, position_probas): | |
| """Mask line and return index to update.""" | |
| token_ids = self._mask_repeats(token_ids, position_probas) | |
| ix = self._locate_mask(token_ids, position_probas) | |
| token_ids[ix] = self.mask_token_id | |
| return ix, token_ids | |
| def _locate_mask(self, token_ids, position_probas): | |
| """Update masks or a random token.""" | |
| if self.mask_token_id in token_ids: | |
| # Already masks present, just return the last. | |
| # We used to return thee first but this returns worse predictions. | |
| return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1] | |
| return np.random.choice(range(len(position_probas)), p=position_probas) | |
| def _mask_repeats(self, token_ids, position_probas): | |
| """Repeated tokens are generally of less quality.""" | |
| repeats = [ | |
| ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1] | |
| ] | |
| for ii in repeats: | |
| if position_probas[ii] > 0: | |
| token_ids[ii] = self.mask_token_id | |
| if position_probas[ii + 1] > 0: | |
| token_ids[ii + 1] = self.mask_token_id | |
| return token_ids | |
| def _predict_masked_tokens(self, tokenized_rhymes): | |
| return self.model(tf.constant(tokenized_rhymes))[0] | |
| def _draw_replacement(self, predictions, token_probas, replace_ix): | |
| """Get probability, weigh and draw.""" | |
| # TODO (HG): Can't we softmax when calling the model? | |
| probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas | |
| probas /= probas.sum() | |
| return np.random.choice(range(len(probas)), p=probas) | |
| def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]: | |
| """Returns a list of rhyme words for a sentence. | |
| Parameters | |
| ---------- | |
| sentence : Sentence that may end with punctuation | |
| n_rhymes : Maximum number of rhymes to return | |
| Returns | |
| ------- | |
| List[str] -- List of words that rhyme with the final word | |
| """ | |
| last_word = find_last_word(sentence) | |
| if language == "english": | |
| return query_datamuse_api(last_word, n_rhymes) | |
| elif language == "dutch": | |
| return mick_rijmwoordenboek(last_word, n_rhymes) | |
| else: | |
| raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.") | |
| def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]: | |
| """Query the DataMuse API. | |
| Parameters | |
| ---------- | |
| word : Word to rhyme with | |
| n_rhymes : Max rhymes to return | |
| Returns | |
| ------- | |
| Rhyme words | |
| """ | |
| out = requests.get( | |
| "https://api.datamuse.com/words", params={"rel_rhy": word} | |
| ).json() | |
| words = [_["word"] for _ in out] | |
| if n_rhymes is None: | |
| return words | |
| return words[:n_rhymes] | |
| def mick_rijmwoordenboek(word: str, n_words: int): | |
| url = f"https://rijmwoordenboek.nl/rijm/{word}" | |
| html = get(url) | |
| soup = Soup(html) | |
| results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br>") | |
| # clean up | |
| results = [r.replace("\n", "").replace(" ", "") for r in results] | |
| # filter html and empty strings | |
| results = [r for r in results if ("<" not in r) and (len(r) > 0)] | |
| return random.sample(results, min(len(results), n_words)) | |
| def color_new_words(new: str, old: str, color: str = "#eefa66") -> str: | |
| """Color new words in strings with a span.""" | |
| def find_diff(new_, old_): | |
| return [ii for ii, (n, o) in enumerate(zip(new_, old_)) if n != o] | |
| new_words = new.split() | |
| old_words = old.split() | |
| forward = find_diff(new_words, old_words) | |
| backward = find_diff(new_words[::-1], old_words[::-1]) | |
| if not forward or not backward: | |
| # No difference | |
| return new | |
| start, end = forward[0], len(new_words) - backward[0] | |
| return ( | |
| " ".join(new_words[:start]) | |
| + " " | |
| + f'<span style="background-color: {color}">' | |
| + " ".join(new_words[start:end]) | |
| + "</span>" | |
| + " " | |
| + " ".join(new_words[end:]) | |
| ) | |
| def find_last_word(s): | |
| """Find the last word in a string.""" | |
| # Note: will break on \n, \r, etc. | |
| alpha_only_sentence = "".join([c for c in s if (c.isalpha() or (c == " "))]).strip() | |
| return alpha_only_sentence.split()[-1] | |
| def pairwise(iterable): | |
| """s -> (s0,s1), (s1,s2), (s2, s3), ...""" | |
| # https://stackoverflow.com/questions/5434891/iterate-a-list-as-pair-current-next-in-python | |
| a, b = itertools.tee(iterable) | |
| next(b, None) | |
| return zip(a, b) | |
| def sanitize(s): | |
| """Remove punctuation from a string.""" | |
| return s.translate(str.maketrans("", "", string.punctuation)) | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| main() | |