Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import sys
|
| 2 |
import os
|
| 3 |
from pathlib import Path
|
| 4 |
-
import gc
|
| 5 |
|
| 6 |
# Add the StableCascade and CSD directories to the Python path
|
| 7 |
app_dir = Path(__file__).parent
|
|
@@ -28,29 +27,12 @@ from gdf.schedulers import CosineSchedule
|
|
| 28 |
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
|
| 29 |
from gdf.targets import EpsilonTarget
|
| 30 |
|
| 31 |
-
# Enable mixed precision
|
| 32 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 33 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 34 |
-
|
| 35 |
# Device configuration
|
| 36 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 37 |
print(device)
|
| 38 |
|
| 39 |
# Flag for low VRAM usage
|
| 40 |
-
low_vram =
|
| 41 |
-
|
| 42 |
-
# Function to clear GPU cache
|
| 43 |
-
def clear_gpu_cache():
|
| 44 |
-
torch.cuda.empty_cache()
|
| 45 |
-
gc.collect()
|
| 46 |
-
|
| 47 |
-
# Function to move model to CPU
|
| 48 |
-
def to_cpu(model):
|
| 49 |
-
return model.cpu()
|
| 50 |
-
|
| 51 |
-
# Function to move model to GPU
|
| 52 |
-
def to_gpu(model):
|
| 53 |
-
return model.cuda()
|
| 54 |
|
| 55 |
# Function definition for low VRAM usage
|
| 56 |
if low_vram:
|
|
@@ -71,7 +53,7 @@ if low_vram:
|
|
| 71 |
print(f"Change device of '{attr_name}' to {device}")
|
| 72 |
attr_value.to(device)
|
| 73 |
|
| 74 |
-
|
| 75 |
|
| 76 |
# Stage C model configuration
|
| 77 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
@@ -126,7 +108,7 @@ models_b.generator.bfloat16().eval().requires_grad_(False)
|
|
| 126 |
# Off-load old generator (low VRAM mode)
|
| 127 |
if low_vram:
|
| 128 |
models.generator.to("cpu")
|
| 129 |
-
|
| 130 |
|
| 131 |
# Load and configure new generator
|
| 132 |
generator_rbm = StageCRBM()
|
|
@@ -149,7 +131,6 @@ models_rbm = core.Models(
|
|
| 149 |
models_rbm.generator.eval().requires_grad_(False)
|
| 150 |
|
| 151 |
def infer(style_description, ref_style_file, caption):
|
| 152 |
-
clear_gpu_cache() # Clear cache before inference
|
| 153 |
|
| 154 |
height=1024
|
| 155 |
width=1024
|
|
@@ -185,22 +166,19 @@ def infer(style_description, ref_style_file, caption):
|
|
| 185 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 186 |
|
| 187 |
# Stage C reverse process.
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
sampled_c = sampled_c
|
| 202 |
-
|
| 203 |
-
clear_gpu_cache() # Clear cache between stages
|
| 204 |
|
| 205 |
# Stage B reverse process.
|
| 206 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
@@ -216,21 +194,14 @@ def infer(style_description, ref_style_file, caption):
|
|
| 216 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 217 |
|
| 218 |
sampled = torch.cat([
|
| 219 |
-
torch.nn.functional.interpolate(ref_style.cpu(), size=
|
| 220 |
sampled.cpu(),
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
# Remove the batch dimension and keep only the generated image
|
| 224 |
-
sampled = sampled[1] # This selects the generated image, discarding the reference style image
|
| 225 |
-
|
| 226 |
-
# Ensure the tensor is in [C, H, W] format
|
| 227 |
-
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
| 228 |
-
sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
|
| 229 |
-
sampled_image.save(output_file) # Save the image as a PNG
|
| 230 |
-
else:
|
| 231 |
-
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
| 232 |
|
| 233 |
-
|
|
|
|
|
|
|
| 234 |
|
| 235 |
return output_file # Return the path to the saved image
|
| 236 |
|
|
|
|
| 1 |
import sys
|
| 2 |
import os
|
| 3 |
from pathlib import Path
|
|
|
|
| 4 |
|
| 5 |
# Add the StableCascade and CSD directories to the Python path
|
| 6 |
app_dir = Path(__file__).parent
|
|
|
|
| 27 |
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
|
| 28 |
from gdf.targets import EpsilonTarget
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Device configuration
|
| 31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 32 |
print(device)
|
| 33 |
|
| 34 |
# Flag for low VRAM usage
|
| 35 |
+
low_vram = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# Function definition for low VRAM usage
|
| 38 |
if low_vram:
|
|
|
|
| 53 |
print(f"Change device of '{attr_name}' to {device}")
|
| 54 |
attr_value.to(device)
|
| 55 |
|
| 56 |
+
torch.cuda.empty_cache()
|
| 57 |
|
| 58 |
# Stage C model configuration
|
| 59 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
|
|
| 108 |
# Off-load old generator (low VRAM mode)
|
| 109 |
if low_vram:
|
| 110 |
models.generator.to("cpu")
|
| 111 |
+
torch.cuda.empty_cache()
|
| 112 |
|
| 113 |
# Load and configure new generator
|
| 114 |
generator_rbm = StageCRBM()
|
|
|
|
| 131 |
models_rbm.generator.eval().requires_grad_(False)
|
| 132 |
|
| 133 |
def infer(style_description, ref_style_file, caption):
|
|
|
|
| 134 |
|
| 135 |
height=1024
|
| 136 |
width=1024
|
|
|
|
| 166 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 167 |
|
| 168 |
# Stage C reverse process.
|
| 169 |
+
sampling_c = extras.gdf.sample(
|
| 170 |
+
models_rbm.generator, conditions, stage_c_latent_shape,
|
| 171 |
+
unconditions, device=device,
|
| 172 |
+
**extras.sampling_configs,
|
| 173 |
+
x0_style_forward=x0_style_forward,
|
| 174 |
+
apply_pushforward=False, tau_pushforward=8,
|
| 175 |
+
num_iter=3, eta=0.1, tau=20, eval_csd=True,
|
| 176 |
+
extras=extras, models=models_rbm,
|
| 177 |
+
lam_style=1, lam_txt_alignment=1.0,
|
| 178 |
+
use_ddim_sampler=True,
|
| 179 |
+
)
|
| 180 |
+
for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
|
| 181 |
+
sampled_c = sampled_c
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# Stage B reverse process.
|
| 184 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
|
|
| 194 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 195 |
|
| 196 |
sampled = torch.cat([
|
| 197 |
+
torch.nn.functional.interpolate(ref_style.cpu(), size=height),
|
| 198 |
sampled.cpu(),
|
| 199 |
+
],
|
| 200 |
+
dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
# Save the sampled image to a file
|
| 203 |
+
sampled_image = T.ToPILImage()(sampled.squeeze(0)) # Convert tensor to PIL image
|
| 204 |
+
sampled_image.save(output_file) # Save the image
|
| 205 |
|
| 206 |
return output_file # Return the path to the saved image
|
| 207 |
|