Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from examples.story_examples import get_examples | |
| import spaces | |
| import numpy as np | |
| import torch | |
| import random | |
| import os | |
| import torch.nn.functional as F | |
| from diffusers import StableDiffusionXLPipeline, DDIMScheduler | |
| import copy | |
| from huggingface_hub import hf_hub_download | |
| from diffusers.utils import load_image | |
| from storyDiffusion.utils.gradio_utils import AttnProcessor2_0 as AttnProcessor, cal_attn_mask_xl | |
| from storyDiffusion.utils import PhotoMakerStableDiffusionXLPipeline | |
| from storyDiffusion.utils.utils import get_comic | |
| from storyDiffusion.utils.style_template import styles | |
| # Constants | |
| image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder" | |
| ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin" | |
| os.environ["no_proxy"] = "localhost,127.0.0.1,::1" | |
| STYLE_NAMES = list(styles.keys()) | |
| DEFAULT_STYLE_NAME = "Japanese Anime" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # Global variables | |
| global models_dict, use_va, photomaker_path, pipe2, pipe4, attn_count, total_count, id_length, total_length, cur_step, cur_model_type, write, sa32, sa64, height, width, attn_procs, unet, num_steps | |
| models_dict = { | |
| "RealVision": "SG161222/RealVisXL_V4.0", | |
| "Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y" | |
| } | |
| use_va = True | |
| photomaker_path = hf_hub_download( | |
| repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model") | |
| device = "cuda" | |
| # Functions | |
| def setup_seed(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| def set_text_unfinished(): | |
| return gr.update(visible=True, value="<h3>(Not Finished) Generating ··· The intermediate results will be shown.</h3>") | |
| def set_text_finished(): | |
| return gr.update(visible=True, value="<h3>Generation Finished</h3>") | |
| class SpatialAttnProcessor2_0(torch.nn.Module): | |
| r""" | |
| Attention processor for IP-Adapater for PyTorch 2.0. | |
| Args: | |
| hidden_size (`int`): | |
| The hidden size of the attention layer. | |
| cross_attention_dim (`int`): | |
| The number of channels in the `encoder_hidden_states`. | |
| text_context_len (`int`, defaults to 77): | |
| The context length of the text features. | |
| scale (`float`, defaults to 1.0): | |
| the weight scale of image prompt. | |
| """ | |
| def __init__(self, hidden_size=None, cross_attention_dim=None, id_length=4, device="cuda", dtype=torch.float16): | |
| super().__init__() | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
| self.device = device | |
| self.dtype = dtype | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim | |
| self.total_length = id_length + 1 | |
| self.id_length = id_length | |
| self.id_bank = {} | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None): | |
| # un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2) | |
| # un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb) | |
| # 生成一个0到1之间的随机数 | |
| global total_count, attn_count, cur_step, mask1024, mask4096 | |
| global sa32, sa64 | |
| global write | |
| global height, width | |
| global num_steps | |
| if write: | |
| # print(f"white:{cur_step}") | |
| self.id_bank[cur_step] = [ | |
| hidden_states[:self.id_length], hidden_states[self.id_length:]] | |
| else: | |
| encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to( | |
| self.device), hidden_states[:1], self.id_bank[cur_step][1].to(self.device), hidden_states[1:])) | |
| # 判断随机数是否大于0.5 | |
| if cur_step <= 1: | |
| hidden_states = self.__call2__( | |
| attn, hidden_states, None, attention_mask, temb) | |
| else: # 256 1024 4096 | |
| random_number = random.random() | |
| if cur_step < 0.4 * num_steps: | |
| rand_num = 0.3 | |
| else: | |
| rand_num = 0.1 | |
| # print(f"hidden state shape {hidden_states.shape[1]}") | |
| if random_number > rand_num: | |
| # print("mask shape",mask1024.shape,mask4096.shape) | |
| if not write: | |
| if hidden_states.shape[1] == (height//32) * (width//32): | |
| attention_mask = mask1024[mask1024.shape[0] // | |
| self.total_length * self.id_length:] | |
| else: | |
| attention_mask = mask4096[mask4096.shape[0] // | |
| self.total_length * self.id_length:] | |
| else: | |
| # print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32)) | |
| if hidden_states.shape[1] == (height//32) * (width//32): | |
| attention_mask = mask1024[:mask1024.shape[0] // self.total_length * | |
| self.id_length, :mask1024.shape[0] // self.total_length * self.id_length] | |
| else: | |
| attention_mask = mask4096[:mask4096.shape[0] // self.total_length * | |
| self.id_length, :mask4096.shape[0] // self.total_length * self.id_length] | |
| # print(attention_mask.shape) | |
| # print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None") | |
| hidden_states = self.__call1__( | |
| attn, hidden_states, encoder_hidden_states, attention_mask, temb) | |
| else: | |
| hidden_states = self.__call2__( | |
| attn, hidden_states, None, attention_mask, temb) | |
| attn_count += 1 | |
| if attn_count == total_count: | |
| attn_count = 0 | |
| cur_step += 1 | |
| mask1024, mask4096 = cal_attn_mask_xl( | |
| self.total_length, self.id_length, sa32, sa64, height, width, device=self.device, dtype=self.dtype) | |
| return hidden_states | |
| def __call1__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| ): | |
| # print("hidden state shape",hidden_states.shape,self.id_length) | |
| residual = hidden_states | |
| # if encoder_hidden_states is not None: | |
| # raise Exception("not implement") | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| total_batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| total_batch_size, channel, height * width).transpose(1, 2) | |
| total_batch_size, nums_token, channel = hidden_states.shape | |
| img_nums = total_batch_size//2 | |
| hidden_states = hidden_states.view(-1, img_nums, nums_token, | |
| channel).reshape(-1, img_nums * nums_token, channel) | |
| batch_size, sequence_length, _ = hidden_states.shape | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm( | |
| hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states # B, N, C | |
| else: | |
| encoder_hidden_states = encoder_hidden_states.view( | |
| -1, self.id_length+1, nums_token, channel).reshape(-1, (self.id_length+1) * nums_token, channel) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, | |
| head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, | |
| head_dim).transpose(1, 2) | |
| # print(key.shape,value.shape,query.shape,attention_mask.shape) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| # print(query.shape,key.shape,value.shape,attention_mask.shape) | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| total_batch_size, -1, attn.heads * head_dim) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| # if input_ndim == 4: | |
| # tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| # if attn.residual_connection: | |
| # tile_hidden_states = tile_hidden_states + residual | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| total_batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| # print(hidden_states.shape) | |
| return hidden_states | |
| def __call2__( | |
| self, | |
| attn, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, channel = ( | |
| hidden_states.shape | |
| ) | |
| # print(hidden_states.shape) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm( | |
| hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states # B, N, C | |
| else: | |
| encoder_hidden_states = encoder_hidden_states.view( | |
| -1, self.id_length+1, sequence_length, channel).reshape(-1, (self.id_length+1) * sequence_length, channel) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, | |
| head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, | |
| head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose( | |
| -1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def set_attention_processor(unet, id_length, is_ipadapter=False): | |
| global total_count | |
| total_count = 0 | |
| attn_procs = {} | |
| for name in unet.attn_processors.keys(): | |
| cross_attention_dim = None if name.endswith( | |
| "attn1.processor") else unet.config.cross_attention_dim | |
| if cross_attention_dim is None: | |
| if name.startswith("up_blocks"): | |
| attn_procs[name] = SpatialAttnProcessor2_0(id_length=id_length) | |
| total_count += 1 | |
| else: | |
| attn_procs[name] = AttnProcessor() | |
| else: | |
| attn_procs[name] = AttnProcessor() | |
| unet.set_attn_processor(copy.deepcopy(attn_procs)) | |
| print("Successfully loaded paired self-attention") | |
| print(f"Number of processors: {total_count}") | |
| attn_count = 0 | |
| total_count = 0 | |
| cur_step = 0 | |
| id_length = 4 | |
| total_length = 5 | |
| cur_model_type = "" | |
| device = "cuda" | |
| attn_procs = {} | |
| write = False | |
| sa32 = 0.5 | |
| sa64 = 0.5 | |
| height = 768 | |
| width = 768 | |
| def swap_to_gallery(images): | |
| return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
| def upload_example_to_gallery(images, prompt, style, negative_prompt): | |
| return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
| def remove_back_to_files(): | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
| def remove_tips(): | |
| return gr.update(visible=False) | |
| def apply_style_positive(style_name: str, positive: str): | |
| p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) | |
| return p.replace("{prompt}", positive) | |
| def apply_style(style_name: str, positives: list, negative: str = ""): | |
| p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) | |
| return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative | |
| def change_visiale_by_model_type(_model_type): | |
| if _model_type == "Only Using Textual Description": | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| elif _model_type == "Using Ref Images": | |
| return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| raise ValueError("Invalid model type", _model_type) | |
| def process_generation(_sd_type, _model_type, _upload_images, _num_steps, style_name, _Ip_Adapter_Strength, _style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, G_height, G_width, _comic_type): | |
| global sa32, sa64, id_length, total_length, attn_procs, unet, cur_model_type, device, num_steps, write, cur_step, attn_count, height, width, pipe2, pipe4, sd_model_path, models_dict | |
| _model_type = "Photomaker" if _model_type == "Using Ref Images" else "original" | |
| if _model_type == "Photomaker" and "img" not in general_prompt: | |
| raise gr.Error( | |
| "Please add the trigger word 'img' behind the class word you want to customize, such as: man img or woman img") | |
| if _upload_images is None and _model_type != "original": | |
| raise gr.Error("Cannot find any input face image!") | |
| if len(prompt_array.splitlines()) > 10: | |
| raise gr.Error( | |
| f"No more than 10 prompts in Hugging Face demo for speed! But found {len(prompt_array.splitlines())} prompts!") | |
| height = G_height | |
| width = G_width | |
| sd_model_path = models_dict[_sd_type] | |
| num_steps = _num_steps | |
| if style_name == "(No style)": | |
| sd_model_path = models_dict["RealVision"] | |
| if _model_type == "original": | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| sd_model_path, torch_dtype=torch.float16) | |
| pipe = pipe.to(device) | |
| pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
| set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) | |
| elif _model_type == "Photomaker": | |
| if _sd_type != "RealVision" and style_name != "(No style)": | |
| pipe = pipe2.to(device) | |
| pipe.id_encoder.to(device) | |
| set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) | |
| else: | |
| pipe = pipe4.to(device) | |
| pipe.id_encoder.to(device) | |
| set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) | |
| else: | |
| raise NotImplementedError( | |
| "You should choose between original and Photomaker!", f"But you chose {_model_type}") | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
| cur_model_type = _sd_type + "-" + _model_type + str(id_length_) | |
| if _model_type != "original": | |
| input_id_images = [load_image(img) for img in _upload_images] | |
| prompts = prompt_array.splitlines() | |
| start_merge_step = int(float(_style_strength_ratio) / 100 * _num_steps) | |
| if start_merge_step > 30: | |
| start_merge_step = 30 | |
| print(f"start_merge_step: {start_merge_step}") | |
| generator = torch.Generator(device="cuda").manual_seed(seed_) | |
| sa32, sa64 = sa32_, sa64_ | |
| id_length = id_length_ | |
| clipped_prompts = prompts[:] | |
| prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace( | |
| "[NC]", "") for prompt in clipped_prompts] | |
| prompts = [prompt.rpartition( | |
| '#')[0] if "#" in prompt else prompt for prompt in prompts] | |
| print(prompts) | |
| id_prompts = prompts[:id_length] | |
| real_prompts = prompts[id_length:] | |
| torch.cuda.empty_cache() | |
| write = True | |
| cur_step = 0 | |
| attn_count = 0 | |
| id_prompts, negative_prompt = apply_style( | |
| style_name, id_prompts, negative_prompt) | |
| setup_seed(seed_) | |
| total_results = [] | |
| if _model_type == "original": | |
| id_images = pipe(id_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, | |
| height=height, width=width, negative_prompt=negative_prompt, generator=generator).images | |
| elif _model_type == "Photomaker": | |
| id_images = pipe(id_prompts, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, | |
| start_merge_step=start_merge_step, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images | |
| else: | |
| raise NotImplementedError( | |
| "You should choose between original and Photomaker!", f"But you chose {_model_type}") | |
| total_results = id_images + total_results | |
| yield total_results | |
| real_images = [] | |
| write = False | |
| for real_prompt in real_prompts: | |
| setup_seed(seed_) | |
| cur_step = 0 | |
| real_prompt = apply_style_positive(style_name, real_prompt) | |
| if _model_type == "original": | |
| real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, | |
| height=height, width=width, negative_prompt=negative_prompt, generator=generator).images[0]) | |
| elif _model_type == "Photomaker": | |
| real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, | |
| start_merge_step=start_merge_step, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images[0]) | |
| else: | |
| raise NotImplementedError( | |
| "You should choose between original and Photomaker!", f"But you chose {_model_type}") | |
| total_results = [real_images[-1]] + total_results | |
| yield total_results | |
| if _comic_type != "No typesetting (default)": | |
| from PIL import ImageFont | |
| captions = prompt_array.splitlines() | |
| captions = [caption.replace("[NC]", "") for caption in captions] | |
| captions = [caption.split( | |
| '#')[-1] if "#" in caption else caption for caption in captions] | |
| total_results = get_comic(id_images + real_images, _comic_type, captions=captions, | |
| font=ImageFont.truetype("./storyDiffusion/fonts/Inkfree.ttf", int(45))) + total_results | |
| if _model_type == "Photomaker": | |
| pipe = pipe2.to("cpu") | |
| pipe.id_encoder.to("cpu") | |
| set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) | |
| yield total_results | |
| # Initialize pipelines | |
| pipe2 = PhotoMakerStableDiffusionXLPipeline.from_pretrained( | |
| models_dict["Unstable"], torch_dtype=torch.float16, use_safetensors=False) | |
| pipe2 = pipe2.to("cpu") | |
| pipe2.load_photomaker_adapter( | |
| os.path.dirname(photomaker_path), | |
| subfolder="", | |
| weight_name=os.path.basename(photomaker_path), | |
| trigger_word="img" | |
| ) | |
| pipe2 = pipe2.to("cpu") | |
| pipe2.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
| pipe2.fuse_lora() | |
| pipe4 = PhotoMakerStableDiffusionXLPipeline.from_pretrained( | |
| models_dict["RealVision"], torch_dtype=torch.float16, use_safetensors=True) | |
| pipe4 = pipe4.to("cpu") | |
| pipe4.load_photomaker_adapter( | |
| os.path.dirname(photomaker_path), | |
| subfolder="", | |
| weight_name=os.path.basename(photomaker_path), | |
| trigger_word="img" | |
| ) | |
| pipe4 = pipe4.to("cpu") | |
| pipe4.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
| pipe4.fuse_lora() | |
| def story_generation_ui(): | |
| with gr.Row(): | |
| with gr.Group(elem_id="main-image"): | |
| prompts = [] | |
| colors = [] | |
| with gr.Column(visible=True) as gen_prompt_vis: | |
| sd_type = gr.Dropdown(choices=list(models_dict.keys( | |
| )), value="Unstable", label="sd_type", info="Select pretrained model") | |
| model_type = gr.Radio(["Only Using Textual Description", "Using Ref Images"], label="model_type", | |
| value="Only Using Textual Description", info="Control type of the Character") | |
| with gr.Group(visible=False) as control_image_input: | |
| files = gr.Files( | |
| label="Drag (Select) 1 or more photos of your face", | |
| file_types=["image"], | |
| ) | |
| uploaded_files = gr.Gallery( | |
| label="Your images", visible=False, columns=5, rows=1, height=200) | |
| with gr.Column(visible=False) as clear_button: | |
| remove_and_reupload = gr.ClearButton( | |
| value="Remove and upload new ones", components=files, size="sm") | |
| general_prompt = gr.Textbox( | |
| value='', label="(1) Textual Description for Character", interactive=True) | |
| negative_prompt = gr.Textbox( | |
| value='', label="(2) Negative_prompt", interactive=True) | |
| style = gr.Dropdown( | |
| label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME) | |
| prompt_array = gr.Textbox( | |
| lines=3, value='', label="(3) Comic Description (each line corresponds to a frame).", interactive=True) | |
| with gr.Accordion("(4) Tune the hyperparameters", open=False): | |
| sa32_ = gr.Slider(label="(The degree of Paired Attention at 32 x 32 self-attention layers)", | |
| minimum=0, maximum=1., value=0.7, step=0.1) | |
| sa64_ = gr.Slider(label="(The degree of Paired Attention at 64 x 64 self-attention layers)", | |
| minimum=0, maximum=1., value=0.7, step=0.1) | |
| id_length_ = gr.Slider( | |
| label="Number of id images in total images", minimum=2, maximum=4, value=3, step=1) | |
| seed_ = gr.Slider(label="Seed", minimum=-1, | |
| maximum=MAX_SEED, value=0, step=1) | |
| num_steps = gr.Slider( | |
| label="Number of sample steps", | |
| minimum=25, | |
| maximum=50, | |
| step=1, | |
| value=50, | |
| ) | |
| G_height = gr.Slider( | |
| label="height", | |
| minimum=256, | |
| maximum=1024, | |
| step=32, | |
| value=1024, | |
| ) | |
| G_width = gr.Slider( | |
| label="width", | |
| minimum=256, | |
| maximum=1024, | |
| step=32, | |
| value=1024, | |
| ) | |
| comic_type = gr.Radio(["No typesetting (default)", "Four Pannel", "Classic Comic Style"], | |
| value="Classic Comic Style", label="Typesetting Style", info="Select the typesetting style ") | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=0.1, | |
| maximum=10.0, | |
| step=0.1, | |
| value=5, | |
| ) | |
| style_strength_ratio = gr.Slider( | |
| label="Style strength of Ref Image (%)", | |
| minimum=15, | |
| maximum=50, | |
| step=1, | |
| value=20, | |
| visible=False | |
| ) | |
| Ip_Adapter_Strength = gr.Slider( | |
| label="Ip_Adapter_Strength", | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.5, | |
| visible=False | |
| ) | |
| final_run_btn = gr.Button("Generate ! 😺") | |
| with gr.Column(): | |
| out_image = gr.Gallery(label="Result", columns=2, height='auto') | |
| generated_information = gr.Markdown( | |
| label="Generation Details", value="", visible=False) | |
| model_type.change(fn=change_visiale_by_model_type, inputs=model_type, outputs=[ | |
| control_image_input, style_strength_ratio, Ip_Adapter_Strength]) | |
| files.upload(fn=swap_to_gallery, inputs=files, outputs=[ | |
| uploaded_files, clear_button, files]) | |
| remove_and_reupload.click(fn=remove_back_to_files, outputs=[ | |
| uploaded_files, clear_button, files]) | |
| final_run_btn.click(fn=set_text_unfinished, outputs=generated_information | |
| ).then(process_generation, inputs=[sd_type, model_type, files, num_steps, style, Ip_Adapter_Strength, style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, G_height, G_width, comic_type], outputs=out_image | |
| ).then(fn=set_text_finished, outputs=generated_information) | |
| gr.Examples( | |
| examples=get_examples(), | |
| inputs=[seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, | |
| prompt_array, style, model_type, files, G_height, G_width], | |
| label='😺 Examples 😺', | |
| ) | |