File size: 17,442 Bytes
5178ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Tuple

import torch

from diffusers.configuration_utils import FrozenDict
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.models import AutoModel, WanTransformer3DModel
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.modular_pipelines import (
    BlockState,
    LoopSequentialPipelineBlocks,
    ModularPipelineBlocks,
    PipelineState,
    ModularPipeline
)
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class MatrixGameWanLoopDenoiser(ModularPipelineBlocks):
    model_name = "MatrixGameWan"
    frame_seq_length = 880

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec(
                "guider",
                ClassifierFreeGuidance,
                config=FrozenDict({"guidance_scale": 5.0}),
                default_creation_method="from_config",
            ),
            ComponentSpec("transformer", AutoModel),
        ]

    @property
    def description(self) -> str:
        return (
            "Step within the denoising loop that denoise the latents with guidance. "
            "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
            "object (e.g. `MatrixGameWanDenoiseLoopWrapper`)"
        )

    @property
    def inputs(self) -> List[Tuple[str, Any]]:
        return [
            InputParam("attention_kwargs"),
            InputParam(
                "latents",
                required=True,
                type_hint=torch.Tensor,
                description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
            ),
            InputParam(
                "image_mask_latents",
                required=True,
                type_hint=torch.Tensor,
            ),
            InputParam(
                "image_embeds",
                required=True,
                type_hint=torch.Tensor,
            ),
            InputParam(
                "keyboard_conditions",
                required=True,
                type_hint=torch.Tensor,
            ),
            InputParam(
                "mouse_conditions",
                required=True,
                type_hint=torch.Tensor,
            ),
            InputParam(
                "num_inference_steps",
                required=True,
                type_hint=int,
                default=4,
                description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
            ),
            InputParam(
                kwargs_type="guider_input_fields",
                description=(
                    "All conditional model inputs that need to be prepared with guider. "
                    "It should contain prompt_embeds/negative_prompt_embeds. "
                    "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
                ),
            ),
        ]

    @torch.no_grad()
    def __call__(
        self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
    ) -> PipelineState:
        cond_concat = block_state.image_mask_latents
        keyboard_conditions = block_state.keyboard_conditions
        mouse_conditions = block_state.mouse_conditions
        visual_context = block_state.image_embeds

        transformer_dtype = components.transformer.dtype
        components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)

        # Prepare mini‐batches according to guidance method and `guider_input_fields`
        # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
        # e.g. for CFG, we prepare two batches: one for uncond, one for cond
        # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
        # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
        guider_state = components.guider.prepare_inputs(block_state, {})

        # run the denoiser for each guidance batch
        for guider_state_batch in guider_state:
            components.guider.prepare_models(components.transformer)
            cond_kwargs = guider_state_batch.as_dict()

            # Predict the noise residual
            # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
            guider_state_batch.noise_pred = components.transformer(
                x=block_state.latents.to(transformer_dtype),
                t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block),
                visual_context=visual_context.to(transformer_dtype),
                cond_concat=cond_concat.to(transformer_dtype),
                keyboard_cond=keyboard_conditions,
                mouse_cond=mouse_conditions,
                kv_cache=block_state.kv_cache,
                kv_cache_mouse=block_state.kv_cache_mouse,
                kv_cache_keyboard=block_state.kv_cache_keyboard,
                crossattn_cache=block_state.kv_cache_cross_attn,
                current_start=block_state.current_frame_idx * self.frame_seq_length,
                num_frames_per_block=block_state.num_frames_per_block,
            )[0]
            components.guider.cleanup_models(components.transformer)

        # Perform guidance
        block_state.noise_pred = components.guider(guider_state)[0]

        return components, block_state


