# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Tuple import torch from diffusers.configuration_utils import FrozenDict from diffusers.guiders import ClassifierFreeGuidance from diffusers.models import AutoModel, WanTransformer3DModel from diffusers.schedulers import UniPCMultistepScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.modular_pipelines import ( BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ModularPipeline ) from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) # pylint: disable=invalid-name class MatrixGameWanLoopDenoiser(ModularPipelineBlocks): model_name = "MatrixGameWan" frame_seq_length = 880 @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 5.0}), default_creation_method="from_config", ), ComponentSpec("transformer", AutoModel), ] @property def description(self) -> str: return ( "Step within the denoising loop that denoise the latents with guidance. " "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `MatrixGameWanDenoiseLoopWrapper`)" ) @property def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("attention_kwargs"), InputParam( "latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( "image_mask_latents", required=True, type_hint=torch.Tensor, ), InputParam( "image_embeds", required=True, type_hint=torch.Tensor, ), InputParam( "keyboard_conditions", required=True, type_hint=torch.Tensor, ), InputParam( "mouse_conditions", required=True, type_hint=torch.Tensor, ), InputParam( "num_inference_steps", required=True, type_hint=int, default=4, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( kwargs_type="guider_input_fields", description=( "All conditional model inputs that need to be prepared with guider. " "It should contain prompt_embeds/negative_prompt_embeds. " "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" ), ), ] @torch.no_grad() def __call__( self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: cond_concat = block_state.image_mask_latents keyboard_conditions = block_state.keyboard_conditions mouse_conditions = block_state.mouse_conditions visual_context = block_state.image_embeds transformer_dtype = components.transformer.dtype components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) # Prepare mini‐batches according to guidance method and `guider_input_fields` # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. # e.g. for CFG, we prepare two batches: one for uncond, one for cond # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds guider_state = components.guider.prepare_inputs(block_state, {}) # run the denoiser for each guidance batch for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = guider_state_batch.as_dict() # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.transformer( x=block_state.latents.to(transformer_dtype), t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block), visual_context=visual_context.to(transformer_dtype), cond_concat=cond_concat.to(transformer_dtype), keyboard_cond=keyboard_conditions, mouse_cond=mouse_conditions, kv_cache=block_state.kv_cache, kv_cache_mouse=block_state.kv_cache_mouse, kv_cache_keyboard=block_state.kv_cache_keyboard, crossattn_cache=block_state.kv_cache_cross_attn, current_start=block_state.current_frame_idx * self.frame_seq_length, num_frames_per_block=block_state.num_frames_per_block, )[0] components.guider.cleanup_models(components.transformer) # Perform guidance block_state.noise_pred = components.guider(guider_state)[0] return components, block_state class MatrixGameWanLoopAfterDenoiser(ModularPipelineBlocks): model_name = "MatrixGameWan" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", UniPCMultistepScheduler), ] @property def description(self) -> str: return ( "step within the denoising loop that update the latents. " "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `MatrixGameWanDenoiseLoopWrapper`)" ) @property def inputs(self) -> List[Tuple[str, Any]]: return [] @property def intermediate_inputs(self) -> List[str]: return [ InputParam("generator"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] @torch.no_grad() def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): # Perform scheduler step using the predicted output latents_dtype = block_state.latents.dtype step_index = components.scheduler.index_for_timestep(t) sigma_t = components.scheduler.sigmas[step_index] latents = block_state.latents.double() - sigma_t.double() * block_state.noise_pred.double() block_state.latents = latents if block_state.latents.dtype != latents_dtype: block_state.latents = block_state.latents.to(latents_dtype) return components, block_state class MatrixGameWanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "MatrixGameWan" frame_seq_length = 880 local_attn_size = 6 num_transformer_blocks = 30 def _initialize_kv_cache(self, batch_size, dtype, device): """ Initialize a Per-GPU KV cache for the Wan model. """ cache = [] if self.local_attn_size != -1: # Use the local attention size to compute the KV cache size kv_cache_size = self.local_attn_size * self.frame_seq_length else: # Use the default KV cache size kv_cache_size = 15 * 1 * self.frame_seq_length # 32760 for _ in range(self.num_transformer_blocks): cache.append({ "k": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device), "v": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device), "global_end_index": torch.tensor([0], dtype=torch.long, device=device), "local_end_index": torch.tensor([0], dtype=torch.long, device=device) }) return cache # always store the clean cache def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device): """ Initialize a Per-GPU KV cache for the Wan model. """ kv_cache_mouse = [] kv_cache_keyboard = [] if self.local_attn_size != -1: kv_cache_size = self.local_attn_size else: kv_cache_size = 15 * 1 for _ in range(self.num_transformer_blocks): kv_cache_keyboard.append({ "k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device), "v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device), "global_end_index": torch.tensor([0], dtype=torch.long, device=device), "local_end_index": torch.tensor([0], dtype=torch.long, device=device) }) kv_cache_mouse.append({ "k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device), "v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device), "global_end_index": torch.tensor([0], dtype=torch.long, device=device), "local_end_index": torch.tensor([0], dtype=torch.long, device=device) }) return kv_cache_mouse, kv_cache_keyboard # always store the clean cache def _initialize_crossattn_cache(self, batch_size, dtype, device): """ Initialize a Per-GPU cross-attention cache for the Wan model. """ crossattn_cache = [] for _ in range(self.num_transformer_blocks): crossattn_cache.append({ "k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device), "v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device), "is_init": False }) return crossattn_cache @property def description(self) -> str: return ( "Pipeline block that iteratively denoise the latents over `timesteps`. " "The specific steps with each iteration can be customized with `sub_blocks` attributes" ) @property def loop_expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 5.0}), default_creation_method="from_config", ), ComponentSpec("scheduler", UniPCMultistepScheduler), ComponentSpec("transformer", AutoModel), ] @property def loop_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), InputParam( "num_frames_per_block", required=True, type_hint=int, default=3, ), ] @torch.no_grad() def __call__( self, components: ModularPipeline, state: PipelineState ) -> PipelineState: block_state = self.get_block_state(state) transformer_dtype = components.transformer.dtype num_frames_per_block = block_state.num_frames_per_block latents = block_state.latents.to(transformer_dtype) image_mask_latents = block_state.image_mask_latents.to(transformer_dtype) mouse_conditions = block_state.mouse_conditions.unsqueeze(0).to(transformer_dtype) keyboard_conditions = block_state.keyboard_conditions.unsqueeze(0).to(transformer_dtype) visual_context = block_state.image_embeds batch_size, num_channels, num_frames, height, width = latents.shape output = torch.zeros( (batch_size, num_channels, num_frames, height, width), device=latents.device, dtype=latents.dtype, ) current_frame_idx = 0 num_blocks = num_frames // num_frames_per_block kv_cache = self._initialize_kv_cache(batch_size, latents.dtype, latents.device) kv_cache_mouse, kv_cache_keyboard = self._initialize_kv_cache_mouse_and_keyboard(batch_size, latents.dtype, latents.device) kv_cache_cross_attn = self._initialize_crossattn_cache(batch_size, latents.dtype, latents.device) block_state.kv_cache = kv_cache block_state.kv_cache_mouse = kv_cache_mouse block_state.kv_cache_keyboard = kv_cache_keyboard block_state.kv_cache_cross_attn = kv_cache_cross_attn for _ in range(num_blocks): block_state.current_frame_idx = current_frame_idx block_state.image_mask_latents = image_mask_latents[ :, :, current_frame_idx : current_frame_idx + num_frames_per_block ] cond_idx = 1 + 4 * (current_frame_idx + num_frames_per_block - 1) block_state.mouse_conditions = mouse_conditions[:, :cond_idx] block_state.keyboard_conditions = keyboard_conditions[:, :cond_idx] block_state.latents = latents[ :, :, current_frame_idx : current_frame_idx + num_frames_per_block ] for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step( components, block_state, i=i, t=t ) if i < (block_state.num_inference_steps - 1): t1 = components.scheduler.timesteps[i+1] block_state.latents = components.scheduler.add_noise( block_state.latents, randn_tensor( block_state.latents.shape, device=block_state.latents.device, dtype=block_state.latents.dtype ), t1.expand(block_state.latents.shape[0]) ) output[ :, :, current_frame_idx : current_frame_idx + num_frames_per_block ] = block_state.latents components.transformer( x=block_state.latents, t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block) * 0.0, visual_context=visual_context, cond_concat=block_state.image_mask_latents, keyboard_cond=block_state.keyboard_conditions, mouse_cond=block_state.mouse_conditions, kv_cache=block_state.kv_cache, kv_cache_mouse=block_state.kv_cache_mouse, kv_cache_keyboard=block_state.kv_cache_keyboard, crossattn_cache=block_state.kv_cache_cross_attn, current_start=block_state.current_frame_idx * self.frame_seq_length, num_frames_per_block=block_state.num_frames_per_block, )[0] current_frame_idx += num_frames_per_block block_state.latents = output self.set_block_state(state, block_state) return components, state class MatrixGameWanDenoiseStep(MatrixGameWanDenoiseLoopWrapper): block_classes = [ MatrixGameWanLoopDenoiser, MatrixGameWanLoopAfterDenoiser, ] block_names = ["denoiser", "after_denoiser"] @property def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `MatrixGameWanDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `MatrixGameWanLoopDenoiser`\n" " - `MatrixGameWanLoopAfterDenoiser`\n" "This block supports both text2vid tasks." )