Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import spaces | |
| from pathlib import Path | |
| import gc | |
| import subprocess | |
| from PIL import Image | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| subprocess.run('pip cache purge', shell=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.set_grad_enabled(False) | |
| models = [ | |
| "camenduru/FLUX.1-dev-diffusers", | |
| "black-forest-labs/FLUX.1-schnell", | |
| "sayakpaul/FLUX.1-merged", | |
| "John6666/flux-dev2pro-bf16-flux", | |
| "John6666/flux1-dev-minus-v1-fp8-flux", | |
| "John6666/hyper-flux1-dev-fp8-flux", | |
| "John6666/blue-pencil-flux1-v001-fp8-flux", | |
| "John6666/copycat-flux-test-fp8-v11-fp8-flux", | |
| "John6666/flux-dev8-anime-nsfw-fp8-flux", | |
| "John6666/nepotism-fuxdevschnell-v3aio-fp8-flux", | |
| "John6666/niji-style-flux-devfp8-fp8-flux", | |
| "John6666/niji56-style-v3-fp8-flux", | |
| "John6666/lyh-dalle-anime-v12dalle-fp8-flux", | |
| "John6666/glimmerkin-flux-cute-v10-fp8-flux", | |
| "John6666/xe-anime-flux-03-fp8-flux", | |
| "John6666/xe-figure-flux-01-fp8-flux", | |
| "John6666/xe-pixel-flux-01-fp8-flux", | |
| "John6666/fluxunchained-artfulnsfw-fut516xfp8e4m3fnv11-fp8-flux", | |
| "John6666/fastflux-unchained-t5f16-fp8-flux", | |
| "John6666/iniverse-mix-xl-sfwnsfw-fluxdfp16nsfwv11-fp8-flux", | |
| "John6666/nsfw-master-flux-lora-merged-with-flux1-dev-fp16-v10-fp8-flux", | |
| "John6666/the-araminta-flux1a1-fp8-flux", | |
| "John6666/acorn-is-spinning-flux-v11-fp8-flux", | |
| "John6666/real-horny-pro-fp8-flux", | |
| "John6666/centerfold-flux-v20fp8e5m2-fp8-flux", | |
| "John6666/jib-mix-flux-v208stephyper-fp8-flux", | |
| "John6666/flux-asian-realistic-v10-fp8-flux", | |
| "John6666/fluxasiandoll-v10-fp8-flux", | |
| "John6666/xe-asian-flux-01-fp8-flux", | |
| "John6666/fluxescore-dev-v10fp16-fp8-flux", | |
| # "", | |
| ] | |
| num_loras = 3 | |
| num_cns = 2 | |
| control_images = [None] * num_cns | |
| control_modes = [-1] * num_cns | |
| control_scales = [0] * num_cns | |
| def is_repo_name(s): | |
| import re | |
| return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s) | |
| def is_repo_exists(repo_id): | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| try: | |
| if api.repo_exists(repo_id=repo_id): return True | |
| else: return False | |
| except Exception as e: | |
| print(f"Error: Failed to connect {repo_id}. ") | |
| print(e) | |
| return True # for safe | |
| def clear_cache(): | |
| try: | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_max_memory_allocated() | |
| torch.cuda.reset_peak_memory_stats() | |
| gc.collect() | |
| except Exception as e: | |
| print(e) | |
| def deselect_lora(): | |
| selected_index = None | |
| new_placeholder = "Type a prompt" | |
| updated_text = "" | |
| width = 1024 | |
| height = 1024 | |
| return ( | |
| gr.update(placeholder=new_placeholder), | |
| updated_text, | |
| selected_index, | |
| width, | |
| height, | |
| ) | |
| def get_repo_safetensors(repo_id: str): | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| try: | |
| if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[]) | |
| files = api.list_repo_files(repo_id=repo_id) | |
| except Exception as e: | |
| print(f"Error: Failed to get {repo_id}'s info.") | |
| print(e) | |
| return gr.update(choices=[]) | |
| files = [f for f in files if f.endswith(".safetensors")] | |
| if len(files) == 0: return gr.update(value="", choices=[]) | |
| else: return gr.update(value=files[0], choices=files) | |
| def expand2square(pil_img: Image.Image, background_color: tuple=(0, 0, 0)): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result | |
| # https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny/blob/main/app.py | |
| def resize_image(image, target_width, target_height, crop=True): | |
| from image_datasets.canny_dataset import c_crop | |
| if crop: | |
| image = c_crop(image) # Crop the image to square | |
| original_width, original_height = image.size | |
| # Resize to match the target size without stretching | |
| scale = max(target_width / original_width, target_height / original_height) | |
| resized_width = int(scale * original_width) | |
| resized_height = int(scale * original_height) | |
| image = image.resize((resized_width, resized_height), Image.LANCZOS) | |
| # Center crop to match the target dimensions | |
| left = (resized_width - target_width) // 2 | |
| top = (resized_height - target_height) // 2 | |
| image = image.crop((left, top, left + target_width, top + target_height)) | |
| else: | |
| image = image.resize((target_width, target_height), Image.LANCZOS) | |
| return image | |
| # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union/blob/main/app.py | |
| # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union | |
| controlnet_union_modes = { | |
| "None": -1, | |
| #"scribble_hed": 0, | |
| "canny": 0, # supported | |
| "mlsd": 0, #supported | |
| "tile": 1, #supported | |
| "depth_midas": 2, # supported | |
| "blur": 3, # supported | |
| "openpose": 4, # supported | |
| "gray": 5, # supported | |
| "low_quality": 6, # supported | |
| } | |
| # https://github.com/pytorch/pytorch/issues/123834 | |
| def get_control_params(): | |
| from diffusers.utils import load_image | |
| modes = [] | |
| images = [] | |
| scales = [] | |
| for i, mode in enumerate(control_modes): | |
| if mode == -1 or control_images[i] is None: continue | |
| modes.append(control_modes[i]) | |
| images.append(load_image(control_images[i])) | |
| scales.append(control_scales[i]) | |
| return modes, images, scales | |
| from preprocessor import Preprocessor | |
| def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int, | |
| preprocess_resolution: int): | |
| if control_mode == "None": return image | |
| image_resolution = max(width, height) | |
| image_before = resize_image(expand2square(image.convert("RGB")), image_resolution, image_resolution, False) | |
| # generated control_ | |
| print("start to generate control image") | |
| preprocessor = Preprocessor() | |
| if control_mode == "depth_midas": | |
| preprocessor.load("Midas") | |
| control_image = preprocessor( | |
| image=image_before, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| if control_mode == "openpose": | |
| preprocessor.load("Openpose") | |
| control_image = preprocessor( | |
| image=image_before, | |
| hand_and_face=True, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| if control_mode == "canny": | |
| preprocessor.load("Canny") | |
| control_image = preprocessor( | |
| image=image_before, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| if control_mode == "mlsd": | |
| preprocessor.load("MLSD") | |
| control_image = preprocessor( | |
| image=image_before, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| if control_mode == "scribble_hed": | |
| preprocessor.load("HED") | |
| control_image = preprocessor( | |
| image=image_before, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| if control_mode == "low_quality" or control_mode == "gray" or control_mode == "blur" or control_mode == "tile": | |
| control_image = image_before | |
| image_width = 768 | |
| image_height = 768 | |
| else: | |
| # make sure control image size is same as resized_image | |
| image_width, image_height = control_image.size | |
| image_after = resize_image(control_image, width, height, False) | |
| ref_width, ref_height = image.size | |
| print(f"generate control image success: {ref_width}x{ref_height} => {image_width}x{image_height}") | |
| return image_after | |
| def get_control_union_mode(): | |
| return list(controlnet_union_modes.keys()) | |
| def set_control_union_mode(i: int, mode: str, scale: str): | |
| global control_modes | |
| global control_scales | |
| control_modes[i] = controlnet_union_modes.get(mode, 0) | |
| control_scales[i] = scale | |
| if mode != "None": return True | |
| else: return gr.update(visible=True) | |
| def set_control_union_image(i: int, mode: str, image: Image.Image | None, height: int, width: int, preprocess_resolution: int): | |
| global control_images | |
| if image is None: return None | |
| control_images[i] = preprocess_image(image, mode, height, width, preprocess_resolution) | |
| return control_images[i] | |
| def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str): | |
| lorajson[i]["name"] = str(name) if name != "None" else "" | |
| lorajson[i]["scale"] = float(scale) | |
| lorajson[i]["filename"] = str(filename) | |
| lorajson[i]["trigger"] = str(trigger) | |
| return lorajson | |
| def is_valid_lora(lorajson: list[dict]): | |
| valid = False | |
| for d in lorajson: | |
| if "name" in d.keys() and d["name"] and d["name"] != "None": valid = True | |
| return valid | |
| def get_trigger_word(lorajson: list[dict]): | |
| trigger = "" | |
| for d in lorajson: | |
| if "name" in d.keys() and d["name"] and d["name"] != "None" and d["trigger"]: | |
| trigger += ", " + d["trigger"] | |
| return trigger | |
| # https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora | |
| # https://github.com/huggingface/diffusers/issues/4919 | |
| def fuse_loras(pipe, lorajson: list[dict]): | |
| if not lorajson or not isinstance(lorajson, list): return | |
| a_list = [] | |
| w_list = [] | |
| for d in lorajson: | |
| if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue | |
| k = d["name"] | |
| if is_repo_name(k) and is_repo_exists(k): | |
| a_name = Path(k).stem | |
| pipe.load_lora_weights(k, weight_name=d["filename"], adapter_name = a_name) | |
| elif not Path(k).exists(): | |
| print(f"LoRA not found: {k}") | |
| continue | |
| else: | |
| w_name = Path(k).name | |
| a_name = Path(k).stem | |
| pipe.load_lora_weights(k, weight_name = w_name, adapter_name = a_name) | |
| a_list.append(a_name) | |
| w_list.append(d["scale"]) | |
| if not a_list: return | |
| pipe.set_adapters(a_list, adapter_weights=w_list) | |
| pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0) | |
| #pipe.unload_lora_weights() | |
| def description_ui(): | |
| gr.Markdown( | |
| """ | |
| - Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer), | |
| [jiuface/FLUX.1-dev-Controlnet-Union](https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union), | |
| [DamarJati/FLUX.1-DEV-Canny](https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny), | |
| [gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator). | |
| """ | |
| ) | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| def load_prompt_enhancer(): | |
| try: | |
| model_checkpoint = "gokaygokay/Flux-Prompt-Enhance" | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device) | |
| enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device) | |
| except Exception as e: | |
| print(e) | |
| enhancer_flux = None | |
| return enhancer_flux | |
| enhancer_flux = load_prompt_enhancer() | |
| def enhance_prompt(input_prompt): | |
| result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256) | |
| enhanced_text = result[0]['generated_text'] | |
| return enhanced_text | |
| load_prompt_enhancer.zerogpu = True | |
| fuse_loras.zerogpu = True | |
| preprocess_image.zerogpu = True | |
| get_control_params.zerogpu = True | |
| clear_cache.zerogpu = True |