File size: 3,229 Bytes
5178ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List

import random
import torch
from torchvision.transforms import v2

from diffusers.utils import logging
from diffusers import ModularPipeline, ModularPipelineBlocks
from diffusers.modular_pipelines import PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class MatrixGameWanImageEncoderStep(ModularPipelineBlocks):
    model_name = "MatrixGameWan"

    @property
    def description(self) -> str:
        return "Image Encoder step that generate image_embeddings to guide the video generation"

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec(
                "image_encoder",
                CLIPVisionModelWithProjection,
                repo="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
            ),
            ComponentSpec(
                "image_processor",
                CLIPImageProcessor,
                repo="Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
                subfolder="image_processor"
            ),
        ]

    @property
    def expected_configs(self) -> List[ConfigSpec]:
        return []

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam("image"),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam(
                "image_embeds",
                type_hint=torch.Tensor,
                description="image embeddings used to guide the image generation",
            )
        ]

    def encode_image(self, components, image):
        device = components._execution_device
        image = components.image_processor(images=image, return_tensors="pt").to(device)
        image_embeds = components.image_encoder(**image, output_hidden_states=True)
        return image_embeds.hidden_states[-2]

    @torch.no_grad()
    def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
        # Get inputs and intermediates
        block_state = self.get_block_state(state)
        block_state.device = components._execution_device
        #image_tensor = preprocess(block_state.image)
        #image_tensor = image_tensor.to(block_state.device)
        block_state.image_embeds = self.encode_image(components, block_state.image)

        # Add outputs
        self.set_block_state(state, block_state)
        return components, state