Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +24 -7
inference/flovd_demo.py
CHANGED
|
@@ -264,28 +264,45 @@ def save_flow_warped_video(image, flow, filename, fps=16):
|
|
| 264 |
frame_list.append(Image.fromarray(frame))
|
| 265 |
|
| 266 |
export_to_video(frame_list, filename, fps=fps)
|
|
|
|
| 267 |
|
| 268 |
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
| 269 |
-
|
| 270 |
def patch_prepare_latents_safe():
|
| 271 |
-
def new_prepare_latents(
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
image_latents = image_latents * self.vae.config.scaling_factor
|
| 274 |
|
|
|
|
| 275 |
if image_latents.shape[2] != num_frames:
|
| 276 |
latent_padding = torch.zeros(
|
| 277 |
(image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
|
| 278 |
-
device=image_latents.device,
|
|
|
|
| 279 |
)
|
| 280 |
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
|
|
|
| 284 |
|
| 285 |
-
return latents, image_latents.to(
|
| 286 |
|
|
|
|
| 287 |
CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
|
| 288 |
|
|
|
|
| 289 |
def generate_video(
|
| 290 |
prompt: str,
|
| 291 |
fvsm_path: str,
|
|
|
|
| 264 |
frame_list.append(Image.fromarray(frame))
|
| 265 |
|
| 266 |
export_to_video(frame_list, filename, fps=fps)
|
| 267 |
+
|
| 268 |
|
| 269 |
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
|
|
|
| 270 |
def patch_prepare_latents_safe():
|
| 271 |
+
def new_prepare_latents(
|
| 272 |
+
self,
|
| 273 |
+
image,
|
| 274 |
+
batch_size,
|
| 275 |
+
latent_channels,
|
| 276 |
+
num_frames,
|
| 277 |
+
height,
|
| 278 |
+
width,
|
| 279 |
+
dtype,
|
| 280 |
+
device,
|
| 281 |
+
generator,
|
| 282 |
+
latents=None,
|
| 283 |
+
):
|
| 284 |
+
image_latents = self.vae.encode(image.to(device, dtype=dtype)).latent_dist.sample()
|
| 285 |
image_latents = image_latents * self.vae.config.scaling_factor
|
| 286 |
|
| 287 |
+
# Pad temporal dimension if needed
|
| 288 |
if image_latents.shape[2] != num_frames:
|
| 289 |
latent_padding = torch.zeros(
|
| 290 |
(image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
|
| 291 |
+
device=image_latents.device,
|
| 292 |
+
dtype=image_latents.dtype
|
| 293 |
)
|
| 294 |
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
| 295 |
|
| 296 |
+
if latents is None:
|
| 297 |
+
noise = torch.randn_like(image_latents, generator=generator)
|
| 298 |
+
latents = noise.to(device=device, dtype=dtype)
|
| 299 |
|
| 300 |
+
return latents, image_latents.to(device, dtype=dtype)
|
| 301 |
|
| 302 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
| 303 |
CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
|
| 304 |
|
| 305 |
+
|
| 306 |
def generate_video(
|
| 307 |
prompt: str,
|
| 308 |
fvsm_path: str,
|