rahul7star's picture
Migrated from GitHub
0084610 verified
import os
from math import sqrt
from typing import Optional, Tuple, Union, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import (
DecoderOutput,
DiagonalGaussianDistribution,
)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
OPT_TEMPORAL_TILING = {
1: (1, 1),
17: (17, 17),
21: (13, 8),
25: (17, 8),
29: (17, 12),
33: (21, 12),
37: (21, 16),
41: (17, 12),
45: (21, 12),
49: (17, 8),
53: (21, 16),
57: (21, 12),
61: (13, 8),
65: (17, 12),
69: (21, 16),
73: (17, 8),
77: (17, 12),
81: (21, 12),
85: (21, 16),
89: (17, 12),
93: (21, 12),
97: (17, 8),
101: (21, 16),
105: (21, 12),
109: (13, 8),
113: (17, 12),
117: (21, 16),
121: (17, 8),
125: (17, 12),
129: (21, 12),
133: (21, 16),
137: (17, 12),
141: (21, 12),
145: (17, 8),
149: (21, 16),
153: (21, 12),
157: (13, 8),
161: (17, 12),
165: (21, 16),
169: (17, 8),
173: (17, 12),
177: (21, 12),
181: (21, 16),
185: (17, 12),
189: (21, 12),
193: (17, 8),
197: (21, 16),
201: (21, 12),
205: (13, 8),
209: (17, 12),
213: (21, 16),
217: (17, 8),
221: (17, 12),
225: (21, 12),
229: (21, 16),
233: (17, 12),
237: (21, 12),
241: (17, 8),
}
OPT_SPATIAL_TILING = {
160: (160, 160),
192: (192, 192),
224: (224, 224),
256: (256, 256),
288: (288, 288),
320: (320, 320),
352: (352, 352),
384: (384, 384),
448: (448, 448),
512: (288, 224),
576: (320, 256),
640: (352, 288),
704: (384, 320),
768: (416, 352),
896: (480, 416),
1024: (544, 480),
1152: (608, 544),
1280: (672, 608),
1408: (736, 672),
}
def prepare_causal_attention_mask(
f: int, s: int, dtype: torch.dtype, device: torch.device, b: int
) -> torch.Tensor:
return (
torch.ones((f, f), dtype=dtype, device=device)
.tril_()
.log_()
.repeat_interleave(s, dim=0)
.repeat_interleave(s, dim=1)
.unsqueeze(0)
.expand(b, -1, -1)
.contiguous()
)
class HunyuanVideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
super().__init__()
kernel_size = (
(kernel_size, kernel_size, kernel_size)
if isinstance(kernel_size, int)
else kernel_size
)
self.pad_mode = pad_mode
self.time_causal_padding = (
kernel_size[0] // 2,
kernel_size[0] // 2,
kernel_size[1] // 2,
kernel_size[1] // 2,
kernel_size[2] - 1,
0,
)
self.conv = nn.Conv3d(
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(
hidden_states, self.time_causal_padding, mode=self.pad_mode
)
return self.conv(hidden_states)
class HunyuanVideoUpsampleCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
upsample_factor: Tuple[float, float, float] = (2, 2, 2),
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.upsample_factor = upsample_factor
self.conv = HunyuanVideoCausalConv3d(
in_channels, out_channels, kernel_size, stride, bias=bias
)
@torch.compile(dynamic=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_frames = hidden_states.size(2)
dtp = hidden_states.dtype
first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)
first_frame = F.interpolate(
first_frame.squeeze(2),
scale_factor=self.upsample_factor[1:],
mode="nearest",
).unsqueeze(2).to(dtp) #force cast
if num_frames > 1:
other_frames = other_frames.contiguous()
other_frames = F.interpolate(
other_frames, scale_factor=self.upsample_factor, mode="nearest"
).to(dtp) # force cast
hidden_states = torch.cat((first_frame, other_frames), dim=2)
del first_frame
del other_frames
torch.cuda.empty_cache()
else:
hidden_states = first_frame
hidden_states = self.conv(hidden_states)
return hidden_states
class HunyuanVideoDownsampleCausal3D(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
padding: int = 1,
kernel_size: int = 3,
bias: bool = True,
stride=2,
) -> None:
super().__init__()
out_channels = out_channels or channels
self.conv = HunyuanVideoCausalConv3d(
channels, out_channels, kernel_size, stride, padding, bias=bias
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv(hidden_states)
return hidden_states
class HunyuanVideoResnetBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
non_linearity: str = "swish",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True)
self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0)
self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = HunyuanVideoCausalConv3d(
in_channels, out_channels, 1, 1, 0
)
@torch.compile(dynamic=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
dtp = hidden_states.dtype
hidden_states = hidden_states.contiguous()
residual = hidden_states
hidden_states = self.norm1(hidden_states).to(dtp) #force cast
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states).to(dtp) #force cast
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
hidden_states = hidden_states + residual
return hidden_states
class HunyuanVideoMidBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
add_attention: bool = True,
attention_head_dim: int = 1,
) -> None:
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.add_attention = add_attention
# There is always at least one resnet
resnets = [
HunyuanVideoResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
)
]
attentions = []
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
eps=resnet_eps,
norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
HunyuanVideoResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
batch_size, _, num_frames, height, width = hidden_states.shape
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
mask = prepare_causal_attention_mask(
num_frames,
height * width,
hidden_states.dtype,
hidden_states.device,
batch_size,
)
hidden_states = attn(hidden_states, attention_mask=mask)
hidden_states = hidden_states.unflatten(
1, (num_frames, height, width)
).permute(0, 4, 1, 2, 3)
hidden_states = resnet(hidden_states)
return hidden_states
class HunyuanVideoDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
add_downsample: bool = True,
downsample_stride: int = 2,
downsample_padding: int = 1,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanVideoResnetBlockCausal3D(
in_channels=in_channels,
out_channels=out_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
HunyuanVideoDownsampleCausal3D(
out_channels,
out_channels=out_channels,
padding=downsample_padding,
stride=downsample_stride,
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class HunyuanVideoUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
add_upsample: bool = True,
upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanVideoResnetBlockCausal3D(
in_channels=input_channels,
out_channels=out_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList(
[
HunyuanVideoUpsampleCausal3D(
out_channels,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
]
)
else:
self.upsamplers = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class HunyuanVideoEncoder3D(nn.Module):
r"""
Causal encoder for 3D video-like data introduced
in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
temporal_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
) -> None:
super().__init__()
self.conv_in = HunyuanVideoCausalConv3d(
in_channels, block_out_channels[0], kernel_size=3, stride=1
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
if down_block_type != "HunyuanVideoDownBlock3D":
raise ValueError(f"Unsupported down_block_type: {down_block_type}")
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(temporal_compression_ratio))
if temporal_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(
i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
and not is_final_block
)
elif temporal_compression_ratio == 8:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(i < num_time_downsample_layers)
else:
raise ValueError(
f"Unsupported time_compression_ratio: {temporal_compression_ratio}"
)
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = HunyuanVideoDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample_stride=downsample_stride,
downsample_padding=0,
)
self.down_blocks.append(down_block)
self.mid_block = HunyuanVideoMidBlock3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
add_attention=mid_block_add_attention,
)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = HunyuanVideoCausalConv3d(
block_out_channels[-1], conv_out_channels, kernel_size=3
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class HunyuanVideoDecoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data introduced
in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = HunyuanVideoCausalConv3d(
in_channels, block_out_channels[-1], kernel_size=3, stride=1
)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = HunyuanVideoMidBlock3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
if up_block_type != "HunyuanVideoUpBlock3D":
raise ValueError(f"Unsupported up_block_type: {up_block_type}")
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(
i >= len(block_out_channels) - 1 - num_time_upsample_layers
and not is_final_block
)
else:
raise ValueError(
f"Unsupported time_compression_ratio: {time_compression_ratio}"
)
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(
upsample_scale_factor_T + upsample_scale_factor_HW
)
up_block = HunyuanVideoUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanVideoCausalConv3d(
block_out_channels[0], out_channels, kernel_size=3
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
dtp = hidden_states.dtype
hidden_states = self.conv_in(hidden_states)
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states).to(dtp) # force cast
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents
and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
This model inherits from [`ModelMixin`]. Check the superclass
documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 16,
down_block_types: Tuple[str, ...] = (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
up_block_types: Tuple[str, ...] = (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 512, 512),
layers_per_block: int = 2,
act_fn: str = "silu",
norm_num_groups: int = 32,
scaling_factor: float = 0.476986,
spatial_compression_ratio: int = 8,
temporal_compression_ratio: int = 4,
mid_block_add_attention: bool = True,
) -> None:
super().__init__()
self.time_compression_ratio = temporal_compression_ratio
self.encoder = HunyuanVideoEncoder3D(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
double_z=True,
mid_block_add_attention=mid_block_add_attention,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
)
self.decoder = HunyuanVideoDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
time_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv3d(
2 * latent_channels, 2 * latent_channels, kernel_size=1
)
self.post_quant_conv = nn.Conv3d(
latent_channels, latent_channels, kernel_size=1
)
self.spatial_compression_ratio = spatial_compression_ratio
self.temporal_compression_ratio = temporal_compression_ratio
self.use_slicing = False
self.use_tiling = True
self.use_framewise_encoding = True
self.use_framewise_decoding = True
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
self.tile_sample_min_num_frames = 16
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
self.tile_sample_stride_num_frames = 12
self.tile_size = None
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, num_frames, height, width = x.shape
if self.use_framewise_decoding and num_frames > (
self.tile_sample_min_num_frames + 1
):
return self._temporal_tiled_encode(x)
if self.use_tiling and (
width > self.tile_sample_min_width or height > self.tile_sample_min_height
):
return self.tiled_encode(x)
x = self.encoder(x)
enc = self.quant_conv(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, opt_tiling: bool = True, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`]
instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned,
otherwise a plain `tuple` is returned.
"""
if opt_tiling:
tile_size, tile_stride = self.get_enc_optimal_tiling(x.shape)
else:
b, _, f, h, w = x.shape
tile_size, tile_stride = (b, f, h, w), (f, h, w)
if tile_size != self.tile_size:
self.tile_size = tile_size
self.apply_tiling(tile_size, tile_stride)
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(
self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
_, _, num_frames, height, width = z.shape
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_stride_width // self.spatial_compression_ratio
)
tile_latent_min_num_frames = (
self.tile_sample_min_num_frames // self.temporal_compression_ratio
)
if self.use_framewise_decoding and num_frames > (
tile_latent_min_num_frames + 1
):
return self._temporal_tiled_decode(z, return_dict=return_dict)
if self.use_tiling and (
width > tile_latent_min_width or height > tile_latent_min_height
):
return self.tiled_decode(z, return_dict=return_dict)
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned,
otherwise a plain `tuple` is returned.
"""
tile_size, tile_stride = self.get_dec_optimal_tiling(z.shape)
if tile_size != self.tile_size:
self.tile_size = tile_size
self.apply_tiling(tile_size, tile_stride)
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
1 - y / blend_extent
) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
1 - x / blend_extent
) + b[:, :, :, :, x] * (x / blend_extent)
return b
def blend_t(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (
1 - x / blend_extent
) + b[:, :, x, :, :] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, _, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_min_width // self.spatial_compression_ratio
)
tile_latent_stride_height = (
self.tile_sample_stride_height // self.spatial_compression_ratio
)
tile_latent_stride_width = (
self.tile_sample_stride_width // self.spatial_compression_ratio
)
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
rows = []
for i in range(
0, height - self.tile_sample_min_height + 1, self.tile_sample_stride_height
):
row = []
for j in range(
0, width - self.tile_sample_min_width + 1, self.tile_sample_stride_width
):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile).clone()
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
height_lim = (
tile_latent_min_height
if i == len(rows) - 1
else tile_latent_stride_height
)
width_lim = (
tile_latent_min_width
if j == len(row) - 1
else tile_latent_stride_width
)
result_row.append(tile[:, :, :, :height_lim, :width_lim])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(
self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned,
otherwise a plain `tuple` is returned.
"""
_, _, _, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_min_width // self.spatial_compression_ratio
)
tile_latent_stride_height = (
self.tile_sample_stride_height // self.spatial_compression_ratio
)
tile_latent_stride_width = (
self.tile_sample_stride_width // self.spatial_compression_ratio
)
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
rows = []
for i in range(
0, height - tile_latent_min_height + 1, tile_latent_stride_height
):
row = []
for j in range(
0, width - tile_latent_min_width + 1, tile_latent_stride_width
):
tile = z[
:,
:,
:,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile).clone()
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
height_lim = (
self.tile_sample_min_height
if i == len(rows) - 1
else self.tile_sample_stride_height
)
width_lim = (
self.tile_sample_min_width
if j == len(row) - 1
else self.tile_sample_stride_width
)
result_row.append(tile[:, :, :, :height_lim, :width_lim])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
_, _, num_frames, height, width = x.shape
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
tile_latent_min_num_frames = (
self.tile_sample_min_num_frames // self.temporal_compression_ratio
)
tile_latent_stride_num_frames = (
self.tile_sample_stride_num_frames // self.temporal_compression_ratio
)
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
row = []
# for i in range(0, num_frames, self.tile_sample_stride_num_frames):
for i in range(
0,
num_frames - self.tile_sample_min_num_frames + 1,
self.tile_sample_stride_num_frames,
):
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
if self.use_tiling and (
height > self.tile_sample_min_height
or width > self.tile_sample_min_width
):
tile = self.tiled_encode(tile)
else:
tile = self.encoder(tile).clone()
tile = self.quant_conv(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
t_lim = (
tile_latent_min_num_frames
if i == len(row) - 1
else tile_latent_stride_num_frames
)
result_row.append(tile[:, :, :t_lim, :, :])
else:
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
return enc
def _temporal_tiled_decode(
self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
_, _, num_frames, _, _ = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_min_width // self.spatial_compression_ratio
)
tile_latent_min_num_frames = (
self.tile_sample_min_num_frames // self.temporal_compression_ratio
)
tile_latent_stride_num_frames = (
self.tile_sample_stride_num_frames // self.temporal_compression_ratio
)
blend_num_frames = (
self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
)
row = []
for i in range(
0,
num_frames - tile_latent_min_num_frames + 1,
tile_latent_stride_num_frames,
):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (
tile.shape[-1] > tile_latent_min_width
or tile.shape[-2] > tile_latent_min_height
):
decoded = self.tiled_decode(tile, return_dict=True).sample
else:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile).clone()
if i > 0:
decoded = decoded[:, :, 1:, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
t_lim = (
self.tile_sample_min_num_frames
if i == len(row) - 1
else self.tile_sample_stride_num_frames
)
result_row.append(tile[:, :, :t_lim, :, :])
else:
result_row.append(
tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]
)
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
def apply_tiling(
self, tile: Tuple[int, int, int, int], stride: Tuple[int, int, int]
):
"""Applies tiling."""
_, ft, ht, wt = tile
fs, hs, ws = stride
self.use_tiling = True
self.tile_sample_min_num_frames = ft - 1
self.tile_sample_stride_num_frames = fs
self.tile_sample_min_height = ht
self.tile_sample_min_width = wt
self.tile_sample_stride_height = hs
self.tile_sample_stride_width = ws
def get_enc_optimal_tiling(
self, shape: List[int]
) -> Tuple[Tuple[int, int, int, int], Tuple[int, int, int]]:
"""Returns optimal tiling for given shape."""
_, _, num_frames, height, width = shape
if (sqrt(height * width) < 450) and (num_frames <= 97):
ft, fs = num_frames, num_frames
else:
ft = OPT_TEMPORAL_TILING[num_frames][0]
fs = OPT_TEMPORAL_TILING[num_frames][1]
if sqrt(height * width) > 500:
ht = OPT_SPATIAL_TILING[height][0]
hs = OPT_SPATIAL_TILING[height][1]
wt = OPT_SPATIAL_TILING[width][0]
ws = OPT_SPATIAL_TILING[width][1]
else:
ht, hs, wt, ws = height, height, width, width
return (1, ft, ht, wt), (fs, hs, ws)
def get_dec_optimal_tiling(
self, shape: List[int]
) -> Tuple[Tuple[int, int, int, int], Tuple[int, int, int]]:
"""Returns optimal tiling for given shape."""
b, _, f, h, w = shape
enc_inp_shape = [b, 3, 4 * (f - 1) + 1, 8 * h, 8 * w]
return self.get_enc_optimal_tiling(enc_inp_shape)
def build_vae(conf):
if conf.name == "hunyuan":
return AutoencoderKLHunyuanVideo.from_pretrained(
conf.checkpoint_path, subfolder="vae", torch_dtype=torch.float16
)
else:
assert False, f"unknown vae name {conf.name}"