Spaces:
Runtime error
Runtime error
| import base64 | |
| import io | |
| import re | |
| import time | |
| from datetime import date | |
| from pathlib import Path | |
| import gradio as gr | |
| import requests | |
| import torch | |
| from PIL import Image | |
| import modules.shared as shared | |
| from modules.models import reload_model, unload_model | |
| from modules.ui import create_refresh_button | |
| torch._C._jit_set_profiling_mode(False) | |
| # parameters which can be customized in settings.json of webui | |
| params = { | |
| 'address': 'http://127.0.0.1:7860', | |
| 'mode': 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on) | |
| 'manage_VRAM': False, | |
| 'save_img': False, | |
| 'SD_model': 'NeverEndingDream', # not used right now | |
| 'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful', | |
| 'negative_prompt': '(worst quality, low quality:1.3)', | |
| 'width': 512, | |
| 'height': 512, | |
| 'denoising_strength': 0.61, | |
| 'restore_faces': False, | |
| 'enable_hr': False, | |
| 'hr_upscaler': 'ESRGAN_4x', | |
| 'hr_scale': '1.0', | |
| 'seed': -1, | |
| 'sampler_name': 'DPM++ 2M Karras', | |
| 'steps': 32, | |
| 'cfg_scale': 7, | |
| 'textgen_prefix': 'Please provide a detailed and vivid description of [subject]', | |
| 'sd_checkpoint': ' ', | |
| 'checkpoint_list': [" "] | |
| } | |
| def give_VRAM_priority(actor): | |
| global shared, params | |
| if actor == 'SD': | |
| unload_model() | |
| print("Requesting Auto1111 to re-load last checkpoint used...") | |
| response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='') | |
| response.raise_for_status() | |
| elif actor == 'LLM': | |
| print("Requesting Auto1111 to vacate VRAM...") | |
| response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='') | |
| response.raise_for_status() | |
| reload_model() | |
| elif actor == 'set': | |
| print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...") | |
| response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='') | |
| response.raise_for_status() | |
| elif actor == 'reset': | |
| print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint") | |
| response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='') | |
| response.raise_for_status() | |
| else: | |
| raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!') | |
| response.raise_for_status() | |
| del response | |
| if params['manage_VRAM']: | |
| give_VRAM_priority('set') | |
| SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select | |
| picture_response = False # specifies if the next model response should appear as a picture | |
| def remove_surrounded_chars(string): | |
| # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR | |
| # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' | |
| return re.sub('\*[^\*]*?(\*|$)', '', string) | |
| def triggers_are_in(string): | |
| string = remove_surrounded_chars(string) | |
| # regex searches for send|main|message|me (at the end of the word) followed by | |
| # a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s), | |
| # (?aims) are regex parser flags | |
| return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string)) | |
| def state_modifier(state): | |
| if picture_response: | |
| state['stream'] = False | |
| return state | |
| def input_modifier(string): | |
| """ | |
| This function is applied to your text inputs before | |
| they are fed into the model. | |
| """ | |
| global params | |
| if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing | |
| return string | |
| if triggers_are_in(string): # if we're in it, check for trigger words | |
| toggle_generation(True) | |
| string = string.lower() | |
| if "of" in string: | |
| subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it | |
| string = params['textgen_prefix'].replace("[subject]", subject) | |
| else: | |
| string = params['textgen_prefix'].replace("[subject]", "your appearance, your surroundings and what you are doing right now") | |
| return string | |
| # Get and save the Stable Diffusion-generated picture | |
| def get_SD_pictures(description): | |
| global params | |
| if params['manage_VRAM']: | |
| give_VRAM_priority('SD') | |
| payload = { | |
| "prompt": params['prompt_prefix'] + description, | |
| "seed": params['seed'], | |
| "sampler_name": params['sampler_name'], | |
| "enable_hr": params['enable_hr'], | |
| "hr_scale": params['hr_scale'], | |
| "hr_upscaler": params['hr_upscaler'], | |
| "denoising_strength": params['denoising_strength'], | |
| "steps": params['steps'], | |
| "cfg_scale": params['cfg_scale'], | |
| "width": params['width'], | |
| "height": params['height'], | |
| "restore_faces": params['restore_faces'], | |
| "override_settings_restore_afterwards": True, | |
| "negative_prompt": params['negative_prompt'] | |
| } | |
| print(f'Prompting the image generator via the API on {params["address"]}...') | |
| response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload) | |
| response.raise_for_status() | |
| r = response.json() | |
| visible_result = "" | |
| for img_str in r['images']: | |
| if params['save_img']: | |
| img_data = base64.b64decode(img_str) | |
| variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}' | |
| output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png') | |
| output_file.parent.mkdir(parents=True, exist_ok=True) | |
| with open(output_file.as_posix(), 'wb') as f: | |
| f.write(img_data) | |
| visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n' | |
| else: | |
| image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0]))) | |
| # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history | |
| image.thumbnail((300, 300)) | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| buffered.seek(0) | |
| image_bytes = buffered.getvalue() | |
| img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode() | |
| visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n' | |
| if params['manage_VRAM']: | |
| give_VRAM_priority('LLM') | |
| return visible_result | |
| # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) | |
| # and replace it with 'text' for the purposes of logging? | |
| def output_modifier(string): | |
| """ | |
| This function is applied to the model outputs. | |
| """ | |
| global picture_response, params | |
| if not picture_response: | |
| return string | |
| string = remove_surrounded_chars(string) | |
| string = string.replace('"', '') | |
| string = string.replace('“', '') | |
| string = string.replace('\n', ' ') | |
| string = string.strip() | |
| if string == '': | |
| string = 'no viable description in reply, try regenerating' | |
| return string | |
| text = "" | |
| if (params['mode'] < 2): | |
| toggle_generation(False) | |
| text = f'*Sends a picture which portrays: “{string}”*' | |
| else: | |
| text = string | |
| string = get_SD_pictures(string) + "\n" + text | |
| return string | |
| def bot_prefix_modifier(string): | |
| """ | |
| This function is only applied in chat mode. It modifies | |
| the prefix text for the Bot and can be used to bias its | |
| behavior. | |
| """ | |
| return string | |
| def toggle_generation(*args): | |
| global picture_response, shared | |
| if not args: | |
| picture_response = not picture_response | |
| else: | |
| picture_response = args[0] | |
| shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*" | |
| def filter_address(address): | |
| address = address.strip() | |
| # address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash | |
| address = re.sub('\/$', '', address) # remove trailing /s | |
| if not address.startswith('http'): | |
| address = 'http://' + address | |
| return address | |
| def SD_api_address_update(address): | |
| global params | |
| msg = "✔️ SD API is found on:" | |
| address = filter_address(address) | |
| params.update({"address": address}) | |
| try: | |
| response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') | |
| response.raise_for_status() | |
| # r = response.json() | |
| except: | |
| msg = "❌ No SD API endpoint on:" | |
| return gr.Textbox.update(label=msg) | |
| def custom_css(): | |
| path_to_css = Path(__file__).parent.resolve() / 'style.css' | |
| return open(path_to_css, 'r').read() | |
| def get_checkpoints(): | |
| global params | |
| try: | |
| models = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') | |
| options = requests.get(url=f'{params["address"]}/sdapi/v1/options') | |
| options_json = options.json() | |
| params['sd_checkpoint'] = options_json['sd_model_checkpoint'] | |
| params['checkpoint_list'] = [result["title"] for result in models.json()] | |
| except: | |
| params['sd_checkpoint'] = "" | |
| params['checkpoint_list'] = [] | |
| return gr.update(choices=params['checkpoint_list'], value=params['sd_checkpoint']) | |
| def load_checkpoint(checkpoint): | |
| payload = { | |
| "sd_model_checkpoint": checkpoint | |
| } | |
| try: | |
| requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload) | |
| except: | |
| pass | |
| def get_samplers(): | |
| try: | |
| response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers') | |
| response.raise_for_status() | |
| samplers = [x["name"] for x in response.json()] | |
| except: | |
| samplers = [] | |
| return samplers | |
| def ui(): | |
| # Gradio elements | |
| # gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title | |
| with gr.Accordion("Parameters", open=True, elem_classes="SDAP"): | |
| with gr.Row(): | |
| address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address') | |
| modes_list = ["Manual", "Immersive/Interactive", "Picturebook/Adventure"] | |
| mode = gr.Dropdown(modes_list, value=modes_list[params['mode']], label="Mode of operation", type="index") | |
| with gr.Column(scale=1, min_width=300): | |
| manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM') | |
| save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat') | |
| force_pic = gr.Button("Force the picture response") | |
| suppr_pic = gr.Button("Suppress the picture response") | |
| with gr.Row(): | |
| checkpoint = gr.Dropdown(params['checkpoint_list'], value=params['sd_checkpoint'], label="Checkpoint", type="value") | |
| update_checkpoints = gr.Button("Get list of checkpoints") | |
| with gr.Accordion("Generation parameters", open=False): | |
| prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') | |
| textgen_prefix = gr.Textbox(placeholder=params['textgen_prefix'], value=params['textgen_prefix'], label='textgen prefix (type [subject] where the subject should be placed)') | |
| negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt') | |
| with gr.Row(): | |
| with gr.Column(): | |
| width = gr.Slider(256, 768, value=params['width'], step=64, label='Width') | |
| height = gr.Slider(256, 768, value=params['height'], step=64, label='Height') | |
| with gr.Column(variant="compact", elem_id="sampler_col"): | |
| with gr.Row(elem_id="sampler_row"): | |
| sampler_name = gr.Dropdown(value=params['sampler_name'], label='Sampling method', elem_id="sampler_box") | |
| create_refresh_button(sampler_name, lambda: None, lambda: {'choices': get_samplers()}, 'refresh-button') | |
| steps = gr.Slider(1, 150, value=params['steps'], step=1, label="Sampling steps", elem_id="steps_box") | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=params['seed'], elem_id="seed_box") | |
| cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box") | |
| with gr.Column() as hr_options: | |
| restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces') | |
| enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix') | |
| with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options: | |
| hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by') | |
| denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength') | |
| hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler') | |
| # Event functions to update the parameters in the backend | |
| address.change(lambda x: params.update({"address": filter_address(x)}), address, None) | |
| mode.select(lambda x: params.update({"mode": x}), mode, None) | |
| mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None) | |
| manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None) | |
| manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None) | |
| save_img.change(lambda x: params.update({"save_img": x}), save_img, None) | |
| address.submit(fn=SD_api_address_update, inputs=address, outputs=address) | |
| prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None) | |
| textgen_prefix.change(lambda x: params.update({"textgen_prefix": x}), textgen_prefix, None) | |
| negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None) | |
| width.change(lambda x: params.update({"width": x}), width, None) | |
| height.change(lambda x: params.update({"height": x}), height, None) | |
| hr_scale.change(lambda x: params.update({"hr_scale": x}), hr_scale, None) | |
| denoising_strength.change(lambda x: params.update({"denoising_strength": x}), denoising_strength, None) | |
| restore_faces.change(lambda x: params.update({"restore_faces": x}), restore_faces, None) | |
| hr_upscaler.change(lambda x: params.update({"hr_upscaler": x}), hr_upscaler, None) | |
| enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None) | |
| enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options) | |
| update_checkpoints.click(get_checkpoints, None, checkpoint) | |
| checkpoint.change(lambda x: params.update({"sd_checkpoint": x}), checkpoint, None) | |
| checkpoint.change(load_checkpoint, checkpoint, None) | |
| sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None) | |
| steps.change(lambda x: params.update({"steps": x}), steps, None) | |
| seed.change(lambda x: params.update({"seed": x}), seed, None) | |
| cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None) | |
| force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None) | |
| suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None) | |