Spaces:
Running
Running
| import gradio as gr | |
| import io | |
| from PIL import Image | |
| import base64 | |
| import requests | |
| import json | |
| from PIL import Image | |
| def read_content(file_path: str) -> str: | |
| """read the content of target file | |
| """ | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| def base2picture(resbase64): | |
| res=resbase64.split(',')[1] | |
| img_b64decode = base64.b64decode(res) | |
| image = io.BytesIO(img_b64decode) | |
| img = Image.open(image) | |
| return img | |
| def filter_content(raw_style: str): | |
| if "(" in raw_style: | |
| i = raw_style.index("(") | |
| else : | |
| i = -1 | |
| if i == -1: | |
| return raw_style | |
| else : | |
| return raw_style[:i] | |
| def request_images(raw_text, class_draw, style_draw, batch_size): | |
| if filter_content(class_draw) != "国画": | |
| if filter_content(class_draw) != "通用": | |
| raw_text = raw_text + f",{filter_content(class_draw)}" | |
| for sty in style_draw: | |
| raw_text = raw_text + f",{filter_content(sty)}" | |
| print(f"raw text is {raw_text}") | |
| url = "http://flagart.baai.ac.cn/api/general/" | |
| elif filter_content(class_draw) == "国画": | |
| if raw_text.endswith("国画"): | |
| pass | |
| else : | |
| raw_text = raw_text + ",国画" | |
| url = "http://flagart.baai.ac.cn/api/guohua/" | |
| d = {"data":[raw_text, batch_size]} | |
| r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"}) | |
| result_text = r.text | |
| content = json.loads(result_text)["data"][0] | |
| images = [] | |
| for i in range(batch_size): | |
| # print(content[i]) | |
| images.append(base2picture(content[i])) | |
| return images | |
| examples = [ | |
| '水墨蝴蝶和牡丹花,国画', | |
| '苍劲有力的墨竹,国画', | |
| '暴风雨中的灯塔', | |
| '机械小松鼠,科学幻想', | |
| '中国水墨山水画,国画', | |
| "Lighthouse in the storm", | |
| "A dog", | |
| "Landscape by 张大千", | |
| "A tiger 长了兔子耳朵", | |
| "A baby bird 铅笔素描", | |
| ] | |
| if __name__ == "__main__": | |
| block = gr.Blocks(css=read_content('style.css')) | |
| with block: | |
| gr.HTML(read_content("header.html")) | |
| with gr.Group(): | |
| with gr.Box(): | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| text = gr.Textbox( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Input text(输入文字)", | |
| interactive=True, | |
| ).style( | |
| border=(True, False, True, True), | |
| rounded=(True, False, False, True), | |
| container=False, | |
| ) | |
| btn = gr.Button("Generate image").style( | |
| margin=False, | |
| rounded=(True, True, True, True), | |
| ) | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", | |
| "照片,摄影(picture photography)", "油画(oil painting)", | |
| "铅笔素描(pencil sketch)", "CG", | |
| "水彩画(watercolor painting)", "水墨画(ink and wash)", | |
| "插画(illustrations)", "3D", "图生图(img2img)"], | |
| label="生成类型(type)", | |
| show_label=True, | |
| value="通用(general)") | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", | |
| "概念艺术(concept art)", "Warming lighting", | |
| "Dramatic lighting", "Natural lighting", | |
| "虚幻引擎(unreal engine)", "4k", "8k", | |
| "充满细节(full details)"], | |
| label="画面风格(style)", | |
| show_label=True, | |
| ) | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| sample_size = gr.Slider(minimum=1, | |
| maximum=4, | |
| step=1, | |
| label="生成数量(number)", | |
| show_label=True, | |
| interactive=True, | |
| ) | |
| gallery = gr.Gallery( | |
| label="Generated images", show_label=False, elem_id="gallery" | |
| ).style(grid=[2], height="auto") | |
| gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100) | |
| text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery) | |
| btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery) | |
| gr.HTML(read_content("footer.html")) | |
| # gr.Image('./contributors.png') | |
| block.queue(max_size=50, concurrency_count=20).launch() |