Spaces:
Running
Running
| import io | |
| import re | |
| import imp | |
| import time | |
| import json | |
| import base64 | |
| import requests | |
| import gradio as gr | |
| import ui_functions as uifn | |
| from css_and_js import js, call_JS | |
| from PIL import Image, PngImagePlugin, ImageChops | |
| url_host = "https://flagstudio.baai.ac.cn" | |
| token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMGY4M2QxMDg3N2MzMTFlZGFiYzYwZmU5ZGFjMTI1ZDMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6ImE3YTE1N2I3LTllNTItNDllMS04YzA0LWEzZmI5YjZiZjNlYSIsIm5iZiI6MTY3MDU5MTcwMSwiZXhwIjoxOTg1OTUxNzAxLCJpYXQiOjE2NzA1OTE3MDF9.OcfGayna-wr_5mo4LT6OJHSCokna8vqKSmmCftFUsx8" | |
| def read_content(file_path: str) -> str: | |
| """read the content of target file | |
| """ | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| def filter_content(raw_style: str): | |
| if "(" in raw_style: | |
| i = raw_style.index("(") | |
| else : | |
| i = -1 | |
| if i == -1: | |
| return raw_style | |
| else : | |
| return raw_style[:i] | |
| def upload_image(img): | |
| url = url_host + "/api/v1/image/get-upload-link" | |
| headers = {"token": token} | |
| r = requests.post(url, json={}, headers=headers) | |
| if r.status_code != 200: | |
| raise gr.Error(r.reason) | |
| head_res = r.json() | |
| if head_res["code"] != 0: | |
| raise gr.Error("Unknown error") | |
| image_id = head_res["data"]["image_id"] | |
| image_url = head_res["data"]["url"] | |
| image_headers = head_res["data"]["headers"] | |
| imgBytes = io.BytesIO() | |
| img.save(imgBytes, "PNG") | |
| imgBytes = imgBytes.getvalue() | |
| r = requests.put(image_url, data=imgBytes, headers=image_headers) | |
| if r.status_code != 200: | |
| raise gr.Error(r.reason) | |
| return image_id, image_url | |
| def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None): | |
| data = { | |
| "type": "gen-image", | |
| "parameters": { | |
| "width": width, # output height width | |
| "height": height, # output image height | |
| "prompts": [prompt], | |
| } | |
| } | |
| data["parameters"]["seed"] = int(seed) | |
| if img is not None: | |
| # Upload image | |
| image_id, image_url = upload_image(img) | |
| data["parameters"]["init_image"] = { | |
| "image_id": image_id, | |
| "url": image_url, | |
| "width": img.width, | |
| "height": img.height, | |
| } | |
| if mask is not None: | |
| # Upload mask | |
| extrama = mask.convert("L").getextrema() | |
| if extrama[1] > 0: | |
| mask_id, mask_url = upload_image(mask) | |
| data["parameters"]["mask_image"] = { | |
| "image_id": mask_id, | |
| "url": mask_url, | |
| "width": mask.width, | |
| "height": mask.height, | |
| } | |
| headers = {"token": token} | |
| # Send create task request | |
| all_task_data = [] | |
| url = url_host+"/api/v1/task/create" | |
| for _ in range(image_num): | |
| r = requests.post(url, json=data, headers=headers) | |
| if r.status_code != 200: | |
| raise gr.Error(r.reason) | |
| create_res = r.json() | |
| if create_res['code'] == 3002: | |
| raise gr.Error("Inappropriate prompt detected.") | |
| elif create_res['code'] != 0: | |
| raise gr.Error("Unknown error") | |
| all_task_data.append(create_res["data"]) | |
| # Get result | |
| url = url_host+"/api/v1/task/status" | |
| images = [] | |
| while True: | |
| if len(all_task_data) <= 0: | |
| return images | |
| for i in range(len(all_task_data)-1, -1, -1): | |
| data = all_task_data[i] | |
| r = requests.post(url, json=data, headers=headers) | |
| if r.status_code != 200: | |
| raise gr.Error(r.reason) | |
| res = r.json() | |
| if res["code"] == 6002: | |
| # Running | |
| continue | |
| if res["code"] == 6005: | |
| raise gr.Error("NSFW image detected.") | |
| elif res["code"] == 0: | |
| # Finished | |
| for img_info in res["data"]["images"]: | |
| img_res = requests.get(img_info["url"]) | |
| images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB")) | |
| del all_task_data[i] | |
| else: | |
| raise gr.Error(f"Error code: {res['code']}") | |
| time.sleep(1) | |
| def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed): | |
| if filter_content(class_draw) != "国画": | |
| if filter_content(class_draw) != "通用": | |
| raw_text = raw_text + f",{filter_content(class_draw)}" | |
| for sty in style_draw: | |
| raw_text = raw_text + f",{filter_content(sty)}" | |
| elif filter_content(class_draw) == "国画": | |
| raw_text = raw_text + ",国画,水墨画,大作,黑白,高清,传统" | |
| print(f"raw text is {raw_text}") | |
| images = post_reqest(seed, raw_text, w, h, int(batch_size)) | |
| return images | |
| def img2img(prompt, image_and_mask): | |
| if image_and_mask["image"].width <= image_and_mask["image"].height: | |
| width = 512 | |
| height = int((width/image_and_mask["image"].width)*image_and_mask["image"].height) | |
| else: | |
| height = 512 | |
| width = int((height/image_and_mask["image"].height)*image_and_mask["image"].width) | |
| return post_reqest(0, prompt, width, height, 1, image_and_mask["image"], image_and_mask["mask"]) | |
| examples = [ | |
| '水墨蝴蝶和牡丹花,国画', | |
| '苍劲有力的墨竹,国画', | |
| '暴风雨中的灯塔', | |
| '机械小松鼠,科学幻想', | |
| '中国水墨山水画,国画', | |
| "Lighthouse in the storm", | |
| "A dog", | |
| "Landscape by 张大千", | |
| "A tiger 长了兔子耳朵", | |
| "A baby bird 铅笔素描", | |
| ] | |
| if __name__ == "__main__": | |
| block = gr.Blocks(css=read_content('style.css')) | |
| with block: | |
| gr.HTML(read_content("header.html")) | |
| with gr.Tabs(elem_id='tabss') as tabs: | |
| with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'): | |
| with gr.Group(): | |
| with gr.Box(): | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| text = gr.Textbox( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Input text(输入文字)", | |
| interactive=True, | |
| ).style( | |
| border=(True, False, True, True), | |
| rounded=(True, False, False, True), | |
| container=False, | |
| ) | |
| btn = gr.Button("Generate image").style( | |
| margin=False, | |
| rounded=(True, True, True, True), | |
| ) | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| class_draw = gr.Radio(choices=["通用(general)","国画(traditional Chinese painting)",], value="通用(general)", show_label=True, label='生成类型(type)') | |
| # class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", | |
| # "照片,摄影(picture photography)", "油画(oil painting)", | |
| # "铅笔素描(pencil sketch)", "CG", | |
| # "水彩画(watercolor painting)", "水墨画(ink and wash)", | |
| # "插画(illustrations)", "3D", "图生图(img2img)"], | |
| # label="生成类型(type)", | |
| # show_label=True, | |
| # value="通用(general)") | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", | |
| "概念艺术(concept art)", "Warming lighting", | |
| "Dramatic lighting", "Natural lighting", | |
| "虚幻引擎(unreal engine)", "4k", "8k", | |
| "充满细节(full details)"], | |
| label="画面风格(style)", | |
| show_label=True, | |
| ) | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| # sample_size = gr.Slider(minimum=1, | |
| # maximum=4, | |
| # step=1, | |
| # label="生成数量(number)", | |
| # show_label=True, | |
| # interactive=True, | |
| # ) | |
| sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)') | |
| seed = gr.Number(0, label='seed', interactive=True) | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| w = gr.Slider(512,1024,value=512, step=64, label="width") | |
| h = gr.Slider(512,1024,value=512, step=64, label="height") | |
| gallery = gr.Gallery( | |
| label="Generated images", show_label=False, elem_id="gallery" | |
| ).style(grid=[2,2]) | |
| gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100) | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图',show_label=True,value="图片1(img1)") | |
| with gr.Row().style(mobile_collapse=False, equal_height=True): | |
| output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style( | |
| margin=False, | |
| rounded=(True, True, True, True), | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Markdown("提示(Prompt):", visible=False) | |
| with gr.Row(): | |
| move_prompt_zh = gr.Markdown("请移至图生图部分进行编辑(拉到顶部)", visible=False) | |
| with gr.Row(): | |
| move_prompt_en = gr.Markdown("Please move to the img2img section for editing(Pull to the top)", visible=False) | |
| text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery) | |
| btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery) | |
| sample_size.change( | |
| fn=uifn.change_img_choices, | |
| inputs=[sample_size], | |
| outputs=[img_choices] | |
| ) | |
| with gr.TabItem("图生图(Img-to-Img)", id="img2img_tab"): | |
| with gr.Row(elem_id="prompt_row"): | |
| img2img_prompt = gr.Textbox(label="Prompt", | |
| elem_id='img2img_prompt_input', | |
| placeholder="神奇的森林,流淌的河流.", | |
| lines=1, | |
| max_lines=1, | |
| value="", | |
| show_label=False).style() | |
| img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False, | |
| elem_id="img2img_mask_btn") | |
| img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn") | |
| gr.Markdown('#### 输入图像') | |
| with gr.Row().style(equal_height=False): | |
| #with gr.Column(): | |
| img2img_image_mask = gr.Image( | |
| value=None, | |
| source="upload", | |
| interactive=True, | |
| tool="sketch", | |
| type='pil', | |
| elem_id="img2img_mask", | |
| image_mode="RGBA" | |
| ) | |
| gr.Markdown('#### 编辑后的图片') | |
| with gr.Row(): | |
| output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style( | |
| grid=[4,4,4] ) | |
| with gr.Row(): | |
| gr.Markdown('提示(prompt):') | |
| with gr.Row(): | |
| gr.Markdown('请选择一张图像掩盖掉一部分区域,并输入文本描述') | |
| with gr.Row(): | |
| gr.Markdown('Please select an image to cover up a part of the area and enter a text description.') | |
| gr.Markdown('# 编辑设置',visible=False) | |
| output_txt2img_copy_to_input_btn.click( | |
| uifn.copy_img_to_input, | |
| [gallery, img_choices], | |
| [tabs, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt] | |
| ) | |
| img2img_func = img2img | |
| img2img_inputs = [img2img_prompt, img2img_image_mask] | |
| img2img_outputs = [output_img2img_gallery] | |
| img2img_btn_mask.click( | |
| img2img_func, | |
| img2img_inputs, | |
| img2img_outputs | |
| ) | |
| def img2img_submit_params(): | |
| return (img2img_func, | |
| img2img_inputs, | |
| img2img_outputs) | |
| img2img_btn_editor.click(*img2img_submit_params()) | |
| # GENERATE ON ENTER | |
| img2img_prompt.submit(None, None, None, | |
| _js=call_JS("clickFirstVisibleButton", | |
| rowId="prompt_row")) | |
| gr.HTML(read_content("footer.html")) | |
| # gr.Image('./contributors.png') | |
| block.queue(max_size=512, concurrency_count=256).launch() | |