|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from diffusers.modular_pipelines import ( |
|
|
ModularPipelineBlocks, |
|
|
ComponentSpec, |
|
|
PipelineState, |
|
|
ModularPipeline, |
|
|
OutputParam, |
|
|
InputParam, |
|
|
) |
|
|
from diffusers.modular_pipelines.wan.before_denoise import retrieve_timesteps |
|
|
from typing import Optional, List, Union, Tuple |
|
|
from diffusers.image_processor import PipelineImageInput |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
import torch |
|
|
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler |
|
|
|
|
|
|
|
|
from diffusers.pipelines.wan.pipeline_wan_i2v import retrieve_latents |
|
|
|
|
|
|
|
|
class ChronoEditSetTimestepsStep(ModularPipelineBlocks): |
|
|
model_name = "chronoedit" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("scheduler", UniPCMultistepScheduler) |
|
|
] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("num_inference_steps", default=50), |
|
|
InputParam("timesteps"), |
|
|
InputParam("sigmas") |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), |
|
|
OutputParam( |
|
|
"num_inference_steps", |
|
|
type_hint=int, |
|
|
description="The number of denoising steps to perform at inference time", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
block_state.device = components._execution_device |
|
|
|
|
|
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( |
|
|
components.scheduler, |
|
|
block_state.num_inference_steps, |
|
|
block_state.device, |
|
|
block_state.timesteps, |
|
|
block_state.sigmas, |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class ChronoEditPrepareLatentStep(ModularPipelineBlocks): |
|
|
model_name = "chronoedit" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ComponentSpec("vae", AutoencoderKLWan)] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("processed_image", type_hint=PipelineImageInput), |
|
|
InputParam("image_embeds", type_hint=torch.Tensor), |
|
|
InputParam("height", type_hint=int, default=480), |
|
|
InputParam("width", type_hint=int, default=832), |
|
|
InputParam("num_frames", type_hint=int, default=81), |
|
|
InputParam("batch_size"), |
|
|
InputParam("num_videos_per_prompt", type_hint=int, default=1), |
|
|
InputParam("latents", type_hint=Optional[torch.Tensor]), |
|
|
InputParam("generator"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="The initial latents to use for the denoising process.", |
|
|
), |
|
|
OutputParam( |
|
|
"condition", |
|
|
type_hint=torch.Tensor, |
|
|
description="Conditioning latents for the denoising process.", |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs(height, width): |
|
|
if height % 16 != 0 or width % 16 != 0: |
|
|
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") |
|
|
|
|
|
@staticmethod |
|
|
def prepare_latents( |
|
|
components, |
|
|
image: PipelineImageInput, |
|
|
batch_size: int, |
|
|
num_channels_latents: int = 16, |
|
|
height: int = 480, |
|
|
width: int = 832, |
|
|
num_frames: int = 81, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
latents: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
num_latent_frames = (num_frames - 1) // components.vae_scale_factor_temporal + 1 |
|
|
latent_height = height // components.vae_scale_factor_spatial |
|
|
latent_width = width // components.vae_scale_factor_spatial |
|
|
|
|
|
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) |
|
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
|
raise ValueError( |
|
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
|
) |
|
|
|
|
|
if latents is None: |
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
else: |
|
|
latents = latents.to(device=device, dtype=dtype) |
|
|
|
|
|
image = image.unsqueeze(2) |
|
|
video_condition = torch.cat( |
|
|
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 |
|
|
) |
|
|
video_condition = video_condition.to(device=device, dtype=dtype) |
|
|
|
|
|
latents_mean = ( |
|
|
torch.tensor(components.vae.config.latents_mean) |
|
|
.view(1, components.vae.config.z_dim, 1, 1, 1) |
|
|
.to(latents.device, latents.dtype) |
|
|
) |
|
|
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( |
|
|
1, components.vae.config.z_dim, 1, 1, 1 |
|
|
).to(latents.device, latents.dtype) |
|
|
|
|
|
if isinstance(generator, list): |
|
|
latent_condition = [ |
|
|
retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") for _ in generator |
|
|
] |
|
|
latent_condition = torch.cat(latent_condition) |
|
|
else: |
|
|
latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") |
|
|
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) |
|
|
|
|
|
latent_condition = (latent_condition - latents_mean) * latents_std |
|
|
|
|
|
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) |
|
|
mask_lat_size[:, :, list(range(1, num_frames))] = 0 |
|
|
first_frame_mask = mask_lat_size[:, :, 0:1] |
|
|
first_frame_mask = torch.repeat_interleave( |
|
|
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal |
|
|
) |
|
|
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) |
|
|
mask_lat_size = mask_lat_size.view( |
|
|
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width |
|
|
) |
|
|
mask_lat_size = mask_lat_size.transpose(1, 2) |
|
|
mask_lat_size = mask_lat_size.to(latent_condition.device) |
|
|
|
|
|
return latents, torch.concat([mask_lat_size, latent_condition], dim=1) |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
self.check_inputs(block_state.height, block_state.width) |
|
|
|
|
|
block_state.device = components._execution_device |
|
|
block_state.num_channels_latents = components.num_channels_latents |
|
|
|
|
|
batch_size = block_state.batch_size * block_state.num_videos_per_prompt |
|
|
block_state.latents, block_state.condition = self.prepare_latents( |
|
|
components, |
|
|
block_state.processed_image, |
|
|
batch_size, |
|
|
block_state.num_channels_latents, |
|
|
block_state.height, |
|
|
block_state.width, |
|
|
block_state.num_frames, |
|
|
torch.bfloat16, |
|
|
block_state.device, |
|
|
block_state.generator, |
|
|
block_state.latents, |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|