|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List |
|
|
|
|
|
import random |
|
|
import torch |
|
|
from torchvision.transforms import v2 |
|
|
|
|
|
from diffusers.utils import logging |
|
|
from diffusers import ModularPipeline, ModularPipelineBlocks |
|
|
from diffusers.modular_pipelines import PipelineState |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam |
|
|
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class MatrixGameWanImageEncoderStep(ModularPipelineBlocks): |
|
|
model_name = "MatrixGameWan" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Image Encoder step that generate image_embeddings to guide the video generation" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec( |
|
|
"image_encoder", |
|
|
CLIPVisionModelWithProjection, |
|
|
repo="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", |
|
|
), |
|
|
ComponentSpec( |
|
|
"image_processor", |
|
|
CLIPImageProcessor, |
|
|
repo="Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", |
|
|
subfolder="image_processor" |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ConfigSpec]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("image"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"image_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="image embeddings used to guide the image generation", |
|
|
) |
|
|
] |
|
|
|
|
|
def encode_image(self, components, image): |
|
|
device = components._execution_device |
|
|
image = components.image_processor(images=image, return_tensors="pt").to(device) |
|
|
image_embeds = components.image_encoder(**image, output_hidden_states=True) |
|
|
return image_embeds.hidden_states[-2] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
|
|
|
block_state = self.get_block_state(state) |
|
|
block_state.device = components._execution_device |
|
|
|
|
|
|
|
|
block_state.image_embeds = self.encode_image(components, block_state.image) |
|
|
|
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|