Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +2 -0
inference/flovd_demo.py
CHANGED
|
@@ -300,6 +300,8 @@ def patch_prepare_latents_safe():
|
|
| 300 |
|
| 301 |
if latents is None:
|
| 302 |
# noise = torch.randn_like(image_latents, generator=generator)
|
|
|
|
|
|
|
| 303 |
noise = torch.randn(
|
| 304 |
image_latents.shape,
|
| 305 |
dtype=image_latents.dtype,
|
|
|
|
| 300 |
|
| 301 |
if latents is None:
|
| 302 |
# noise = torch.randn_like(image_latents, generator=generator)
|
| 303 |
+
if generator.device != image_latents.device:
|
| 304 |
+
generator = generator.to(image_latents.device)
|
| 305 |
noise = torch.randn(
|
| 306 |
image_latents.shape,
|
| 307 |
dtype=image_latents.dtype,
|