Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import pipeline, set_seed | |
| from transformers import AutoTokenizer | |
| from PIL import ( | |
| Image, | |
| ImageFont, | |
| ImageDraw | |
| ) | |
| import re | |
| import textwrap | |
| from examples import EXAMPLES | |
| import meta | |
| from utils import ( | |
| remote_css, | |
| local_css, | |
| load_image, | |
| pure_comma_separation | |
| ) | |
| class TextGeneration: | |
| def __init__(self): | |
| self.debug = False | |
| self.dummy_output = { | |
| 'directions': [ | |
| 'peel the potato and slice thinly.', | |
| 'place in a microwave safe dish.', | |
| 'cover with plastic wrap and microwave on high for 5 minutes.', | |
| 'remove from the microwave and sprinkle with cheese.', | |
| 'return to the microwave for 1 minute or until cheese is melted.', | |
| 'return to the microwave for 1 minute or until cheese is melted. return to the microwave for 1 minute or until cheese is melted.' | |
| 'return to the microwave for 1 minute or until cheese is melted.', | |
| 'return to the microwave for 1 minute or until cheese is melted.', | |
| 'return to the microwave for 1 minute or until cheese is melted.', | |
| ], | |
| 'ingredients': [ | |
| '1 potato', | |
| '1 slice cheese', | |
| '1 potato', | |
| '1 slice cheese' | |
| '1 potato', | |
| '1 slice cheese', | |
| '1 slice cheese', | |
| '1 potato', | |
| '1 slice cheese' | |
| '1 potato', | |
| '1 slice cheese', | |
| ], | |
| 'title': 'Cheese Potatoes' | |
| } | |
| self.tokenizer = None | |
| self.generator = None | |
| self.task = "text2text-generation" | |
| self.model_name_or_path = "flax-community/t5-recipe-generation" | |
| self.list_division = 5 | |
| self.point = "-" | |
| self.h1_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 100) | |
| self.h2_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 50) | |
| self.p_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Regular.ttf", 30) | |
| set_seed(42) | |
| def _skip_special_tokens_and_prettify(self, text): | |
| recipe_maps = {"<sep>": "--", "<section>": "\n"} | |
| recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys())) | |
| text = re.sub( | |
| recipe_map_pattern, | |
| lambda m: recipe_maps[m.group()], | |
| re.sub("|".join(self.tokenizer.all_special_tokens), "", text) | |
| ) | |
| data = {"title": "", "ingredients": [], "directions": []} | |
| for section in text.split("\n"): | |
| section = section.strip() | |
| if section.startswith("title:"): | |
| data["title"] = " ".join( | |
| [w.strip().capitalize() for w in section.replace("title:", "").strip().split() if w.strip()] | |
| ) | |
| elif section.startswith("ingredients:"): | |
| data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')] | |
| elif section.startswith("directions:"): | |
| data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')] | |
| else: | |
| pass | |
| return data | |
| def load(self): | |
| if not self.debug: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
| self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path) | |
| def prepare_frame(self, recipe, frame): | |
| im_editable = ImageDraw.Draw(frame) | |
| # Title | |
| ws, hs = 120, 500 | |
| im_editable.text( | |
| (ws, hs), | |
| recipe["title"], | |
| (61, 61, 70), | |
| font=self.h1_font, | |
| ) | |
| # Ingredients | |
| hs = hs + 250 | |
| im_editable.text( | |
| (ws, hs), | |
| "Ingredients", | |
| (61, 61, 70), | |
| font=self.h2_font, | |
| ) | |
| hs = hs + 80 | |
| ingredients = recipe["ingredients"] | |
| ingredients_col1 = [textwrap.fill(item, 30) for item in ingredients[:self.list_division]] | |
| ingredients_col2 = [textwrap.fill(item, 30) for item in ingredients[self.list_division:]] | |
| im_editable.text( | |
| (ws + 10, hs), | |
| "\n".join([f"- {item}" for item in ingredients_col1]), | |
| (61, 61, 70), | |
| font=self.p_font, | |
| ) | |
| im_editable.text( | |
| (ws + 500, hs), | |
| "\n".join([f"{self.point} {item}" for item in ingredients_col2]), | |
| (61, 61, 70), | |
| font=self.p_font, | |
| ) | |
| # Directions | |
| hs = hs + 200 | |
| im_editable.text( | |
| (ws, hs), | |
| "Directions", | |
| (61, 61, 70), | |
| font=self.h2_font, | |
| ) | |
| hs = hs + 80 | |
| directions = [textwrap.fill(item, 70).replace("\n", "\n ") for item in recipe["directions"]] | |
| im_editable.text( | |
| (ws + 10, hs), | |
| "\n".join([f"{self.point} {item}" for item in directions]), | |
| (61, 61, 70), | |
| font=self.p_font, | |
| ) | |
| return frame | |
| def generate(self, items, generation_kwargs): | |
| print(generation_kwargs) | |
| if not self.debug: | |
| generation_kwargs["num_return_sequences"] = 1 | |
| # generation_kwargs["return_full_text"] = False | |
| generation_kwargs["return_tensors"] = True | |
| generation_kwargs["return_text"] = False | |
| generated_ids = self.generator( | |
| items, | |
| **generation_kwargs, | |
| )[0]["generated_token_ids"] | |
| recipe = self.tokenizer.decode(generated_ids, skip_special_tokens=False) | |
| recipe = self._skip_special_tokens_and_prettify(recipe) | |
| return recipe | |
| return self.dummy_output | |
| def generate_frame(self, recipe): | |
| frame = load_image("asset/images/recipe-post.png") | |
| return self.prepare_frame(recipe, frame) | |
| def load_text_generator(): | |
| generator = TextGeneration() | |
| generator.load() | |
| return generator | |
| chef_top = { | |
| "max_length": 512, | |
| "min_length": 64, | |
| "no_repeat_ngram_size": 3, | |
| "do_sample": True, | |
| "top_k": 60, | |
| "top_p": 0.95, | |
| "num_return_sequences": 1 | |
| } | |
| chef_beam = { | |
| "max_length": 512, | |
| "min_length": 64, | |
| "no_repeat_ngram_size": 3, | |
| "early_stopping": True, | |
| "num_beams": 5, | |
| "length_penalty": 1.5, | |
| "num_return_sequences": 1 | |
| } | |
| def main(): | |
| st.set_page_config( | |
| page_title="Chef Transformer", | |
| page_icon="🍲", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| generator = load_text_generator() | |
| local_css("asset/css/style.css") | |
| st.sidebar.image(load_image("asset/images/chef-transformer-transparent.png"), width=310) | |
| st.sidebar.title("Welcome to our lovely restaurant, what can I do for you?") | |
| chef = st.sidebar.selectbox("Choose your chef", index=0, options=["Chef Scheherazade", "Chef Giovanni"]) | |
| is_text = st.sidebar.selectbox( | |
| label='Recipe', | |
| options=(True, False), | |
| help="Will generate your recipe as a text post", | |
| ) | |
| is_frame = st.sidebar.selectbox( | |
| label='Recipe for Instagram?', | |
| options=(True, False), | |
| help="Will generate your recipe as an Instagram post", | |
| ) | |
| st.markdown(meta.HEADER_INFO) | |
| prompts = list(EXAMPLES.keys()) + ["Custom"] | |
| prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) | |
| if prompt == "Custom": | |
| prompt_box = "" | |
| else: | |
| prompt_box = EXAMPLES[prompt] | |
| items = st.text_input( | |
| 'Add custom ingredients here (separated by `,`): ', | |
| pure_comma_separation(prompt_box, return_list=False), | |
| key="custom_keywords", | |
| max_chars=1000) | |
| items = pure_comma_separation(items, return_list=False) | |
| entered_items = st.empty() | |
| if st.button('Get Recipe!'): | |
| entered_items.markdown("**Generate recipe for:** " + items) | |
| with st.spinner("Generating recipe..."): | |
| gen_kw = chef_top if chef == "Chef Scheherazade" else chef_beam | |
| generated_recipe = generator.generate(items, gen_kw) | |
| if generated_recipe: | |
| if is_text: | |
| title = generated_recipe["title"] | |
| ingredients = generated_recipe["ingredients"] | |
| directions = [textwrap.fill(item, 70).replace("\n", "\n ") for item in | |
| generated_recipe["directions"]] | |
| st.markdown( | |
| " ".join([ | |
| f"<h2>{title}</h2>", | |
| "<h3>Ingredient</h3>", | |
| "<ul class='ingredients-list'>", | |
| " ".join([f'<li>{item}</li>' for item in ingredients]), | |
| "</ul>", | |
| "<h3>Direction</h3>", | |
| "<ul class='ingredients-list'>", | |
| " ".join([f'<li>{item}</li>' for item in directions]), | |
| "</ul>", | |
| ]), | |
| unsafe_allow_html=True | |
| ) | |
| if is_frame: | |
| recipe_post = generator.generate_frame(generated_recipe) | |
| col1, col2, col3 = st.beta_columns([1, 6, 1]) | |
| with col1: | |
| st.write("") | |
| with col2: | |
| st.image( | |
| recipe_post, | |
| # width=500, | |
| caption="Your recipe", | |
| use_column_width="auto", | |
| output_format="PNG" | |
| ) | |
| with col3: | |
| st.write("") | |
| if __name__ == '__main__': | |
| main() | |