Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -135,9 +135,13 @@ def initialize_models():
|
|
| 135 |
|
| 136 |
def infer(style_description, ref_style_file, caption):
|
| 137 |
try:
|
| 138 |
-
#
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
height = 1024
|
| 142 |
width = 1024
|
| 143 |
batch_size = 1
|
|
@@ -145,16 +149,6 @@ def infer(style_description, ref_style_file, caption):
|
|
| 145 |
|
| 146 |
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
|
| 147 |
|
| 148 |
-
extras.sampling_configs['cfg'] = 4
|
| 149 |
-
extras.sampling_configs['shift'] = 2
|
| 150 |
-
extras.sampling_configs['timesteps'] = 20
|
| 151 |
-
extras.sampling_configs['t_start'] = 1.0
|
| 152 |
-
|
| 153 |
-
extras_b.sampling_configs['cfg'] = 1.1
|
| 154 |
-
extras_b.sampling_configs['shift'] = 1
|
| 155 |
-
extras_b.sampling_configs['timesteps'] = 10
|
| 156 |
-
extras_b.sampling_configs['t_start'] = 1.0
|
| 157 |
-
|
| 158 |
ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
|
| 159 |
|
| 160 |
batch = {'captions': [caption] * batch_size}
|
|
@@ -189,6 +183,9 @@ def infer(style_description, ref_style_file, caption):
|
|
| 189 |
|
| 190 |
clear_gpu_cache() # Clear cache between stages
|
| 191 |
|
|
|
|
|
|
|
|
|
|
| 192 |
# Stage B reverse process
|
| 193 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 194 |
conditions_b['effnet'] = sampled_c
|
|
|
|
| 135 |
|
| 136 |
def infer(style_description, ref_style_file, caption):
|
| 137 |
try:
|
| 138 |
+
# Clear GPU cache before inference
|
| 139 |
+
clear_gpu_cache()
|
| 140 |
+
|
| 141 |
+
# Ensure models are on the correct device
|
| 142 |
+
models_rbm.to(device)
|
| 143 |
+
models_b.to(device)
|
| 144 |
+
|
| 145 |
height = 1024
|
| 146 |
width = 1024
|
| 147 |
batch_size = 1
|
|
|
|
| 149 |
|
| 150 |
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
|
| 153 |
|
| 154 |
batch = {'captions': [caption] * batch_size}
|
|
|
|
| 183 |
|
| 184 |
clear_gpu_cache() # Clear cache between stages
|
| 185 |
|
| 186 |
+
# Ensure models_b is on the correct device
|
| 187 |
+
models_b.to(device)
|
| 188 |
+
|
| 189 |
# Stage B reverse process
|
| 190 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 191 |
conditions_b['effnet'] = sampled_c
|