Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +8 -1
inference/flovd_demo.py
CHANGED
|
@@ -299,7 +299,14 @@ def patch_prepare_latents_safe():
|
|
| 299 |
image_latents = torch.cat([image_latents, latent_padding], dim=2)
|
| 300 |
|
| 301 |
if latents is None:
|
| 302 |
-
noise = torch.randn_like(image_latents, generator=generator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
latents = noise.to(device=device, dtype=dtype)
|
| 304 |
|
| 305 |
return latents, image_latents.to(device, dtype=dtype)
|
|
|
|
| 299 |
image_latents = torch.cat([image_latents, latent_padding], dim=2)
|
| 300 |
|
| 301 |
if latents is None:
|
| 302 |
+
# noise = torch.randn_like(image_latents, generator=generator)
|
| 303 |
+
noise = torch.randn(
|
| 304 |
+
image_latents.shape,
|
| 305 |
+
dtype=image_latents.dtype,
|
| 306 |
+
device=image_latents.device,
|
| 307 |
+
generator=generator
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
latents = noise.to(device=device, dtype=dtype)
|
| 311 |
|
| 312 |
return latents, image_latents.to(device, dtype=dtype)
|