Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +8 -6
inference/flovd_demo.py
CHANGED
|
@@ -265,8 +265,6 @@ def save_flow_warped_video(image, flow, filename, fps=16):
|
|
| 265 |
|
| 266 |
export_to_video(frame_list, filename, fps=fps)
|
| 267 |
|
| 268 |
-
|
| 269 |
-
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
| 270 |
def patch_prepare_latents_safe():
|
| 271 |
def new_prepare_latents(
|
| 272 |
self,
|
|
@@ -281,16 +279,20 @@ def patch_prepare_latents_safe():
|
|
| 281 |
generator,
|
| 282 |
latents=None,
|
| 283 |
):
|
| 284 |
-
# Ensure 5D input: [B, C, F=1, H, W]
|
| 285 |
image_5d = image.unsqueeze(2) if image.ndim == 4 else image
|
| 286 |
-
|
| 287 |
image_latents = self.vae.encode(image_5d.to(device, dtype=dtype)).latent_dist.sample()
|
| 288 |
image_latents = image_latents * self.vae.config.scaling_factor
|
| 289 |
|
| 290 |
-
# Pad
|
| 291 |
if image_latents.shape[2] != num_frames:
|
| 292 |
latent_padding = torch.zeros(
|
| 293 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
device=image_latents.device,
|
| 295 |
dtype=image_latents.dtype
|
| 296 |
)
|
|
|
|
| 265 |
|
| 266 |
export_to_video(frame_list, filename, fps=fps)
|
| 267 |
|
|
|
|
|
|
|
| 268 |
def patch_prepare_latents_safe():
|
| 269 |
def new_prepare_latents(
|
| 270 |
self,
|
|
|
|
| 279 |
generator,
|
| 280 |
latents=None,
|
| 281 |
):
|
|
|
|
| 282 |
image_5d = image.unsqueeze(2) if image.ndim == 4 else image
|
|
|
|
| 283 |
image_latents = self.vae.encode(image_5d.to(device, dtype=dtype)).latent_dist.sample()
|
| 284 |
image_latents = image_latents * self.vae.config.scaling_factor
|
| 285 |
|
| 286 |
+
# Pad frame dim if needed
|
| 287 |
if image_latents.shape[2] != num_frames:
|
| 288 |
latent_padding = torch.zeros(
|
| 289 |
+
(
|
| 290 |
+
image_latents.shape[0],
|
| 291 |
+
image_latents.shape[1],
|
| 292 |
+
num_frames - image_latents.shape[2],
|
| 293 |
+
image_latents.shape[3],
|
| 294 |
+
image_latents.shape[4],
|
| 295 |
+
),
|
| 296 |
device=image_latents.device,
|
| 297 |
dtype=image_latents.dtype
|
| 298 |
)
|