Spaces:
Build error
Build error
add use low vram option
Browse files
app.py
CHANGED
|
@@ -34,7 +34,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
| 34 |
print(device)
|
| 35 |
|
| 36 |
# Flag for low VRAM usage
|
| 37 |
-
low_vram = False
|
| 38 |
|
| 39 |
# Function definition for low VRAM usage
|
| 40 |
def models_to(model, device="cpu", excepts=None):
|
|
@@ -107,11 +107,13 @@ models_b = WurstCoreB.Models(
|
|
| 107 |
)
|
| 108 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
| 109 |
|
|
|
|
| 110 |
if low_vram:
|
| 111 |
# Off-load old generator (which is not used in models_rbm)
|
| 112 |
models.generator.to("cpu")
|
| 113 |
torch.cuda.empty_cache()
|
| 114 |
gc.collect()
|
|
|
|
| 115 |
|
| 116 |
generator_rbm = StageCRBM()
|
| 117 |
for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
|
|
@@ -128,10 +130,10 @@ models_rbm.generator.eval().requires_grad_(False)
|
|
| 128 |
|
| 129 |
|
| 130 |
|
| 131 |
-
def infer(ref_style_file, style_description, caption, progress):
|
| 132 |
global models_rbm, models_b, device
|
| 133 |
|
| 134 |
-
if
|
| 135 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
| 136 |
try:
|
| 137 |
|
|
@@ -167,7 +169,7 @@ def infer(ref_style_file, style_description, caption, progress):
|
|
| 167 |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
| 168 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
| 169 |
|
| 170 |
-
if
|
| 171 |
# The sampling process uses more vram, so we offload everything except two modules to the cpu.
|
| 172 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 173 |
|
|
@@ -236,10 +238,10 @@ def infer(ref_style_file, style_description, caption, progress):
|
|
| 236 |
torch.cuda.empty_cache()
|
| 237 |
gc.collect()
|
| 238 |
|
| 239 |
-
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
|
| 240 |
global models_rbm, models_b, device
|
| 241 |
sam_model = LangSAM()
|
| 242 |
-
if
|
| 243 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
| 244 |
models_to(sam_model, device=device)
|
| 245 |
models_to(sam_model.sam, device=device)
|
|
@@ -288,7 +290,7 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
|
|
| 288 |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
| 289 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
| 290 |
|
| 291 |
-
if
|
| 292 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 293 |
models_to(sam_model, device="cpu")
|
| 294 |
models_to(sam_model.sam, device="cpu")
|
|
@@ -363,13 +365,13 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
|
|
| 363 |
torch.cuda.empty_cache()
|
| 364 |
gc.collect()
|
| 365 |
|
| 366 |
-
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
|
| 367 |
result = None
|
| 368 |
progress = gr.Progress(track_tqdm=True)
|
| 369 |
if use_subject_ref is True:
|
| 370 |
-
result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference, progress)
|
| 371 |
else:
|
| 372 |
-
result = infer(style_reference_image, style_description, subject_prompt, progress)
|
| 373 |
return result
|
| 374 |
|
| 375 |
def show_hide_subject_image_component(use_subject_ref):
|
|
@@ -406,7 +408,9 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
| 406 |
subject_prompt = gr.Textbox(
|
| 407 |
label = "Subject Prompt"
|
| 408 |
)
|
| 409 |
-
|
|
|
|
|
|
|
| 410 |
|
| 411 |
with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
|
| 412 |
subject_reference = gr.Image(label="Subject Reference", type="filepath")
|
|
@@ -418,13 +422,13 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
| 418 |
output_image = gr.Image(label="Output Image")
|
| 419 |
gr.Examples(
|
| 420 |
examples = [
|
| 421 |
-
["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False],
|
| 422 |
-
["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False],
|
| 423 |
-
["./data/glowing.png", "glowing style", "a dwarf", None, False],
|
| 424 |
-
["./data/melting_gold.png", "melting golden 3D rendering style", "a dog", "./data/dog.jpg", True]
|
| 425 |
],
|
| 426 |
fn=run,
|
| 427 |
-
inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref],
|
| 428 |
outputs=[output_image],
|
| 429 |
cache_examples=False
|
| 430 |
|
|
@@ -439,7 +443,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
| 439 |
|
| 440 |
submit_btn.click(
|
| 441 |
fn = run,
|
| 442 |
-
inputs = [style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref],
|
| 443 |
outputs = [output_image],
|
| 444 |
show_api = False
|
| 445 |
)
|
|
|
|
| 34 |
print(device)
|
| 35 |
|
| 36 |
# Flag for low VRAM usage
|
| 37 |
+
# low_vram = False
|
| 38 |
|
| 39 |
# Function definition for low VRAM usage
|
| 40 |
def models_to(model, device="cpu", excepts=None):
|
|
|
|
| 107 |
)
|
| 108 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
| 109 |
|
| 110 |
+
"""
|
| 111 |
if low_vram:
|
| 112 |
# Off-load old generator (which is not used in models_rbm)
|
| 113 |
models.generator.to("cpu")
|
| 114 |
torch.cuda.empty_cache()
|
| 115 |
gc.collect()
|
| 116 |
+
"""
|
| 117 |
|
| 118 |
generator_rbm = StageCRBM()
|
| 119 |
for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
|
| 133 |
+
def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
| 134 |
global models_rbm, models_b, device
|
| 135 |
|
| 136 |
+
if use_low_vram:
|
| 137 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
| 138 |
try:
|
| 139 |
|
|
|
|
| 169 |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
| 170 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
| 171 |
|
| 172 |
+
if use_low_vram:
|
| 173 |
# The sampling process uses more vram, so we offload everything except two modules to the cpu.
|
| 174 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 175 |
|
|
|
|
| 238 |
torch.cuda.empty_cache()
|
| 239 |
gc.collect()
|
| 240 |
|
| 241 |
+
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
|
| 242 |
global models_rbm, models_b, device
|
| 243 |
sam_model = LangSAM()
|
| 244 |
+
if use_low_vram:
|
| 245 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
| 246 |
models_to(sam_model, device=device)
|
| 247 |
models_to(sam_model.sam, device=device)
|
|
|
|
| 290 |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
| 291 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
| 292 |
|
| 293 |
+
if use_low_vram:
|
| 294 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 295 |
models_to(sam_model, device="cpu")
|
| 296 |
models_to(sam_model.sam, device="cpu")
|
|
|
|
| 365 |
torch.cuda.empty_cache()
|
| 366 |
gc.collect()
|
| 367 |
|
| 368 |
+
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram):
|
| 369 |
result = None
|
| 370 |
progress = gr.Progress(track_tqdm=True)
|
| 371 |
if use_subject_ref is True:
|
| 372 |
+
result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference, use_low_vram, progress)
|
| 373 |
else:
|
| 374 |
+
result = infer(style_reference_image, style_description, subject_prompt, use_low_vram, progress)
|
| 375 |
return result
|
| 376 |
|
| 377 |
def show_hide_subject_image_component(use_subject_ref):
|
|
|
|
| 408 |
subject_prompt = gr.Textbox(
|
| 409 |
label = "Subject Prompt"
|
| 410 |
)
|
| 411 |
+
with gr.Row():
|
| 412 |
+
use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
|
| 413 |
+
use_low_vram = gr.Checkbox(label="Use Low-VRAM", value=False)
|
| 414 |
|
| 415 |
with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
|
| 416 |
subject_reference = gr.Image(label="Subject Reference", type="filepath")
|
|
|
|
| 422 |
output_image = gr.Image(label="Output Image")
|
| 423 |
gr.Examples(
|
| 424 |
examples = [
|
| 425 |
+
["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False, False],
|
| 426 |
+
["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False, False],
|
| 427 |
+
["./data/glowing.png", "glowing style", "a dwarf", None, False, False],
|
| 428 |
+
["./data/melting_gold.png", "melting golden 3D rendering style", "a dog", "./data/dog.jpg", True, False]
|
| 429 |
],
|
| 430 |
fn=run,
|
| 431 |
+
inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
|
| 432 |
outputs=[output_image],
|
| 433 |
cache_examples=False
|
| 434 |
|
|
|
|
| 443 |
|
| 444 |
submit_btn.click(
|
| 445 |
fn = run,
|
| 446 |
+
inputs = [style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
|
| 447 |
outputs = [output_image],
|
| 448 |
show_api = False
|
| 449 |
)
|