# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from diffusers.modular_pipelines import ( ModularPipelineBlocks, ComponentSpec, PipelineState, ModularPipeline, OutputParam, InputParam, ) from diffusers.modular_pipelines.wan.before_denoise import retrieve_timesteps from typing import Optional, List, Union, Tuple from diffusers.image_processor import PipelineImageInput from diffusers.utils.torch_utils import randn_tensor import torch from diffusers import AutoencoderKLWan, UniPCMultistepScheduler # One needs Wan anyway to run ChronoEdit (`AutoencoderKLWan`). from diffusers.pipelines.wan.pipeline_wan_i2v import retrieve_latents class ChronoEditSetTimestepsStep(ModularPipelineBlocks): model_name = "chronoedit" @property def expected_components(self) -> List[ComponentSpec]: return [ComponentSpec("scheduler", UniPCMultistepScheduler)] @property def inputs(self) -> List[InputParam]: return [InputParam("num_inference_steps", default=50), InputParam("timesteps"), InputParam("sigmas")] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), OutputParam( "num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time", ), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas, ) self.set_block_state(state, block_state) return components, state class ChronoEditPrepareLatentStep(ModularPipelineBlocks): model_name = "chronoedit" @property def expected_components(self) -> List[ComponentSpec]: return [ComponentSpec("vae", AutoencoderKLWan)] @property def inputs(self) -> List[InputParam]: return [ InputParam("processed_image", type_hint=PipelineImageInput), InputParam("image_embeds", type_hint=torch.Tensor), InputParam("height", type_hint=int, default=480), InputParam("width", type_hint=int, default=832), InputParam("num_frames", type_hint=int, default=81), InputParam("batch_size"), InputParam("num_videos_per_prompt", type_hint=int, default=1), InputParam("latents", type_hint=Optional[torch.Tensor]), InputParam("generator"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process.", ), OutputParam( "condition", type_hint=torch.Tensor, description="Conditioning latents for the denoising process.", ), ] @staticmethod def check_inputs(height, width): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @staticmethod def prepare_latents( components, image: PipelineImageInput, batch_size: int, num_channels_latents: int = 16, height: int = 480, width: int = 832, num_frames: int = 81, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // components.vae_scale_factor_temporal + 1 latent_height = height // components.vae_scale_factor_spatial latent_width = width // components.vae_scale_factor_spatial shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device=device, dtype=dtype) image = image.unsqueeze(2) video_condition = torch.cat( [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 ) video_condition = video_condition.to(device=device, dtype=dtype) latents_mean = ( torch.tensor(components.vae.config.latents_mean) .view(1, components.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( 1, components.vae.config.z_dim, 1, 1, 1 ).to(latents.device, latents.dtype) if isinstance(generator, list): latent_condition = [ retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") for _ in generator ] latent_condition = torch.cat(latent_condition) else: latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) latent_condition = (latent_condition - latents_mean) * latents_std mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size[:, :, list(range(1, num_frames))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave( first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal ) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) mask_lat_size = mask_lat_size.view( batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width ) mask_lat_size = mask_lat_size.transpose(1, 2) mask_lat_size = mask_lat_size.to(latent_condition.device) return latents, torch.concat([mask_lat_size, latent_condition], dim=1) @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) self.check_inputs(block_state.height, block_state.width) block_state.device = components._execution_device block_state.num_channels_latents = components.num_channels_latents batch_size = block_state.batch_size * block_state.num_videos_per_prompt block_state.latents, block_state.condition = self.prepare_latents( components, block_state.processed_image, batch_size, block_state.num_channels_latents, block_state.height, block_state.width, block_state.num_frames, torch.bfloat16, block_state.device, block_state.generator, block_state.latents, ) self.set_block_state(state, block_state) return components, state