class MatrixGameWanLoopAfterDenoiser(ModularPipelineBlocks):
    model_name = "MatrixGameWan"

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec("scheduler", UniPCMultistepScheduler),
        ]

    @property
    def description(self) -> str:
        return (
            "step within the denoising loop that update the latents. "
            "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
            "object (e.g. `MatrixGameWanDenoiseLoopWrapper`)"
        )

    @property
    def inputs(self) -> List[Tuple[str, Any]]:
        return []

    @property
    def intermediate_inputs(self) -> List[str]:
        return [
            InputParam("generator"),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]

    @torch.no_grad()
    def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
        # Perform scheduler step using the predicted output
        latents_dtype = block_state.latents.dtype

        step_index = components.scheduler.index_for_timestep(t)
        sigma_t = components.scheduler.sigmas[step_index]

        latents = block_state.latents.double() - sigma_t.double() * block_state.noise_pred.double()
        block_state.latents = latents

        if block_state.latents.dtype != latents_dtype:
            block_state.latents = block_state.latents.to(latents_dtype)

        return components, block_state


class MatrixGameWanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
    model_name = "MatrixGameWan"
    frame_seq_length = 880
    local_attn_size = 6
    num_transformer_blocks = 30

    def _initialize_kv_cache(self, batch_size, dtype, device):
        """
        Initialize a Per-GPU KV cache for the Wan model.
        """
        cache = []
        if self.local_attn_size != -1:
            # Use the local attention size to compute the KV cache size
            kv_cache_size = self.local_attn_size * self.frame_seq_length
        else:
            # Use the default KV cache size
            kv_cache_size = 15 * 1 * self.frame_seq_length # 32760

        for _ in range(self.num_transformer_blocks):
            cache.append({
                "k": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device),
                "v": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device),
                "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
                "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
            })

        return cache  # always store the clean cache

    def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device):
        """
        Initialize a Per-GPU KV cache for the Wan model.
        """
        kv_cache_mouse = []
        kv_cache_keyboard = []
        if self.local_attn_size != -1:
            kv_cache_size = self.local_attn_size
        else:
            kv_cache_size = 15 * 1
        for _ in range(self.num_transformer_blocks):
            kv_cache_keyboard.append({
                "k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
                "v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
                "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
                "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
            })
            kv_cache_mouse.append({
                "k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
                "v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
                "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
                "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
            })
        return kv_cache_mouse, kv_cache_keyboard  # always store the clean cache

    def _initialize_crossattn_cache(self, batch_size, dtype, device):
        """
        Initialize a Per-GPU cross-attention cache for the Wan model.
        """
        crossattn_cache = []

        for _ in range(self.num_transformer_blocks):
            crossattn_cache.append({
                "k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
                "v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
                "is_init": False
            })

        return crossattn_cache

    @property
    def description(self) -> str:
        return (
            "Pipeline block that iteratively denoise the latents over `timesteps`. "
            "The specific steps with each iteration can be customized with `sub_blocks` attributes"
        )

    @property
    def loop_expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec(
                "guider",
                ClassifierFreeGuidance,
                config=FrozenDict({"guidance_scale": 5.0}),
                default_creation_method="from_config",
            ),
            ComponentSpec("scheduler", UniPCMultistepScheduler),
            ComponentSpec("transformer", AutoModel),
        ]

    @property
    def loop_inputs(self) -> List[InputParam]:
        return [
            InputParam(
                "timesteps",
                required=True,
                type_hint=torch.Tensor,
                description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
            ),
            InputParam(
                "num_inference_steps",
                required=True,
                type_hint=int,
                description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
            ),
            InputParam(
                "num_frames_per_block",
                required=True,
                type_hint=int,
                default=3,
            ),
        ]

    @torch.no_grad()
    def __call__(
        self, components: ModularPipeline, state: PipelineState
    ) -> PipelineState:
        block_state = self.get_block_state(state)
        transformer_dtype = components.transformer.dtype

        num_frames_per_block = block_state.num_frames_per_block
        latents = block_state.latents.to(transformer_dtype)
        image_mask_latents = block_state.image_mask_latents.to(transformer_dtype)
        mouse_conditions = block_state.mouse_conditions.unsqueeze(0).to(transformer_dtype)
        keyboard_conditions = block_state.keyboard_conditions.unsqueeze(0).to(transformer_dtype)
        visual_context = block_state.image_embeds

        batch_size, num_channels, num_frames, height, width = latents.shape
        output = torch.zeros(
            (batch_size, num_channels, num_frames, height, width),
            device=latents.device,
            dtype=latents.dtype,
        )

        current_frame_idx = 0
        num_blocks = num_frames // num_frames_per_block

        kv_cache = self._initialize_kv_cache(batch_size, latents.dtype, latents.device)
        kv_cache_mouse, kv_cache_keyboard = self._initialize_kv_cache_mouse_and_keyboard(batch_size, latents.dtype, latents.device)
        kv_cache_cross_attn = self._initialize_crossattn_cache(batch_size, latents.dtype, latents.device)

        block_state.kv_cache = kv_cache
        block_state.kv_cache_mouse = kv_cache_mouse
        block_state.kv_cache_keyboard = kv_cache_keyboard
        block_state.kv_cache_cross_attn = kv_cache_cross_attn

        for _ in range(num_blocks):
            block_state.current_frame_idx = current_frame_idx
            block_state.image_mask_latents = image_mask_latents[
                :, :, current_frame_idx : current_frame_idx + num_frames_per_block
            ]
            cond_idx = 1 + 4 * (current_frame_idx + num_frames_per_block - 1)
            block_state.mouse_conditions = mouse_conditions[:, :cond_idx]
            block_state.keyboard_conditions = keyboard_conditions[:, :cond_idx]

            block_state.latents = latents[
                :, :, current_frame_idx : current_frame_idx + num_frames_per_block
            ]
            for i, t in enumerate(block_state.timesteps):
                components, block_state = self.loop_step(
                    components, block_state, i=i, t=t
                )

                if i < (block_state.num_inference_steps - 1):
                    t1 = components.scheduler.timesteps[i+1]
                    block_state.latents = components.scheduler.add_noise(
                        block_state.latents,
                        randn_tensor(
                            block_state.latents.shape,
                            device=block_state.latents.device,
                            dtype=block_state.latents.dtype
                        ),
                        t1.expand(block_state.latents.shape[0])
                    )

            output[
                :, :, current_frame_idx : current_frame_idx + num_frames_per_block
            ] = block_state.latents

            components.transformer(
                x=block_state.latents,
                t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block) * 0.0,
                visual_context=visual_context,
                cond_concat=block_state.image_mask_latents,
                keyboard_cond=block_state.keyboard_conditions,
                mouse_cond=block_state.mouse_conditions,
                kv_cache=block_state.kv_cache,
                kv_cache_mouse=block_state.kv_cache_mouse,
                kv_cache_keyboard=block_state.kv_cache_keyboard,
                crossattn_cache=block_state.kv_cache_cross_attn,
                current_start=block_state.current_frame_idx * self.frame_seq_length,
                num_frames_per_block=block_state.num_frames_per_block,
            )[0]
            current_frame_idx += num_frames_per_block

        block_state.latents = output
        self.set_block_state(state, block_state)

        return components, state


class MatrixGameWanDenoiseStep(MatrixGameWanDenoiseLoopWrapper):
    block_classes = [
        MatrixGameWanLoopDenoiser,
        MatrixGameWanLoopAfterDenoiser,
    ]
    block_names = ["denoiser", "after_denoiser"]

    @property
    def description(self) -> str:
        return (
            "Denoise step that iteratively denoise the latents. \n"
            "Its loop logic is defined in `MatrixGameWanDenoiseLoopWrapper.__call__` method \n"
            "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
            " - `MatrixGameWanLoopDenoiser`\n"
            " - `MatrixGameWanLoopAfterDenoiser`\n"
            "This block supports both text2vid tasks."
        )