Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- app.py +5 -2
- mod.py +31 -4
- requirements.txt +6 -1
app.py
CHANGED
|
@@ -11,7 +11,8 @@ import random
|
|
| 11 |
import time
|
| 12 |
|
| 13 |
from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
|
| 14 |
-
description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
|
|
|
|
| 15 |
from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
|
| 16 |
download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
|
| 17 |
update_loras)
|
|
@@ -241,6 +242,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
| 241 |
tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use CogFlorence-2.1-Large", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
|
| 242 |
tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
|
| 243 |
prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt")
|
|
|
|
| 244 |
with gr.Column(scale=1, elem_id="gen_column"):
|
| 245 |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
|
| 246 |
with gr.Row():
|
|
@@ -306,8 +308,8 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
| 306 |
with gr.Accordion("From URL", open=True, visible=True):
|
| 307 |
with gr.Row():
|
| 308 |
lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
|
| 309 |
-
lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
|
| 310 |
lora_search_civitai_submit = gr.Button("Search on Civitai")
|
|
|
|
| 311 |
lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
| 312 |
lora_search_civitai_json = gr.JSON(value={}, visible=False)
|
| 313 |
lora_search_civitai_desc = gr.Markdown(value="", visible=False)
|
|
@@ -344,6 +346,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
| 344 |
)
|
| 345 |
|
| 346 |
model_name.change(change_base_model, [model_name], [result])
|
|
|
|
| 347 |
|
| 348 |
gr.on(
|
| 349 |
triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
|
|
|
|
| 11 |
import time
|
| 12 |
|
| 13 |
from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
|
| 14 |
+
description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
|
| 15 |
+
get_trigger_word, pipe, enhance_prompt)
|
| 16 |
from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
|
| 17 |
download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
|
| 18 |
update_loras)
|
|
|
|
| 242 |
tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use CogFlorence-2.1-Large", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
|
| 243 |
tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
|
| 244 |
prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt")
|
| 245 |
+
prompt_enhance = gr.Button(value="Enhance your prompt", variant="secondary")
|
| 246 |
with gr.Column(scale=1, elem_id="gen_column"):
|
| 247 |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
|
| 248 |
with gr.Row():
|
|
|
|
| 308 |
with gr.Accordion("From URL", open=True, visible=True):
|
| 309 |
with gr.Row():
|
| 310 |
lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
|
|
|
|
| 311 |
lora_search_civitai_submit = gr.Button("Search on Civitai")
|
| 312 |
+
lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
|
| 313 |
lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
| 314 |
lora_search_civitai_json = gr.JSON(value={}, visible=False)
|
| 315 |
lora_search_civitai_desc = gr.Markdown(value="", visible=False)
|
|
|
|
| 346 |
)
|
| 347 |
|
| 348 |
model_name.change(change_base_model, [model_name], [result])
|
| 349 |
+
prompt_enhance.click(enhance_prompt, [prompt], [prompt])
|
| 350 |
|
| 351 |
gr.on(
|
| 352 |
triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
|
mod.py
CHANGED
|
@@ -7,6 +7,7 @@ import gc
|
|
| 7 |
import subprocess
|
| 8 |
|
| 9 |
|
|
|
|
| 10 |
subprocess.run('pip cache purge', shell=True)
|
| 11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
torch.set_grad_enabled(False)
|
|
@@ -61,7 +62,7 @@ def get_repo_safetensors(repo_id: str):
|
|
| 61 |
if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
|
| 62 |
files = api.list_repo_files(repo_id=repo_id)
|
| 63 |
except Exception as e:
|
| 64 |
-
print(f"Error: Failed to get {repo_id}'s info.
|
| 65 |
print(e)
|
| 66 |
return gr.update(choices=[])
|
| 67 |
files = [f for f in files if f.endswith(".safetensors")]
|
|
@@ -138,8 +139,7 @@ def fuse_loras(pipe, lorajson: list[dict]):
|
|
| 138 |
#pipe.unload_lora_weights()
|
| 139 |
|
| 140 |
|
| 141 |
-
|
| 142 |
-
fuse_loras.zerogpu = True
|
| 143 |
|
| 144 |
|
| 145 |
def description_ui():
|
|
@@ -148,4 +148,31 @@ def description_ui():
|
|
| 148 |
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
|
| 149 |
[gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
|
| 150 |
"""
|
| 151 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import subprocess
|
| 8 |
|
| 9 |
|
| 10 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 11 |
subprocess.run('pip cache purge', shell=True)
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
torch.set_grad_enabled(False)
|
|
|
|
| 62 |
if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
|
| 63 |
files = api.list_repo_files(repo_id=repo_id)
|
| 64 |
except Exception as e:
|
| 65 |
+
print(f"Error: Failed to get {repo_id}'s info.")
|
| 66 |
print(e)
|
| 67 |
return gr.update(choices=[])
|
| 68 |
files = [f for f in files if f.endswith(".safetensors")]
|
|
|
|
| 139 |
#pipe.unload_lora_weights()
|
| 140 |
|
| 141 |
|
| 142 |
+
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
def description_ui():
|
|
|
|
| 148 |
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
|
| 149 |
[gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
|
| 150 |
"""
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
| 155 |
+
def load_prompt_enhancer():
|
| 156 |
+
try:
|
| 157 |
+
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
|
| 158 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
| 159 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
|
| 160 |
+
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print(e)
|
| 163 |
+
enhancer_flux = None
|
| 164 |
+
return enhancer_flux
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
enhancer_flux = load_prompt_enhancer()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def enhance_prompt(input_prompt):
|
| 171 |
+
result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
|
| 172 |
+
enhanced_text = result[0]['generated_text']
|
| 173 |
+
return enhanced_text
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
load_prompt_enhancer.zerogpu = True
|
| 177 |
+
change_base_model.zerogpu = True
|
| 178 |
+
fuse_loras.zerogpu = True
|
requirements.txt
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
torch
|
|
|
|
|
|
|
|
|
|
| 2 |
git+https://github.com/huggingface/diffusers
|
| 3 |
spaces
|
| 4 |
transformers
|
| 5 |
peft
|
| 6 |
sentencepiece
|
| 7 |
-
timm
|
|
|
|
|
|
|
|
|
| 1 |
torch
|
| 2 |
+
torchvision
|
| 3 |
+
huggingface_hub
|
| 4 |
+
accelerate
|
| 5 |
git+https://github.com/huggingface/diffusers
|
| 6 |
spaces
|
| 7 |
transformers
|
| 8 |
peft
|
| 9 |
sentencepiece
|
| 10 |
+
timm
|
| 11 |
+
xformers
|
| 12 |
+
einops
|