Spaces:
Build error
Build error
| import copy | |
| import logging | |
| from typing import List | |
| import streamlit as st | |
| from transformers import BertTokenizer, TFAutoModelForMaskedLM | |
| from rhyme_with_ai.utils import color_new_words, sanitize | |
| from rhyme_with_ai.rhyme import query_rhyme_words | |
| from rhyme_with_ai.rhyme_generator import RhymeGenerator | |
| DEFAULT_QUERY = "Machines will take over the world soon" | |
| N_RHYMES = 10 | |
| LANGUAGE = st.sidebar.radio("Language", ["english", "dutch"],0) | |
| if LANGUAGE == "english": | |
| MODEL_PATH = "bert-large-cased-whole-word-masking" | |
| ITER_FACTOR = 5 | |
| elif LANGUAGE == "dutch": | |
| MODEL_PATH = "GroNLP/bert-base-dutch-cased" | |
| ITER_FACTOR = 10 # Faster model | |
| else: | |
| raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english' or 'dutch'.") | |
| def main(): | |
| st.markdown( | |
| "<sup>Created with " | |
| "[Datamuse](https://www.datamuse.com/api/), " | |
| "[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), " | |
| "[Hugging Face](https://huggingface.co/), " | |
| "[Streamlit](https://streamlit.io/) and " | |
| "[App Engine](https://cloud.google.com/appengine/)." | |
| " Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) " | |
| "or check the " | |
| "[source](https://github.com/godatadriven/rhyme-with-ai).</sup>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.title("Rhyme with AI") | |
| query = get_query() | |
| if not query: | |
| query = DEFAULT_QUERY | |
| rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE) | |
| if rhyme_words_options: | |
| logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options) | |
| start_rhyming(query, rhyme_words_options) | |
| else: | |
| st.write("No rhyme words found") | |
| def get_query(): | |
| q = sanitize( | |
| st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY) | |
| ) | |
| if not q: | |
| return DEFAULT_QUERY | |
| return q | |
| def start_rhyming(query, rhyme_words_options): | |
| st.markdown("## My Suggestions:") | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| max_iter = len(query.split()) * ITER_FACTOR | |
| rhyme_words = rhyme_words_options[:N_RHYMES] | |
| model, tokenizer = load_model(MODEL_PATH) | |
| sentence_generator = RhymeGenerator(model, tokenizer) | |
| sentence_generator.start(query, rhyme_words) | |
| current_sentences = [" " for _ in range(N_RHYMES)] | |
| for i in range(max_iter): | |
| previous_sentences = copy.deepcopy(current_sentences) | |
| current_sentences = sentence_generator.mutate() | |
| display_output(status_text, query, current_sentences, previous_sentences) | |
| progress_bar.progress(i / (max_iter - 1)) | |
| st.balloons() | |
| def load_model(model_path): | |
| return ( | |
| TFAutoModelForMaskedLM.from_pretrained(model_path), | |
| BertTokenizer.from_pretrained(model_path), | |
| ) | |
| def display_output(status_text, query, current_sentences, previous_sentences): | |
| print_sentences = [] | |
| for new, old in zip(current_sentences, previous_sentences): | |
| formatted = color_new_words(new, old) | |
| after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>" | |
| print_sentences.append(after_comma) | |
| status_text.markdown( | |
| query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True | |
| ) | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| main() | |