# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, List from diffusers.modular_pipelines import ( ModularPipelineBlocks, ComponentSpec, InputParam, OutputParam, ModularPipeline, PipelineState, ) from diffusers.guiders import ClassifierFreeGuidance from transformers import UMT5EncoderModel, AutoTokenizer from diffusers.image_processor import PipelineImageInput import torch from diffusers.modular_pipelines.wan.encoders import WanTextEncoderStep from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from diffusers.video_processor import VideoProcessor from diffusers.configuration_utils import FrozenDict class ChronoEditImageEncoderStep(ModularPipelineBlocks): model_name = "chronoedit" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("image_processor", CLIPImageProcessor), ComponentSpec("image_encoder", CLIPVisionModelWithProjection), ] @property def inputs(self) -> List[InputParam]: return [InputParam("image", type_hint=PipelineImageInput)] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "image_embeds", type_hint=torch.Tensor, description="Image embeddings to use as conditions during the denoising process.", ) ] @staticmethod def encode_image(components, image: PipelineImageInput, device: Optional[torch.device] = None): device = device or components.image_encoder.device image = components.image_processor(images=image, return_tensors="pt").to(device) image_embeds = components.image_encoder(**image, output_hidden_states=True) return image_embeds.hidden_states[-2] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.image_embeds = self.encode_image(components, block_state.image, components._execution_device) self.set_block_state(state, block_state) return components, state class ChronoEditProcessImageStep(ModularPipelineBlocks): model_name = "chronoedit" @property def inputs(self) -> List[InputParam]: return [ InputParam("image", type_hint=PipelineImageInput), InputParam("image_embeds", type_hint=torch.Tensor, required=False), InputParam("batch_size", type_hint=int, required=False), InputParam("height", type_hint=int), InputParam("width", type_hint=int), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("processed_image", type_hint=PipelineImageInput), OutputParam("image_embeds", type_hint=torch.Tensor) ] @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( "video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config", ) ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) image = block_state.image device = components._execution_device block_state.processed_image = components.video_processor.preprocess( image, height=block_state.height, width=block_state.width ).to(device, dtype=torch.bfloat16) if block_state.image_embeds is not None: image_embeds = block_state.image_embeds batch_size = block_state.batch_size block_state.image_embeds = image_embeds.repeat(batch_size, 1, 1).to(torch.bfloat16) self.set_block_state(state, block_state) return components, state # Configure CFG with a guidance scale of 1. class ChronoEditTextEncoderStep(WanTextEncoderStep): model_name = "chronoedit" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("text_encoder", UMT5EncoderModel), ComponentSpec("tokenizer", AutoTokenizer), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 1.0}), default_creation_method="from_config", ), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: # Get inputs and intermediates block_state = self.get_block_state(state) self.check_inputs(block_state) block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 block_state.device = components._execution_device block_state.negative_prompt_embeds = None # Encode input prompt ( block_state.prompt_embeds, block_state.negative_prompt_embeds, ) = self.encode_prompt( components, block_state.prompt, block_state.device, 1, block_state.prepare_unconditional_embeds, block_state.negative_prompt, prompt_embeds=None, negative_prompt_embeds=block_state.negative_prompt_embeds, ) # Add outputs self.set_block_state(state, block_state) return components, state