Spaces:
Runtime error
Runtime error
add multiprocessing
Browse files- app_dialogue.py +43 -14
app_dialogue.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import List, Optional, Tuple
|
|
| 10 |
from urllib.parse import urlparse
|
| 11 |
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
|
|
|
|
| 13 |
import random
|
| 14 |
import gradio as gr
|
| 15 |
import PIL
|
|
@@ -777,6 +778,28 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
|
|
| 777 |
with gr.Row():
|
| 778 |
chatbot.render()
|
| 779 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
def model_inference(
|
| 781 |
model_selector,
|
| 782 |
system_prompt,
|
|
@@ -849,21 +872,27 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
|
|
| 849 |
|
| 850 |
query = prompt_list_to_tgi_input(formated_prompt_list)
|
| 851 |
all_meme_images = []
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
|
|
|
|
|
|
|
|
|
| 861 |
)
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
|
|
|
|
|
|
|
|
|
| 867 |
|
| 868 |
gr.on(
|
| 869 |
triggers=[textbox.submit, imagebox.upload, submit_btn.click],
|
|
|
|
| 10 |
from urllib.parse import urlparse
|
| 11 |
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
|
| 13 |
+
import concurrent.futures
|
| 14 |
import random
|
| 15 |
import gradio as gr
|
| 16 |
import PIL
|
|
|
|
| 778 |
with gr.Row():
|
| 779 |
chatbot.render()
|
| 780 |
|
| 781 |
+
def generate_meme(
|
| 782 |
+
i,
|
| 783 |
+
client,
|
| 784 |
+
query,
|
| 785 |
+
image,
|
| 786 |
+
font_meme_text,
|
| 787 |
+
all_caps_meme_text,
|
| 788 |
+
text_at_the_top,
|
| 789 |
+
generation_args,
|
| 790 |
+
):
|
| 791 |
+
text = client.generate(prompt=query, **generation_args).generated_text
|
| 792 |
+
if image is not None and text != "":
|
| 793 |
+
meme_image = make_meme_image(
|
| 794 |
+
image=image,
|
| 795 |
+
text=text,
|
| 796 |
+
font_meme_text=font_meme_text,
|
| 797 |
+
all_caps_meme_text=all_caps_meme_text,
|
| 798 |
+
text_at_the_top=text_at_the_top,
|
| 799 |
+
)
|
| 800 |
+
meme_image = pil_to_temp_file(meme_image)
|
| 801 |
+
return meme_image
|
| 802 |
+
|
| 803 |
def model_inference(
|
| 804 |
model_selector,
|
| 805 |
system_prompt,
|
|
|
|
| 872 |
|
| 873 |
query = prompt_list_to_tgi_input(formated_prompt_list)
|
| 874 |
all_meme_images = []
|
| 875 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
| 876 |
+
futures = [
|
| 877 |
+
executor.submit(
|
| 878 |
+
generate_meme,
|
| 879 |
+
i,
|
| 880 |
+
client,
|
| 881 |
+
query,
|
| 882 |
+
image,
|
| 883 |
+
font_meme_text,
|
| 884 |
+
all_caps_meme_text,
|
| 885 |
+
text_at_the_top,
|
| 886 |
+
generation_args,
|
| 887 |
)
|
| 888 |
+
for i in range(4)
|
| 889 |
+
]
|
| 890 |
+
|
| 891 |
+
for future in concurrent.futures.as_completed(futures):
|
| 892 |
+
meme_image = future.result()
|
| 893 |
+
if meme_image:
|
| 894 |
+
all_meme_images.append(meme_image)
|
| 895 |
+
return user_prompt_str, all_meme_images, chat_history
|
| 896 |
|
| 897 |
gr.on(
|
| 898 |
triggers=[textbox.submit, imagebox.upload, submit_btn.click],
|