# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ TODO: need to implement temporal reasoning: https://huggingface.co/spaces/nvidia/ChronoEdit/blob/main/chronoedit_diffusers/pipeline_chronoedit.py """ from diffusers.modular_pipelines import ( ModularPipelineBlocks, ComponentSpec, BlockState, PipelineState, ModularPipeline, InputParam, LoopSequentialPipelineBlocks, ) from diffusers.configuration_utils import FrozenDict from diffusers.guiders import ClassifierFreeGuidance from typing import List from diffusers import AutoModel, UniPCMultistepScheduler import torch from diffusers.modular_pipelines.wan.denoise import WanLoopAfterDenoiser, WanDenoiseLoopWrapper class ChronoEditLoopBeforeDenoiser(ModularPipelineBlocks): model_name = "chronoedit" @property def inputs(self) -> List[InputParam]: return [ InputParam( "latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "condition", required=True, type_hint=torch.Tensor, description="The conditioning latents to use for the denoising process. Can be generated in prepare_latent step.", ), ] @torch.no_grad() def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): latent_model_input = torch.cat([block_state.latents, block_state.condition], dim=1) block_state.latent_model_input = latent_model_input.to(block_state.latents.dtype) block_state.timestep = t.expand(block_state.latents.shape[0]) return components, block_state class ChronoEditLoopDenoiser(ModularPipelineBlocks): model_name = "chronoedit" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 1.0}), default_creation_method="from_config", ), ComponentSpec("transformer", AutoModel), ] @property def inputs(self) -> List[InputParam]: return [ InputParam("attention_kwargs"), InputParam( "latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "condition", required=True, type_hint=torch.Tensor, description="The conditioning latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "image_embeds", required=True, type_hint=torch.Tensor, description="The conditioning image embeddings to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( kwargs_type="denoiser_input_fields", description=( "All conditional model inputs that need to be prepared with guider. " "It should contain prompt_embeds/negative_prompt_embeds. " "Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" ), ), ] @torch.no_grad() def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor) -> PipelineState: # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) guider_inputs = { "prompt_embeds": ( getattr(block_state, "prompt_embeds", None), getattr(block_state, "negative_prompt_embeds", None), ), } components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) # run the denoiser for each guidance batch for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} prompt_embeds = cond_kwargs.pop("prompt_embeds") # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latent_model_input, timestep=block_state.timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=block_state.image_embeds, attention_kwargs=block_state.attention_kwargs, return_dict=False, )[0] components.guider.cleanup_models(components.transformer) # Perform guidance block_state.noise_pred = components.guider(guider_state)[0] return components, block_state class ChronoEditDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "chronoedit" @property def loop_expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 1.0}), default_creation_method="from_config", ), ComponentSpec("scheduler", UniPCMultistepScheduler), ComponentSpec("transformer", AutoModel), ] @property def loop_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.num_warmup_steps = max( len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 ) with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) if i == len(block_state.timesteps) - 1 or ( (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 ): progress_bar.update() self.set_block_state(state, block_state) return components, state class ChronoEditLoopAfterDenoiser(WanLoopAfterDenoiser): model_name = "chronoedit" class ChronoEditDenoiseStep(ChronoEditDenoiseLoopWrapper): block_classes = [ChronoEditLoopBeforeDenoiser, ChronoEditLoopDenoiser, ChronoEditLoopAfterDenoiser] block_names = ["before_denoiser", "denoiser", "after_denoiser"]