matrix-game-2-modular / denoise.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
5178ef1 verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.models import AutoModel, WanTransformer3DModel
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.modular_pipelines import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
ModularPipeline
)
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class MatrixGameWanLoopDenoiser(ModularPipelineBlocks):
model_name = "MatrixGameWan"
frame_seq_length = 880
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", AutoModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `MatrixGameWanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"image_mask_latents",
required=True,
type_hint=torch.Tensor,
),
InputParam(
"image_embeds",
required=True,
type_hint=torch.Tensor,
),
InputParam(
"keyboard_conditions",
required=True,
type_hint=torch.Tensor,
),
InputParam(
"mouse_conditions",
required=True,
type_hint=torch.Tensor,
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
default=4,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
cond_concat = block_state.image_mask_latents
keyboard_conditions = block_state.keyboard_conditions
mouse_conditions = block_state.mouse_conditions
visual_context = block_state.image_embeds
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare mini‐batches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, {})
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
x=block_state.latents.to(transformer_dtype),
t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block),
visual_context=visual_context.to(transformer_dtype),
cond_concat=cond_concat.to(transformer_dtype),
keyboard_cond=keyboard_conditions,
mouse_cond=mouse_conditions,
kv_cache=block_state.kv_cache,
kv_cache_mouse=block_state.kv_cache_mouse,
kv_cache_keyboard=block_state.kv_cache_keyboard,
crossattn_cache=block_state.kv_cache_cross_attn,
current_start=block_state.current_frame_idx * self.frame_seq_length,
num_frames_per_block=block_state.num_frames_per_block,
)[0]
components.guider.cleanup_models(components.transformer)
# Perform guidance
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class MatrixGameWanLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "MatrixGameWan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `MatrixGameWanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@torch.no_grad()
def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# Perform scheduler step using the predicted output
latents_dtype = block_state.latents.dtype
step_index = components.scheduler.index_for_timestep(t)
sigma_t = components.scheduler.sigmas[step_index]
latents = block_state.latents.double() - sigma_t.double() * block_state.noise_pred.double()
block_state.latents = latents
if block_state.latents.dtype != latents_dtype:
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class MatrixGameWanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "MatrixGameWan"
frame_seq_length = 880
local_attn_size = 6
num_transformer_blocks = 30
def _initialize_kv_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
cache = []
if self.local_attn_size != -1:
# Use the local attention size to compute the KV cache size
kv_cache_size = self.local_attn_size * self.frame_seq_length
else:
# Use the default KV cache size
kv_cache_size = 15 * 1 * self.frame_seq_length # 32760
for _ in range(self.num_transformer_blocks):
cache.append({
"k": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device),
"v": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
return cache # always store the clean cache
def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache_mouse = []
kv_cache_keyboard = []
if self.local_attn_size != -1:
kv_cache_size = self.local_attn_size
else:
kv_cache_size = 15 * 1
for _ in range(self.num_transformer_blocks):
kv_cache_keyboard.append({
"k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
"v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
kv_cache_mouse.append({
"k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
"v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
return kv_cache_mouse, kv_cache_keyboard # always store the clean cache
def _initialize_crossattn_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU cross-attention cache for the Wan model.
"""
crossattn_cache = []
for _ in range(self.num_transformer_blocks):
crossattn_cache.append({
"k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
"is_init": False
})
return crossattn_cache
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("scheduler", UniPCMultistepScheduler),
ComponentSpec("transformer", AutoModel),
]
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_frames_per_block",
required=True,
type_hint=int,
default=3,
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
transformer_dtype = components.transformer.dtype
num_frames_per_block = block_state.num_frames_per_block
latents = block_state.latents.to(transformer_dtype)
image_mask_latents = block_state.image_mask_latents.to(transformer_dtype)
mouse_conditions = block_state.mouse_conditions.unsqueeze(0).to(transformer_dtype)
keyboard_conditions = block_state.keyboard_conditions.unsqueeze(0).to(transformer_dtype)
visual_context = block_state.image_embeds
batch_size, num_channels, num_frames, height, width = latents.shape
output = torch.zeros(
(batch_size, num_channels, num_frames, height, width),
device=latents.device,
dtype=latents.dtype,
)
current_frame_idx = 0
num_blocks = num_frames // num_frames_per_block
kv_cache = self._initialize_kv_cache(batch_size, latents.dtype, latents.device)
kv_cache_mouse, kv_cache_keyboard = self._initialize_kv_cache_mouse_and_keyboard(batch_size, latents.dtype, latents.device)
kv_cache_cross_attn = self._initialize_crossattn_cache(batch_size, latents.dtype, latents.device)
block_state.kv_cache = kv_cache
block_state.kv_cache_mouse = kv_cache_mouse
block_state.kv_cache_keyboard = kv_cache_keyboard
block_state.kv_cache_cross_attn = kv_cache_cross_attn
for _ in range(num_blocks):
block_state.current_frame_idx = current_frame_idx
block_state.image_mask_latents = image_mask_latents[
:, :, current_frame_idx : current_frame_idx + num_frames_per_block
]
cond_idx = 1 + 4 * (current_frame_idx + num_frames_per_block - 1)
block_state.mouse_conditions = mouse_conditions[:, :cond_idx]
block_state.keyboard_conditions = keyboard_conditions[:, :cond_idx]
block_state.latents = latents[
:, :, current_frame_idx : current_frame_idx + num_frames_per_block
]
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(
components, block_state, i=i, t=t
)
if i < (block_state.num_inference_steps - 1):
t1 = components.scheduler.timesteps[i+1]
block_state.latents = components.scheduler.add_noise(
block_state.latents,
randn_tensor(
block_state.latents.shape,
device=block_state.latents.device,
dtype=block_state.latents.dtype
),
t1.expand(block_state.latents.shape[0])
)
output[
:, :, current_frame_idx : current_frame_idx + num_frames_per_block
] = block_state.latents
components.transformer(
x=block_state.latents,
t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block) * 0.0,
visual_context=visual_context,
cond_concat=block_state.image_mask_latents,
keyboard_cond=block_state.keyboard_conditions,
mouse_cond=block_state.mouse_conditions,
kv_cache=block_state.kv_cache,
kv_cache_mouse=block_state.kv_cache_mouse,
kv_cache_keyboard=block_state.kv_cache_keyboard,
crossattn_cache=block_state.kv_cache_cross_attn,
current_start=block_state.current_frame_idx * self.frame_seq_length,
num_frames_per_block=block_state.num_frames_per_block,
)[0]
current_frame_idx += num_frames_per_block
block_state.latents = output
self.set_block_state(state, block_state)
return components, state
class MatrixGameWanDenoiseStep(MatrixGameWanDenoiseLoopWrapper):
block_classes = [
MatrixGameWanLoopDenoiser,
MatrixGameWanLoopAfterDenoiser,
]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `MatrixGameWanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `MatrixGameWanLoopDenoiser`\n"
" - `MatrixGameWanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
)