|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, List, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
from diffusers.guiders import ClassifierFreeGuidance |
|
|
from diffusers.models import AutoModel, WanTransformer3DModel |
|
|
from diffusers.schedulers import UniPCMultistepScheduler |
|
|
from diffusers.utils import logging |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from diffusers.modular_pipelines import ( |
|
|
BlockState, |
|
|
LoopSequentialPipelineBlocks, |
|
|
ModularPipelineBlocks, |
|
|
PipelineState, |
|
|
ModularPipeline |
|
|
) |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class MatrixGameWanLoopDenoiser(ModularPipelineBlocks): |
|
|
model_name = "MatrixGameWan" |
|
|
frame_seq_length = 880 |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec( |
|
|
"guider", |
|
|
ClassifierFreeGuidance, |
|
|
config=FrozenDict({"guidance_scale": 5.0}), |
|
|
default_creation_method="from_config", |
|
|
), |
|
|
ComponentSpec("transformer", AutoModel), |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Step within the denoising loop that denoise the latents with guidance. " |
|
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
|
|
"object (e.g. `MatrixGameWanDenoiseLoopWrapper`)" |
|
|
) |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[Tuple[str, Any]]: |
|
|
return [ |
|
|
InputParam("attention_kwargs"), |
|
|
InputParam( |
|
|
"latents", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", |
|
|
), |
|
|
InputParam( |
|
|
"image_mask_latents", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
InputParam( |
|
|
"image_embeds", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
InputParam( |
|
|
"keyboard_conditions", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
InputParam( |
|
|
"mouse_conditions", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
), |
|
|
InputParam( |
|
|
"num_inference_steps", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
default=4, |
|
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", |
|
|
), |
|
|
InputParam( |
|
|
kwargs_type="guider_input_fields", |
|
|
description=( |
|
|
"All conditional model inputs that need to be prepared with guider. " |
|
|
"It should contain prompt_embeds/negative_prompt_embeds. " |
|
|
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" |
|
|
), |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor |
|
|
) -> PipelineState: |
|
|
cond_concat = block_state.image_mask_latents |
|
|
keyboard_conditions = block_state.keyboard_conditions |
|
|
mouse_conditions = block_state.mouse_conditions |
|
|
visual_context = block_state.image_embeds |
|
|
|
|
|
transformer_dtype = components.transformer.dtype |
|
|
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
guider_state = components.guider.prepare_inputs(block_state, {}) |
|
|
|
|
|
|
|
|
for guider_state_batch in guider_state: |
|
|
components.guider.prepare_models(components.transformer) |
|
|
cond_kwargs = guider_state_batch.as_dict() |
|
|
|
|
|
|
|
|
|
|
|
guider_state_batch.noise_pred = components.transformer( |
|
|
x=block_state.latents.to(transformer_dtype), |
|
|
t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block), |
|
|
visual_context=visual_context.to(transformer_dtype), |
|
|
cond_concat=cond_concat.to(transformer_dtype), |
|
|
keyboard_cond=keyboard_conditions, |
|
|
mouse_cond=mouse_conditions, |
|
|
kv_cache=block_state.kv_cache, |
|
|
kv_cache_mouse=block_state.kv_cache_mouse, |
|
|
kv_cache_keyboard=block_state.kv_cache_keyboard, |
|
|
crossattn_cache=block_state.kv_cache_cross_attn, |
|
|
current_start=block_state.current_frame_idx * self.frame_seq_length, |
|
|
num_frames_per_block=block_state.num_frames_per_block, |
|
|
)[0] |
|
|
components.guider.cleanup_models(components.transformer) |
|
|
|
|
|
|
|
|
block_state.noise_pred = components.guider(guider_state)[0] |
|
|
|
|
|
return components, block_state |
|
|
|
|
|
|
|
|
class MatrixGameWanLoopAfterDenoiser(ModularPipelineBlocks): |
|
|
model_name = "MatrixGameWan" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("scheduler", UniPCMultistepScheduler), |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"step within the denoising loop that update the latents. " |
|
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
|
|
"object (e.g. `MatrixGameWanDenoiseLoopWrapper`)" |
|
|
) |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[Tuple[str, Any]]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def intermediate_inputs(self) -> List[str]: |
|
|
return [ |
|
|
InputParam("generator"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): |
|
|
|
|
|
latents_dtype = block_state.latents.dtype |
|
|
|
|
|
step_index = components.scheduler.index_for_timestep(t) |
|
|
sigma_t = components.scheduler.sigmas[step_index] |
|
|
|
|
|
latents = block_state.latents.double() - sigma_t.double() * block_state.noise_pred.double() |
|
|
block_state.latents = latents |
|
|
|
|
|
if block_state.latents.dtype != latents_dtype: |
|
|
block_state.latents = block_state.latents.to(latents_dtype) |
|
|
|
|
|
return components, block_state |
|
|
|
|
|
|
|
|
class MatrixGameWanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): |
|
|
model_name = "MatrixGameWan" |
|
|
frame_seq_length = 880 |
|
|
local_attn_size = 6 |
|
|
num_transformer_blocks = 30 |
|
|
|
|
|
def _initialize_kv_cache(self, batch_size, dtype, device): |
|
|
""" |
|
|
Initialize a Per-GPU KV cache for the Wan model. |
|
|
""" |
|
|
cache = [] |
|
|
if self.local_attn_size != -1: |
|
|
|
|
|
kv_cache_size = self.local_attn_size * self.frame_seq_length |
|
|
else: |
|
|
|
|
|
kv_cache_size = 15 * 1 * self.frame_seq_length |
|
|
|
|
|
for _ in range(self.num_transformer_blocks): |
|
|
cache.append({ |
|
|
"k": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device), |
|
|
"v": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device), |
|
|
"global_end_index": torch.tensor([0], dtype=torch.long, device=device), |
|
|
"local_end_index": torch.tensor([0], dtype=torch.long, device=device) |
|
|
}) |
|
|
|
|
|
return cache |
|
|
|
|
|
def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device): |
|
|
""" |
|
|
Initialize a Per-GPU KV cache for the Wan model. |
|
|
""" |
|
|
kv_cache_mouse = [] |
|
|
kv_cache_keyboard = [] |
|
|
if self.local_attn_size != -1: |
|
|
kv_cache_size = self.local_attn_size |
|
|
else: |
|
|
kv_cache_size = 15 * 1 |
|
|
for _ in range(self.num_transformer_blocks): |
|
|
kv_cache_keyboard.append({ |
|
|
"k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device), |
|
|
"v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device), |
|
|
"global_end_index": torch.tensor([0], dtype=torch.long, device=device), |
|
|
"local_end_index": torch.tensor([0], dtype=torch.long, device=device) |
|
|
}) |
|
|
kv_cache_mouse.append({ |
|
|
"k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device), |
|
|
"v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device), |
|
|
"global_end_index": torch.tensor([0], dtype=torch.long, device=device), |
|
|
"local_end_index": torch.tensor([0], dtype=torch.long, device=device) |
|
|
}) |
|
|
return kv_cache_mouse, kv_cache_keyboard |
|
|
|
|
|
def _initialize_crossattn_cache(self, batch_size, dtype, device): |
|
|
""" |
|
|
Initialize a Per-GPU cross-attention cache for the Wan model. |
|
|
""" |
|
|
crossattn_cache = [] |
|
|
|
|
|
for _ in range(self.num_transformer_blocks): |
|
|
crossattn_cache.append({ |
|
|
"k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device), |
|
|
"v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device), |
|
|
"is_init": False |
|
|
}) |
|
|
|
|
|
return crossattn_cache |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Pipeline block that iteratively denoise the latents over `timesteps`. " |
|
|
"The specific steps with each iteration can be customized with `sub_blocks` attributes" |
|
|
) |
|
|
|
|
|
@property |
|
|
def loop_expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec( |
|
|
"guider", |
|
|
ClassifierFreeGuidance, |
|
|
config=FrozenDict({"guidance_scale": 5.0}), |
|
|
default_creation_method="from_config", |
|
|
), |
|
|
ComponentSpec("scheduler", UniPCMultistepScheduler), |
|
|
ComponentSpec("transformer", AutoModel), |
|
|
] |
|
|
|
|
|
@property |
|
|
def loop_inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"timesteps", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", |
|
|
), |
|
|
InputParam( |
|
|
"num_inference_steps", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", |
|
|
), |
|
|
InputParam( |
|
|
"num_frames_per_block", |
|
|
required=True, |
|
|
type_hint=int, |
|
|
default=3, |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
transformer_dtype = components.transformer.dtype |
|
|
|
|
|
num_frames_per_block = block_state.num_frames_per_block |
|
|
latents = block_state.latents.to(transformer_dtype) |
|
|
image_mask_latents = block_state.image_mask_latents.to(transformer_dtype) |
|
|
mouse_conditions = block_state.mouse_conditions.unsqueeze(0).to(transformer_dtype) |
|
|
keyboard_conditions = block_state.keyboard_conditions.unsqueeze(0).to(transformer_dtype) |
|
|
visual_context = block_state.image_embeds |
|
|
|
|
|
batch_size, num_channels, num_frames, height, width = latents.shape |
|
|
output = torch.zeros( |
|
|
(batch_size, num_channels, num_frames, height, width), |
|
|
device=latents.device, |
|
|
dtype=latents.dtype, |
|
|
) |
|
|
|
|
|
current_frame_idx = 0 |
|
|
num_blocks = num_frames // num_frames_per_block |
|
|
|
|
|
kv_cache = self._initialize_kv_cache(batch_size, latents.dtype, latents.device) |
|
|
kv_cache_mouse, kv_cache_keyboard = self._initialize_kv_cache_mouse_and_keyboard(batch_size, latents.dtype, latents.device) |
|
|
kv_cache_cross_attn = self._initialize_crossattn_cache(batch_size, latents.dtype, latents.device) |
|
|
|
|
|
block_state.kv_cache = kv_cache |
|
|
block_state.kv_cache_mouse = kv_cache_mouse |
|
|
block_state.kv_cache_keyboard = kv_cache_keyboard |
|
|
block_state.kv_cache_cross_attn = kv_cache_cross_attn |
|
|
|
|
|
for _ in range(num_blocks): |
|
|
block_state.current_frame_idx = current_frame_idx |
|
|
block_state.image_mask_latents = image_mask_latents[ |
|
|
:, :, current_frame_idx : current_frame_idx + num_frames_per_block |
|
|
] |
|
|
cond_idx = 1 + 4 * (current_frame_idx + num_frames_per_block - 1) |
|
|
block_state.mouse_conditions = mouse_conditions[:, :cond_idx] |
|
|
block_state.keyboard_conditions = keyboard_conditions[:, :cond_idx] |
|
|
|
|
|
block_state.latents = latents[ |
|
|
:, :, current_frame_idx : current_frame_idx + num_frames_per_block |
|
|
] |
|
|
for i, t in enumerate(block_state.timesteps): |
|
|
components, block_state = self.loop_step( |
|
|
components, block_state, i=i, t=t |
|
|
) |
|
|
|
|
|
if i < (block_state.num_inference_steps - 1): |
|
|
t1 = components.scheduler.timesteps[i+1] |
|
|
block_state.latents = components.scheduler.add_noise( |
|
|
block_state.latents, |
|
|
randn_tensor( |
|
|
block_state.latents.shape, |
|
|
device=block_state.latents.device, |
|
|
dtype=block_state.latents.dtype |
|
|
), |
|
|
t1.expand(block_state.latents.shape[0]) |
|
|
) |
|
|
|
|
|
output[ |
|
|
:, :, current_frame_idx : current_frame_idx + num_frames_per_block |
|
|
] = block_state.latents |
|
|
|
|
|
components.transformer( |
|
|
x=block_state.latents, |
|
|
t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block) * 0.0, |
|
|
visual_context=visual_context, |
|
|
cond_concat=block_state.image_mask_latents, |
|
|
keyboard_cond=block_state.keyboard_conditions, |
|
|
mouse_cond=block_state.mouse_conditions, |
|
|
kv_cache=block_state.kv_cache, |
|
|
kv_cache_mouse=block_state.kv_cache_mouse, |
|
|
kv_cache_keyboard=block_state.kv_cache_keyboard, |
|
|
crossattn_cache=block_state.kv_cache_cross_attn, |
|
|
current_start=block_state.current_frame_idx * self.frame_seq_length, |
|
|
num_frames_per_block=block_state.num_frames_per_block, |
|
|
)[0] |
|
|
current_frame_idx += num_frames_per_block |
|
|
|
|
|
block_state.latents = output |
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class MatrixGameWanDenoiseStep(MatrixGameWanDenoiseLoopWrapper): |
|
|
block_classes = [ |
|
|
MatrixGameWanLoopDenoiser, |
|
|
MatrixGameWanLoopAfterDenoiser, |
|
|
] |
|
|
block_names = ["denoiser", "after_denoiser"] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Denoise step that iteratively denoise the latents. \n" |
|
|
"Its loop logic is defined in `MatrixGameWanDenoiseLoopWrapper.__call__` method \n" |
|
|
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" |
|
|
" - `MatrixGameWanLoopDenoiser`\n" |
|
|
" - `MatrixGameWanLoopAfterDenoiser`\n" |
|
|
"This block supports both text2vid tasks." |
|
|
) |
|
|
|