matrix-game-2-modular / encoders.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
5178ef1 verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
import random
import torch
from torchvision.transforms import v2
from diffusers.utils import logging
from diffusers import ModularPipeline, ModularPipelineBlocks
from diffusers.modular_pipelines import PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class MatrixGameWanImageEncoderStep(ModularPipelineBlocks):
model_name = "MatrixGameWan"
@property
def description(self) -> str:
return "Image Encoder step that generate image_embeddings to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_encoder",
CLIPVisionModelWithProjection,
repo="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
),
ComponentSpec(
"image_processor",
CLIPImageProcessor,
repo="Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
subfolder="image_processor"
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_embeds",
type_hint=torch.Tensor,
description="image embeddings used to guide the image generation",
)
]
def encode_image(self, components, image):
device = components._execution_device
image = components.image_processor(images=image, return_tensors="pt").to(device)
image_embeds = components.image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
block_state.device = components._execution_device
#image_tensor = preprocess(block_state.image)
#image_tensor = image_tensor.to(block_state.device)
block_state.image_embeds = self.encode_image(components, block_state.image)
# Add outputs
self.set_block_state(state, block_state)
return components, state