Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -37,25 +37,24 @@ print(device)
|
|
| 37 |
low_vram = True
|
| 38 |
|
| 39 |
# Function definition for low VRAM usage
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
torch.cuda.empty_cache()
|
| 59 |
|
| 60 |
# Stage C model configuration
|
| 61 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
@@ -128,9 +127,10 @@ models_rbm = core.Models(
|
|
| 128 |
generator_ema=models.generator_ema,
|
| 129 |
tokenizer=models.tokenizer,
|
| 130 |
text_model=models.text_model,
|
| 131 |
-
image_model=models.image_model
|
|
|
|
|
|
|
| 132 |
)
|
| 133 |
-
models_rbm.generator.eval().requires_grad_(False)
|
| 134 |
|
| 135 |
def reset_inference_state():
|
| 136 |
global models_rbm, models_b, extras, extras_b, device, core, core_b
|
|
@@ -146,29 +146,32 @@ def reset_inference_state():
|
|
| 146 |
extras_b.sampling_configs['timesteps'] = 10
|
| 147 |
extras_b.sampling_configs['t_start'] = 1.0
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
models_rbm =
|
| 151 |
-
models_b
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
)
|
|
|
|
| 155 |
|
| 156 |
-
# Move models to the correct device
|
| 157 |
if low_vram:
|
| 158 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 159 |
-
|
|
|
|
| 160 |
else:
|
| 161 |
models_to(models_rbm, device=device)
|
| 162 |
-
|
|
|
|
| 163 |
|
| 164 |
# Ensure effnet is on the correct device
|
| 165 |
models_rbm.effnet.to(device)
|
| 166 |
|
| 167 |
-
#
|
| 168 |
models_rbm.generator.eval().requires_grad_(False)
|
| 169 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
| 170 |
|
| 171 |
-
# Clear CUDA cache
|
| 172 |
torch.cuda.empty_cache()
|
| 173 |
gc.collect()
|
| 174 |
|
|
@@ -197,6 +200,14 @@ def infer(style_description, ref_style_file, caption):
|
|
| 197 |
batch = {'captions': [caption] * batch_size}
|
| 198 |
batch['style'] = ref_style
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
# Ensure effnet is on the correct device
|
| 201 |
models_rbm.effnet.to(device)
|
| 202 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
|
|
|
| 37 |
low_vram = True
|
| 38 |
|
| 39 |
# Function definition for low VRAM usage
|
| 40 |
+
def models_to(model, device="cpu", excepts=None):
|
| 41 |
+
"""
|
| 42 |
+
Change the device of nn.Modules within a class, skipping specified attributes.
|
| 43 |
+
"""
|
| 44 |
+
for attr_name in dir(model):
|
| 45 |
+
if attr_name.startswith('__') and attr_name.endswith('__'):
|
| 46 |
+
continue # skip special attributes
|
| 47 |
+
|
| 48 |
+
attr_value = getattr(model, attr_name, None)
|
| 49 |
+
|
| 50 |
+
if isinstance(attr_value, torch.nn.Module):
|
| 51 |
+
if excepts and attr_name in excepts:
|
| 52 |
+
print(f"Except '{attr_name}'")
|
| 53 |
+
continue
|
| 54 |
+
print(f"Change device of '{attr_name}' to {device}")
|
| 55 |
+
attr_value.to(device)
|
| 56 |
+
|
| 57 |
+
torch.cuda.empty_cache()
|
|
|
|
| 58 |
|
| 59 |
# Stage C model configuration
|
| 60 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
|
|
| 127 |
generator_ema=models.generator_ema,
|
| 128 |
tokenizer=models.tokenizer,
|
| 129 |
text_model=models.text_model,
|
| 130 |
+
image_model=models.image_model,
|
| 131 |
+
stage_a=models.stage_a,
|
| 132 |
+
stage_b=models.stage_b,
|
| 133 |
)
|
|
|
|
| 134 |
|
| 135 |
def reset_inference_state():
|
| 136 |
global models_rbm, models_b, extras, extras_b, device, core, core_b
|
|
|
|
| 146 |
extras_b.sampling_configs['timesteps'] = 10
|
| 147 |
extras_b.sampling_configs['t_start'] = 1.0
|
| 148 |
|
| 149 |
+
# Move models to CPU to free up GPU memory
|
| 150 |
+
models_to(models_rbm, device="cpu")
|
| 151 |
+
models_b.generator.to("cpu")
|
| 152 |
+
|
| 153 |
+
# Clear CUDA cache
|
| 154 |
+
torch.cuda.empty_cache()
|
| 155 |
+
gc.collect()
|
| 156 |
|
| 157 |
+
# Move necessary models back to the correct device
|
| 158 |
if low_vram:
|
| 159 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 160 |
+
models_rbm.generator.to(device)
|
| 161 |
+
models_rbm.previewer.to(device)
|
| 162 |
else:
|
| 163 |
models_to(models_rbm, device=device)
|
| 164 |
+
|
| 165 |
+
models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
|
| 166 |
|
| 167 |
# Ensure effnet is on the correct device
|
| 168 |
models_rbm.effnet.to(device)
|
| 169 |
|
| 170 |
+
# Reset model states
|
| 171 |
models_rbm.generator.eval().requires_grad_(False)
|
| 172 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
| 173 |
|
| 174 |
+
# Clear CUDA cache again
|
| 175 |
torch.cuda.empty_cache()
|
| 176 |
gc.collect()
|
| 177 |
|
|
|
|
| 200 |
batch = {'captions': [caption] * batch_size}
|
| 201 |
batch['style'] = ref_style
|
| 202 |
|
| 203 |
+
# Ensure models are on the correct device before inference
|
| 204 |
+
if low_vram:
|
| 205 |
+
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
| 206 |
+
else:
|
| 207 |
+
models_to(models_rbm, device=device)
|
| 208 |
+
|
| 209 |
+
models_b.generator.to(device)
|
| 210 |
+
|
| 211 |
# Ensure effnet is on the correct device
|
| 212 |
models_rbm.effnet.to(device)
|
| 213 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|