Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -194,14 +194,19 @@ def infer(style_description, ref_style_file, caption):
|
|
| 194 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 195 |
|
| 196 |
sampled = torch.cat([
|
| 197 |
-
torch.nn.functional.interpolate(ref_style.cpu(), size=height),
|
| 198 |
sampled.cpu(),
|
| 199 |
-
|
| 200 |
-
dim=0)
|
| 201 |
|
| 202 |
-
#
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
return output_file # Return the path to the saved image
|
| 207 |
|
|
|
|
| 194 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 195 |
|
| 196 |
sampled = torch.cat([
|
| 197 |
+
torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
|
| 198 |
sampled.cpu(),
|
| 199 |
+
], dim=0)
|
|
|
|
| 200 |
|
| 201 |
+
# Remove the batch dimension and keep only the generated image
|
| 202 |
+
sampled = sampled[1] # This selects the generated image, discarding the reference style image
|
| 203 |
+
|
| 204 |
+
# Ensure the tensor is in [C, H, W] format
|
| 205 |
+
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
| 206 |
+
sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
|
| 207 |
+
sampled_image.save(output_file) # Save the image as a PNG
|
| 208 |
+
else:
|
| 209 |
+
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
| 210 |
|
| 211 |
return output_file # Return the path to the saved image
|
| 212 |
|