Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import gradio as gr | |
| import json | |
| import logging | |
| logging.getLogger("diffusers").setLevel(logging.ERROR) | |
| import diffusers | |
| diffusers.utils.logging.set_verbosity(40) | |
| import warnings | |
| warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers") | |
| warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers") | |
| warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers") | |
| from pathlib import Path | |
| from env import (hf_token, hf_read_token, # to use only for private repos | |
| CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2, HF_LORA_ESSENTIAL_PRIVATE_REPO, | |
| HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes, | |
| download_model_list, download_lora_list, download_vae_list) | |
| from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo, | |
| safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list, | |
| get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt, | |
| get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai) | |
| def download_things(directory, url, hf_token="", civitai_api_key=""): | |
| url = url.strip() | |
| if "drive.google.com" in url: | |
| original_dir = os.getcwd() | |
| os.chdir(directory) | |
| os.system(f"gdown --fuzzy {url}") | |
| os.chdir(original_dir) | |
| elif "huggingface.co" in url: | |
| url = url.replace("?download=true", "") | |
| # url = urllib.parse.quote(url, safe=':/') # fix encoding | |
| if "/blob/" in url: | |
| url = url.replace("/blob/", "/resolve/") | |
| user_header = f'"Authorization: Bearer {hf_token}"' | |
| if hf_token: | |
| os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
| else: | |
| os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
| elif "civitai.com" in url: | |
| if "?" in url: | |
| url = url.split("?")[0] | |
| if civitai_api_key: | |
| url = url + f"?token={civitai_api_key}" | |
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
| else: | |
| print("\033[91mYou need an API key to download Civitai models.\033[0m") | |
| else: | |
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
| def get_model_list(directory_path): | |
| model_list = [] | |
| valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'} | |
| for filename in os.listdir(directory_path): | |
| if os.path.splitext(filename)[1] in valid_extensions: | |
| name_without_extension = os.path.splitext(filename)[0] | |
| file_path = os.path.join(directory_path, filename) | |
| # model_list.append((name_without_extension, file_path)) | |
| model_list.append(file_path) | |
| print('\033[34mFILE: ' + file_path + '\033[0m') | |
| return model_list | |
| # - **Download Models** | |
| download_model = ", ".join(download_model_list) | |
| # - **Download VAEs** | |
| download_vae = ", ".join(download_vae_list) | |
| # - **Download LoRAs** | |
| download_lora = ", ".join(download_lora_list) | |
| #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True) | |
| #download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False) | |
| CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY") | |
| hf_token = os.environ.get("HF_TOKEN") | |
| # Download stuffs | |
| for url in [url.strip() for url in download_model.split(',')]: | |
| if not os.path.exists(f"./models/{url.split('/')[-1]}"): | |
| download_things(directory_models, url, hf_token, CIVITAI_API_KEY) | |
| for url in [url.strip() for url in download_vae.split(',')]: | |
| if not os.path.exists(f"./vaes/{url.split('/')[-1]}"): | |
| download_things(directory_vaes, url, hf_token, CIVITAI_API_KEY) | |
| for url in [url.strip() for url in download_lora.split(',')]: | |
| if not os.path.exists(f"./loras/{url.split('/')[-1]}"): | |
| download_things(directory_loras, url, hf_token, CIVITAI_API_KEY) | |
| lora_model_list = get_lora_model_list() | |
| vae_model_list = get_model_list(directory_vaes) | |
| vae_model_list.insert(0, "None") | |
| def get_t2i_model_info(repo_id: str): | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| try: | |
| if " " in repo_id or not api.repo_exists(repo_id): return "" | |
| model = api.model_info(repo_id=repo_id) | |
| except Exception as e: | |
| print(f"Error: Failed to get {repo_id}'s info. ") | |
| print(e) | |
| return "" | |
| if model.private or model.gated: return "" | |
| tags = model.tags | |
| info = [] | |
| url = f"https://huggingface.co/{repo_id}/" | |
| if not 'diffusers' in tags: return "" | |
| if 'diffusers:FluxPipeline' in tags: | |
| info.append("FLUX.1") | |
| elif 'diffusers:StableDiffusionXLPipeline' in tags: | |
| info.append("SDXL") | |
| elif 'diffusers:StableDiffusionPipeline' in tags: | |
| info.append("SD1.5") | |
| if model.card_data and model.card_data.tags: | |
| info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl'])) | |
| info.append(f"DLs: {model.downloads}") | |
| info.append(f"likes: {model.likes}") | |
| info.append(model.last_modified.strftime("lastmod: %Y-%m-%d")) | |
| md = f"Model Info: {', '.join(info)}, [Model Repo]({url})" | |
| return gr.update(value=md) | |
| private_lora_dict = {"": ["", "", "", "", ""]} | |
| try: | |
| with open('lora_dict.json', encoding='utf-8') as f: | |
| d = json.load(f) | |
| for k, v in d.items(): | |
| private_lora_dict[escape_lora_basename(k)] = v | |
| except Exception: | |
| pass | |
| private_lora_model_list = get_private_lora_model_lists() | |
| loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy() | |
| loras_url_to_path_dict = {} # {"URL to download": "local filepath", ...} | |
| civitai_lora_last_results = {} # {"URL to download": {search results}, ...} | |
| all_lora_list = [] | |
| def get_all_lora_list(): | |
| global all_lora_list | |
| loras = get_lora_model_list() | |
| all_lora_list = loras.copy() | |
| return loras | |
| def get_all_lora_tupled_list(): | |
| global loras_dict | |
| models = get_all_lora_list() | |
| if not models: return [] | |
| tupled_list = [] | |
| for model in models: | |
| #if not model: continue # to avoid GUI-related bug | |
| basename = Path(model).stem | |
| key = to_lora_key(model) | |
| items = None | |
| if key in loras_dict.keys(): | |
| items = loras_dict.get(key, None) | |
| else: | |
| items = get_civitai_info(model) | |
| if items != None: | |
| loras_dict[key] = items | |
| name = basename | |
| value = model | |
| if items and items[2] != "": | |
| if items[1] == "Pony": | |
| name = f"{basename} (for {items[1]}🐴, {items[2]})" | |
| else: | |
| name = f"{basename} (for {items[1]}, {items[2]})" | |
| tupled_list.append((name, value)) | |
| return tupled_list | |
| def update_lora_dict(path: str): | |
| global loras_dict | |
| key = to_lora_key(path) | |
| if key in loras_dict.keys(): return | |
| items = get_civitai_info(path) | |
| if items == None: return | |
| loras_dict[key] = items | |
| def download_lora(dl_urls: str): | |
| global loras_url_to_path_dict | |
| dl_path = "" | |
| before = get_local_model_list(directory_loras) | |
| urls = [] | |
| for url in [url.strip() for url in dl_urls.split(',')]: | |
| local_path = f"{directory_loras}/{url.split('/')[-1]}" | |
| if not Path(local_path).exists(): | |
| download_things(directory_loras, url, hf_token, CIVITAI_API_KEY) | |
| urls.append(url) | |
| after = get_local_model_list(directory_loras) | |
| new_files = list_sub(after, before) | |
| for i, file in enumerate(new_files): | |
| path = Path(file) | |
| if path.exists(): | |
| new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}') | |
| path.resolve().rename(new_path.resolve()) | |
| loras_url_to_path_dict[urls[i]] = str(new_path) | |
| update_lora_dict(str(new_path)) | |
| dl_path = str(new_path) | |
| return dl_path | |
| def copy_lora(path: str, new_path: str): | |
| import shutil | |
| if path == new_path: return new_path | |
| cpath = Path(path) | |
| npath = Path(new_path) | |
| if cpath.exists(): | |
| try: | |
| shutil.copy(str(cpath.resolve()), str(npath.resolve())) | |
| except Exception: | |
| return None | |
| update_lora_dict(str(npath)) | |
| return new_path | |
| else: | |
| return None | |
| def download_my_lora(dl_urls: str, lora): | |
| path = download_lora(dl_urls) | |
| if path: lora = path | |
| choices = get_all_lora_tupled_list() | |
| return gr.update(value=lora, choices=choices) | |
| def apply_lora_prompt(lora_info: str): | |
| if lora_info == "None": return "" | |
| lora_tag = lora_info.replace("/",",") | |
| lora_tags = lora_tag.split(",") if str(lora_info) != "None" else [] | |
| lora_prompts = normalize_prompt_list(lora_tags) | |
| prompt = ", ".join(list_uniq(lora_prompts)) | |
| return prompt | |
| def update_loras(prompt, lora, lora_wt): | |
| on, label, tag, md = get_lora_info(lora) | |
| choices = get_all_lora_tupled_list() | |
| return gr.update(value=prompt), gr.update(value=lora, choices=choices), gr.update(value=lora_wt),\ | |
| gr.update(value=tag, label=label, visible=on), gr.update(value=md, visible=on) | |
| def search_civitai_lora(query, base_model): | |
| global civitai_lora_last_results | |
| items = search_lora_on_civitai(query, base_model) | |
| if not items: return gr.update(choices=[("", "")], value="", visible=False),\ | |
| gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True) | |
| civitai_lora_last_results = {} | |
| choices = [] | |
| for item in items: | |
| base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model'] | |
| name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})" | |
| value = item['dl_url'] | |
| choices.append((name, value)) | |
| civitai_lora_last_results[value] = item | |
| if not choices: return gr.update(choices=[("", "")], value="", visible=False),\ | |
| gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True) | |
| result = civitai_lora_last_results.get(choices[0][1], "None") | |
| md = result['md'] if result else "" | |
| return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\ | |
| gr.update(visible=True), gr.update(visible=True) | |
| def select_civitai_lora(search_result): | |
| if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True) | |
| result = civitai_lora_last_results.get(search_result, "None") | |
| md = result['md'] if result else "" | |
| return gr.update(value=search_result), gr.update(value=md, visible=True) | |
| def search_civitai_lora_json(query, base_model): | |
| results = {} | |
| items = search_lora_on_civitai(query, base_model) | |
| if not items: return gr.update(value=results) | |
| for item in items: | |
| results[item['dl_url']] = item | |
| return gr.update(value=results) | |