diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/autoencoder_kl_3d.py b/autoencoder_kl_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6d82b75762c60d35d35a263caba21c71e39247a7 --- /dev/null +++ b/autoencoder_kl_3d.py @@ -0,0 +1,793 @@ +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from dataclasses import dataclass +from typing import Tuple, Optional +import math +import random +import numpy as np +from einops import rearrange +import torch +from torch import Tensor, nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + if parameters.ndim == 3: + dim = 2 # (B, L, C) + elif parameters.ndim == 5 or parameters.ndim == 4: + dim = 1 # (B, C, T, H ,W) / (B, C, H, W) + else: + raise NotImplementedError + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + reduce_dim = list(range(1, self.mean.ndim)) + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=reduce_dim, + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - + 1.0 - + self.logvar + + other.logvar, + dim=reduce_dim, + ) + + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +@dataclass +class DecoderOutput(BaseOutput): + sample: torch.FloatTensor + posterior: Optional[DiagonalGaussianDistribution] = None + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +def forward_with_checkpointing(module, *inputs, use_checkpointing=False): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + if use_checkpointing: + return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False) + else: + return module(*inputs) + + +class Conv3d(nn.Conv3d): + """ + Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5. + Only symmetric padding is supported. + """ + + def forward(self, input): + B, C, T, H, W = input.shape + memory_count = (C * T * H * W) * 2 / 1024**3 + if memory_count > 2: + n_split = math.ceil(memory_count / 2) + assert n_split >= 2 + chunks = torch.chunk(input, chunks=n_split, dim=-3) + padded_chunks = [] + for i in range(len(chunks)): + if self.padding[0] > 0: + padded_chunk = F.pad( + chunks[i], + (0, 0, 0, 0, self.padding[0], self.padding[0]), + mode="constant" if self.padding_mode == "zeros" else self.padding_mode, + value=0, + ) + if i > 0: + padded_chunk[:, :, :self.padding[0]] = chunks[i - 1][:, :, -self.padding[0]:] + if i < len(chunks) - 1: + padded_chunk[:, :, -self.padding[0]:] = chunks[i + 1][:, :, :self.padding[0]] + else: + padded_chunk = chunks[i] + padded_chunks.append(padded_chunk) + padding_bak = self.padding + self.padding = (0, self.padding[1], self.padding[2]) + outputs = [] + for i in range(len(padded_chunks)): + outputs.append(super().forward(padded_chunks[i])) + self.padding = padding_bak + return torch.cat(outputs, dim=-3) + else: + return super().forward(input) + + +class AttnBlock(nn.Module): + """ Attention with torch sdpa implementation. """ + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = Conv3d(in_channels, in_channels, kernel_size=1) + self.k = Conv3d(in_channels, in_channels, kernel_size=1) + self.v = Conv3d(in_channels, in_channels, kernel_size=1) + self.proj_out = Conv3d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, f, h, w = q.shape + q = rearrange(q, "b c f h w -> b 1 (f h w) c").contiguous() + k = rearrange(k, "b c f h w -> b 1 (f h w) c").contiguous() + v = rearrange(v, "b c f h w -> b 1 (f h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (f h w) c -> b c f h w", f=f, h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int, add_temporal_downsample: bool = True): + super().__init__() + self.add_temporal_downsample = add_temporal_downsample + stride = (2, 2, 2) if add_temporal_downsample else (1, 2, 2) # THW + # no asymmetric padding in torch conv, must do it ourselves + self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=stride, padding=0) + + def forward(self, x: Tensor): + spatial_pad = (0, 1, 0, 1, 0, 0) # WHT + x = nn.functional.pad(x, spatial_pad, mode="constant", value=0) + + temporal_pad = (0, 0, 0, 0, 0, 1) if self.add_temporal_downsample else (0, 0, 0, 0, 1, 1) + x = nn.functional.pad(x, temporal_pad, mode="replicate") + + x = self.conv(x) + return x + + +class DownsampleDCAE(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2 + assert out_channels % factor == 0 + self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + + self.add_temporal_downsample = add_temporal_downsample + self.group_size = factor * in_channels // out_channels + + def forward(self, x: Tensor): + r1 = 2 if self.add_temporal_downsample else 1 + h = self.conv(x) + h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2) + shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2) + + B, C, T, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + return h + shortcut + + +class Upsample(nn.Module): + def __init__(self, in_channels: int, add_temporal_upsample: bool = True): + super().__init__() + self.add_temporal_upsample = add_temporal_upsample + self.scale_factor = (2, 2, 2) if add_temporal_upsample else (1, 2, 2) # THW + self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode="nearest") + x = self.conv(x) + return x + + +class UpsampleDCAE(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2 + self.conv = Conv3d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1) + + self.add_temporal_upsample = add_temporal_upsample + self.repeats = factor * out_channels // in_channels + + def forward(self, x: Tensor): + r1 = 2 if self.add_temporal_upsample else 1 + h = self.conv(x) + h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2) + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = rearrange(shortcut, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2) + return h + shortcut + + +class Encoder(nn.Module): + """ + The encoder network of AutoencoderKLConv3D. + """ + def __init__( + self, + in_channels: int, + z_channels: int, + block_out_channels: Tuple[int, ...], + num_res_blocks: int, + ffactor_spatial: int, + ffactor_temporal: int, + downsample_match_channel: bool = True, + ): + super().__init__() + assert block_out_channels[-1] % (2 * z_channels) == 0 + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + # downsampling + self.conv_in = Conv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList() + block_in = block_out_channels[0] + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + + add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial)) + add_temporal_downsample = (add_spatial_downsample and + bool(i_level >= np.log2(ffactor_spatial // ffactor_temporal))) + if add_spatial_downsample or add_temporal_downsample: + assert i_level < len(block_out_channels) - 1 + block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in + down.downsample = DownsampleDCAE(block_in, block_out, add_temporal_downsample) + block_in = block_out + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = Conv3d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x: Tensor) -> Tensor: + use_checkpointing = bool(self.training and self.gradient_checkpointing) + + # downsampling + h = self.conv_in(x) + for i_level in range(len(self.block_out_channels)): + for i_block in range(self.num_res_blocks): + h = forward_with_checkpointing( + self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing) + if hasattr(self.down[i_level], "downsample"): + h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing) + + # middle + h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing) + h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing) + h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing) + + # end + group_size = self.block_out_channels[-1] // (2 * self.z_channels) + shortcut = rearrange(h, "b (c r) f h w -> b c r f h w", r=group_size).mean(dim=2) + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + h += shortcut + return h + + +class Decoder(nn.Module): + """ + The decoder network of AutoencoderKLConv3D. + """ + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: Tuple[int, ...], + num_res_blocks: int, + ffactor_spatial: int, + ffactor_temporal: int, + upsample_match_channel: bool = True, + ): + super().__init__() + assert block_out_channels[0] % z_channels == 0 + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + # z to block_in + block_in = block_out_channels[0] + self.conv_in = Conv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + + add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial)) + add_temporal_upsample = bool(i_level < np.log2(ffactor_temporal)) + if add_spatial_upsample or add_temporal_upsample: + assert i_level < len(block_out_channels) - 1 + block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in + up.upsample = UpsampleDCAE(block_in, block_out, add_temporal_upsample) + block_in = block_out + self.up.append(up) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = Conv3d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + self.gradient_checkpointing = False + + def forward(self, z: Tensor) -> Tensor: + use_checkpointing = bool(self.training and self.gradient_checkpointing) + + # z to block_in + repeats = self.block_out_channels[0] // (self.z_channels) + h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # middle + h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing) + h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing) + h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing) + + # upsampling + for i_level in range(len(self.block_out_channels)): + for i_block in range(self.num_res_blocks + 1): + h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing) + if hasattr(self.up[i_level], "upsample"): + h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class AutoencoderKLConv3D(ModelMixin, ConfigMixin): + """ + Autoencoder model with KL-regularized latent space based on 3D convolutions. + """ + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int, + out_channels: int, + latent_channels: int, + block_out_channels: Tuple[int, ...], + layers_per_block: int, + ffactor_spatial: int, + ffactor_temporal: int, + sample_size: int, + sample_tsize: int, + scaling_factor: float = None, + shift_factor: Optional[float] = None, + downsample_match_channel: bool = True, + upsample_match_channel: bool = True, + only_encoder: bool = False, # only build encoder for saving memory + only_decoder: bool = False, # only build decoder for saving memory + ): + super().__init__() + self.ffactor_spatial = ffactor_spatial + self.ffactor_temporal = ffactor_temporal + self.scaling_factor = scaling_factor + self.shift_factor = shift_factor + + # build model + if not only_decoder: + self.encoder = Encoder( + in_channels=in_channels, + z_channels=latent_channels, + block_out_channels=block_out_channels, + num_res_blocks=layers_per_block, + ffactor_spatial=ffactor_spatial, + ffactor_temporal=ffactor_temporal, + downsample_match_channel=downsample_match_channel, + ) + if not only_encoder: + self.decoder = Decoder( + z_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + num_res_blocks=layers_per_block, + ffactor_spatial=ffactor_spatial, + ffactor_temporal=ffactor_temporal, + upsample_match_channel=upsample_match_channel, + ) + + # slicing and tiling related + self.use_slicing = False + self.slicing_bsz = 1 + self.use_spatial_tiling = False + self.use_temporal_tiling = False + self.use_tiling_during_training = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = sample_size + self.tile_latent_min_size = sample_size // ffactor_spatial + self.tile_sample_min_tsize = sample_tsize + self.tile_latent_min_tsize = sample_tsize // ffactor_temporal + self.tile_overlap_factor = 0.25 + + # use torch.compile for faster encode speed + self.use_compile = False + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_tiling_during_training(self, use_tiling: bool = True): + self.use_tiling_during_training = use_tiling + + def disable_tiling_during_training(self): + self.enable_tiling_during_training(False) + + def enable_temporal_tiling(self, use_tiling: bool = True): + self.use_temporal_tiling = use_tiling + + def disable_temporal_tiling(self): + self.enable_temporal_tiling(False) + + def enable_spatial_tiling(self, use_tiling: bool = True): + self.use_spatial_tiling = use_tiling + + def disable_spatial_tiling(self): + self.enable_spatial_tiling(False) + + def enable_tiling(self, use_tiling: bool = True): + self.enable_spatial_tiling(use_tiling) + + def disable_tiling(self): + self.disable_spatial_tiling() + + def enable_slicing(self): + self.use_slicing = True + + def disable_slicing(self): + self.use_slicing = False + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int): + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = \ + a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int): + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = \ + a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int): + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = \ + a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent) + return b + + def spatial_tiled_encode(self, x: torch.Tensor): + """ spatial tailing for frames """ + B, C, T, H, W = x.shape + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) # 8 * 0.25 = 2 + row_limit = self.tile_latent_min_size - blend_extent # 8 - 2 = 6 + + rows = [] + for i in range(0, H, overlap_size): + row = [] + for j in range(0, W, overlap_size): + tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + moments = torch.cat(result_rows, dim=-2) + return moments + + def temporal_tiled_encode(self, x: torch.Tensor): + """ temporal tailing for frames """ + B, C, T, H, W = x.shape + overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) # 64 * (1 - 0.25) = 48 + blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) # 8 * 0.25 = 2 + t_limit = self.tile_latent_min_tsize - blend_extent # 8 - 2 = 6 + + row = [] + for i in range(0, T, overlap_size): + tile = x[:, :, i: i + self.tile_sample_min_tsize, :, :] + if self.use_spatial_tiling and ( + tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size): + tile = self.spatial_tiled_encode(tile) + else: + tile = self.encoder(tile) + row.append(tile) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + moments = torch.cat(result_row, dim=-3) + return moments + + def spatial_tiled_decode(self, z: torch.Tensor): + """ spatial tailing for frames """ + B, C, T, H, W = z.shape + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) # 256 * 0.25 = 64 + row_limit = self.tile_sample_min_size - blend_extent # 256 - 64 = 192 + + rows = [] + for i in range(0, H, overlap_size): + row = [] + for j in range(0, W, overlap_size): + tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=-2) + return dec + + def temporal_tiled_decode(self, z: torch.Tensor): + """ temporal tailing for frames """ + B, C, T, H, W = z.shape + overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) # 64 * 0.25 = 16 + t_limit = self.tile_sample_min_tsize - blend_extent # 64 - 16 = 48 + assert 0 < overlap_size < self.tile_latent_min_tsize + + row = [] + for i in range(0, T, overlap_size): + tile = z[:, :, i: i + self.tile_latent_min_tsize, :, :] + if self.use_spatial_tiling and ( + tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size): + decoded = self.spatial_tiled_decode(tile) + else: + decoded = self.decoder(tile) + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + dec = torch.cat(result_row, dim=-3) + return dec + + def encode(self, x: Tensor, return_dict: bool = True): + """ + Encodes the input by passing through the encoder network. + Support slicing and tiling for memory efficiency. + """ + def _encode(x): + if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize: + return self.temporal_tiled_encode(x) + if self.use_spatial_tiling and ( + x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.spatial_tiled_encode(x) + + if self.use_compile: + @torch.compile + def encoder(x): + return self.encoder(x) + return encoder(x) + return self.encoder(x) + + if len(x.shape) != 5: # (B, C, T, H, W) + x = x[:, :, None] + assert len(x.shape) == 5 # (B, C, T, H, W) + if x.shape[2] == 1: + x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) + else: + assert x.shape[2] != self.ffactor_temporal and x.shape[2] % self.ffactor_temporal == 0 + + if self.use_slicing and x.shape[0] > 1: + if self.slicing_bsz == 1: + encoded_slices = [_encode(x_slice) for x_slice in x.split(1)] + else: + sections = [self.slicing_bsz] * (x.shape[0] // self.slicing_bsz) + if x.shape[0] % self.slicing_bsz != 0: + sections.append(x.shape[0] % self.slicing_bsz) + encoded_slices = [_encode(x_slice) for x_slice in x.split(sections)] + h = torch.cat(encoded_slices) + else: + h = _encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def decode(self, z: Tensor, return_dict: bool = True, generator=None): + """ + Decodes the input by passing through the decoder network. + Support slicing and tiling for memory efficiency. + """ + def _decode(z): + if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize: + return self.temporal_tiled_decode(z) + if self.use_spatial_tiling and ( + z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.spatial_tiled_decode(z) + return self.decoder(z) + + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [_decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = _decode(z) + + if z.shape[-3] == 1: + decoded = decoded[:, :, -1:] + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_posterior: bool = True, + return_dict: bool = True + ): + posterior = self.encode(sample).latent_dist + z = posterior.sample() if sample_posterior else posterior.mode() + dec = self.decode(z).sample + return DecoderOutput(sample=dec, posterior=posterior) if return_dict else (dec, posterior) + + def random_reset_tiling(self, x: torch.Tensor): + if x.shape[-3] == 1: + self.disable_spatial_tiling() + self.disable_temporal_tiling() + return + + # Use fixed shape here + min_sample_size = int(1 / self.tile_overlap_factor) * self.ffactor_spatial + min_sample_tsize = int(1 / self.tile_overlap_factor) * self.ffactor_temporal + sample_size = random.choice([None, 1 * min_sample_size, 2 * min_sample_size, 3 * min_sample_size]) + if sample_size is None: + self.disable_spatial_tiling() + else: + self.tile_sample_min_size = sample_size + self.tile_latent_min_size = sample_size // self.ffactor_spatial + self.enable_spatial_tiling() + + sample_tsize = random.choice([None, 1 * min_sample_tsize, 2 * min_sample_tsize, 3 * min_sample_tsize]) + if sample_tsize is None: + self.disable_temporal_tiling() + else: + self.tile_sample_min_tsize = sample_tsize + self.tile_latent_min_tsize = sample_tsize // self.ffactor_temporal + self.enable_temporal_tiling() diff --git a/configuration_hunyuan.py b/configuration_hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c8acfa7ebe7b934c01ff3f1f37f19971762f61 --- /dev/null +++ b/configuration_hunyuan.py @@ -0,0 +1,285 @@ +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from typing import List, Union + + +logger = logging.get_logger(__name__) + + +class HunyuanImage3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`HunyuanImage3Model`]. It is used to instantiate + an Hunyuan model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Hunyuan-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Hunyuan Image 3 model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`HunyuanImage3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations or shared MLP representations. + moe_intermediate_size (`int` or `List`, *optional*, defaults to 11008): + Dimension of the MLP representations in MoE. Use a list if you want a different size per layer. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + use_qk_norm (`bool`, *optional*, defaults to `False`): + Whether query and key in attention use norm + use_cla (`bool`, *optional*, defaults to `False`): + Whether to use CLA in attention + cla_share_factor (`int`, *optional*, defaults to 1): + The share factor of CLA + num_experts (`int` or `List`, *optional*, defaults to 1): + The number of experts for moe. If it is a list, it will be used as the number of experts for each layer. + num_shared_expert (`int` or `List`, *optional*, defaults to 1): + The number of shared experts for moe. If it is a list, it will be used as the number of shared experts + for each layer. + moe_topk (`int` or `List`, *optional*, defaults to 1): + The topk value for moe. If it is a list, it will be used as the topk value for each layer. + capacity_factor (Not used) (`float` or `List`, *optional*, defaults to 1.0): + The capacity factor for moe. If it is a list, it will be used as the capacity factor for each layer. + moe_layer_num_skipped (`int`, *optional*, defaults to 0): + First moe_layer_num_skipped layers do not use MoE. + """ + + model_type = "Hunyuan" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=290943, + hidden_size=4096, + intermediate_size: int=11008, + moe_intermediate_size: Union[int, List]=None, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + attention_head_dim=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + eod_token_id=3, + im_start_id=4, + im_end_id=5, + text_start_id=6, + text_end_id=7, + image_token_id=8, + video_start_id=9, + video_end_id=10, + im_newline_id=11, + mask_init_id=12, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + mlp_bias=False, + attention_dropout=0.0, + use_qk_norm=False, + use_rotary_pos_emb=True, + use_cla=False, + cla_share_factor=1, + norm_type="hf_rms", + num_experts: Union[int, List] = 1, + use_mixed_mlp_moe=False, + num_shared_expert: Union[int, List] = 1, + moe_topk: Union[int, List] = 1, + capacity_factor: int = 1.0, + moe_drop_tokens=False, + moe_random_routing_dropped_token=False, + use_mla=False, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + moe_layer_num_skipped=0, + norm_topk_prob=True, + routed_scaling_factor=1.0, + group_limited_greedy=False, + n_group=None, + topk_group=None, + add_classification_head=False, + class_num=0, + pool_type="last", + pad_id=-1, + # Added + moe_impl="eager", + vae_downsample_factor=(16, 16), # (h, w) + img_proj_type="unet", + patch_size=1, + patch_embed_hidden_dim=1024, + image_base_size=1024, + vae=None, + vit=None, + vit_processor=None, + vit_aligner=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.moe_impl = moe_impl + self.num_experts = num_experts + self.use_mixed_mlp_moe = use_mixed_mlp_moe + self.num_shared_expert = num_shared_expert + self.moe_topk = moe_topk + self.capacity_factor = capacity_factor + self.moe_drop_tokens = moe_drop_tokens + self.moe_random_routing_dropped_token = moe_random_routing_dropped_token + + if attention_head_dim is not None: + self.attention_head_dim = attention_head_dim + else: + self.attention_head_dim = self.hidden_size // num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias + self.attention_dropout = attention_dropout + self.use_qk_norm = use_qk_norm + self.use_rotary_pos_emb = use_rotary_pos_emb + self.use_cla = use_cla + self.cla_share_factor = cla_share_factor + self.norm_type = norm_type + # MLA args + self.use_mla = use_mla + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.v_head_dim = v_head_dim + + # DeepSeek related args + self.moe_layer_num_skipped = moe_layer_num_skipped + self.norm_topk_prob = norm_topk_prob + self.routed_scaling_factor = routed_scaling_factor + self.group_limited_greedy = group_limited_greedy + self.n_group = n_group + self.topk_group = topk_group + self.add_classification_head = add_classification_head + self.class_num = class_num + self.pool_type = pool_type + self.pad_id = pad_id + + if self.class_num is not None: + self.dense_list = [self.hidden_size, self.class_num] + + # ViT args + self.vit = vit + self.vit_processor = vit_processor + self.vit_aligner = vit_aligner + + # Image Gen args + self.vae = vae + self.vae_downsample_factor = vae_downsample_factor + self.img_proj_type = img_proj_type + self.patch_size = patch_size + self.patch_embed_hidden_dim = patch_embed_hidden_dim + self.image_base_size = image_base_size + + # token id + self.eod_token_id = eod_token_id + self.im_start_id = im_start_id + self.im_end_id = im_end_id + self.text_start_id = text_start_id + self.text_end_id = text_end_id + self.image_token_id = image_token_id + self.video_start_id = video_start_id + self.video_end_id = video_end_id + self.im_newline_id = im_newline_id + self.mask_init_id = mask_init_id + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/hunyuan.py b/hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..ec823cb9977202dcb3f02679b7e2b70464d8f686 --- /dev/null +++ b/hunyuan.py @@ -0,0 +1,2654 @@ +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math +import random +import re +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Union, Optional, Dict, Any, Tuple, Callable + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from torch import Tensor +from torch import nn +from torch.cuda import nvtx +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, StaticCache +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.generation.utils import GenerationMixin, GenerationConfig, ALL_CACHE_NAMES +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, +) + +if TYPE_CHECKING: + from transformers.generation.streamers import BaseStreamer + +try: + import flashinfer +except Exception as e: + flashinfer = None + +from .autoencoder_kl_3d import AutoencoderKLConv3D +from .configuration_hunyuan import HunyuanImage3Config +from .hunyuan_image_3_pipeline import HunyuanImage3Text2ImagePipeline, FlowMatchDiscreteScheduler +from .image_processor import HunyuanImage3ImageProcessor +from .siglip2 import Siglip2VisionTransformer, LightProjector +from .tokenizer_wrapper import TokenizerWrapper, ImageInfo, JointImageInfo +from .system_prompt import get_system_prompt, t2i_system_prompts + + +logger = logging.get_logger(__name__) + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func + +# Type aliases +BatchRaggedImages = Union[torch.Tensor, List[Union[torch.Tensor, List[torch.Tensor]]]] +BatchRaggedTensor = Union[torch.Tensor, List[torch.Tensor]] + + +_CONFIG_FOR_DOC = "HunyuanImage3Config" + +Hunyuan_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`HunyuanImage3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +# ======================================================= +# Helper Functions +# ======================================================= + +def default(val, d): + return val if val is not None else d + + +def to_device(data, device): + if device is None: + return data + if isinstance(data, torch.Tensor): + return data.to(device) + elif isinstance(data, list): + return [to_device(x, device) for x in data] + else: + return data + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def real_batched_index_select(t, dim, idx): + """ index_select for batched index and batched t """ + assert t.ndim >= 2 and idx.ndim >= 2, f"{t.ndim=} {idx.ndim=}" + assert len(t) == len(idx), f"{len(t)=} != {len(idx)=}" + return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))]) + + +# ======================================================= +# Module Functions +# ======================================================= + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def normalization(channels, **kwargs): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: a nn.Module for normalization. + """ + return nn.GroupNorm(32, channels, **kwargs) + + +def topkgating( + logits: Tensor, + topk: int, + group_limited_greedy: bool = False, + n_group: int = None, + topk_group: int = None, + norm_topk_prob: bool = True, + routed_scaling_factor: float = 1.0, + capacity_factor: float = 1.0, + drop_tokens: bool = False, +): + logits = logits.float() + gates = F.softmax(logits, dim=1) + + if group_limited_greedy: + group_shape = list(gates.shape[:-1]) + [n_group, gates.shape[-1] // n_group] + group_scores = ( + gates.reshape(group_shape).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + group_shape + ) + .reshape(list(gates.shape)) + ) # [n, e] + gates = gates.masked_fill(~score_mask.bool(), 0.0) + + num_experts = int(gates.shape[1]) + # Top-k router probability and corresponding expert indices for each token. + # Shape: [tokens_per_group, num_selected_experts]. + expert_gate, expert_index = torch.topk(gates, topk) + expert_mask = F.one_hot(expert_index, num_experts) + # For a given token, determine if it was routed to a given expert. + # Shape: [tokens_per_group, num_experts] + expert_mask_aux = expert_mask.max(dim=-2)[0] + tokens_per_group_and_expert = torch.mean(expert_mask_aux.float(), dim=-2) + router_prob_per_group_and_expert = torch.mean(gates.float(), dim=-2) + l_aux = num_experts ** 2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) + + if drop_tokens: + expert_capacity = int(max(topk, topk * gates.shape[0] // gates.shape[1]) * capacity_factor) + else: + expert_index_flat = expert_index.flatten() + tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts) + expert_capacity = torch.max(tokens_per_expert).item() + + if norm_topk_prob and topk > 1: + gates_s = torch.clamp( + torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps + ) + router_probs = gates / gates_s + else: + router_probs = gates * routed_scaling_factor + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = torch.transpose(expert_index, 0, 1) + # Shape: [num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(-1) + + # Create mask out of indices. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) + exp_counts = torch.sum(expert_mask, dim=0).detach() + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 + # Shape: [num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((topk, -1, num_experts)) + # Shape: [tokens_per_group, num_selected_experts, num_experts]. + token_priority = torch.transpose(token_priority, 0, 1) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [tokens_per_group, num_experts]. + token_priority = torch.max(token_priority, dim=1)[0] + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [tokens_per_group, num_experts, expert_capacity]. + valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) + token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) + valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) + dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) + exp_counts_capacity = torch.sum(dispatch_mask) + exp_capacity_rate = exp_counts_capacity / (logits.shape[0] * topk) + + return [l_aux, exp_capacity_rate], combine_weights, dispatch_mask, exp_counts + + +# ======================================================= +# Multi-Dimensional RoPE +# ======================================================= + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + # assert num are all integers + num_int = [int(x) for x in num] + assert (torch.tensor(num) == torch.tensor(num_int)).all(), f"num should be int, but got {num}" + num = num_int + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [H, W] + grid = torch.stack(grid, dim=0) # [dim, H, W] + + return grid + + +def build_2d_rope( + seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None, + device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, + return_all_pos: bool = False, +): + """ + Reference: https://kexue.fm/archives/10352 + + Start from 1, we have + beta_y = L + (wh - h)/2 + beta_x = L + (wh - w)/2 + + Returns + ------- + cos: torch.Tensor with shape of [seq_len, n_elem] + sin: torch.Tensor with shape of [seq_len, n_elem] + """ + assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}." + + # theta + if base_rescale_factor != 1.0: + base *= base_rescale_factor ** (n_elem / (n_elem - 2)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2] + + # position indices + if image_infos is None: + image_infos = [] + + image_infos_list = [image_infos] + sample_seq_lens = [seq_len] + + # Prepare position indices for each sample + x_sections = [] + y_sections = [] + for sample_id, sample_image_infos in enumerate(image_infos_list): + last_pos = 0 + for sec_slice, (h, w) in sample_image_infos: + L = sec_slice.start # start from 0, so image_slice.start is just L + # previous text + if last_pos < L: + y_sections.append(torch.arange(last_pos, L)) + x_sections.append(torch.arange(last_pos, L)) + elif h is None: + # Interleave data has overlapped positions for tokens. + y_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) + x_sections.append(torch.arange(sec_slice.start, sec_slice.stop)) + continue + else: + # Interleave data has overlapped positions for noised image and the successive clean image, + # leading to last_pos (= last text end L + noise w * h) > L (last text end L). + pass + # current image + beta_y = L + (w * h - h) / 2 + beta_x = L + (w * h - w) / 2 + grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w] + grid = grid.reshape(2, -1) # (y, x) + y_sections.append(grid[0]) + x_sections.append(grid[1]) + # step + last_pos = L + w * h + # final text + y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) + x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id])) + + x_pos = torch.cat(x_sections).long() + y_pos = torch.cat(y_sections).long() + # If there are overlap positions, we need to remove them. + x_pos = x_pos[:seq_len] + y_pos = y_pos[:seq_len] + all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2] + + # calc rope + idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2) + + cos = torch.cos(idx_theta) + sin = torch.sin(idx_theta) + + if return_all_pos: + return cos, sin, all_pos + + return cos, sin + + +def build_batch_2d_rope( + seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None, + device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0, + return_all_pos: bool = False, +): + cos_list, sin_list, all_pos_list = [], [], [] + if image_infos is None: + image_infos = [None] + for i, image_info in enumerate(image_infos): + res = build_2d_rope( + seq_len, n_elem, image_infos=image_info, device=device, + base=base, base_rescale_factor=base_rescale_factor, + return_all_pos=return_all_pos, + ) + if return_all_pos: + cos, sin, all_pos = res + else: + cos, sin = res + all_pos = None + cos_list.append(cos) + sin_list.append(sin) + all_pos_list.append(all_pos) + + stacked_cos = torch.stack(cos_list, dim=0) + stacked_sin = torch.stack(sin_list, dim=0) + + if return_all_pos: + return stacked_cos, stacked_sin, all_pos_list + + return stacked_cos, stacked_sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass shifted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + if position_ids is not None: + cos = cos[position_ids] + sin = sin[position_ids] + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# ======================================================= +# Modules for Image Generation +# ======================================================= + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, + hidden_size, + act_layer=nn.GELU, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None + ): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.normal_(self.mlp[2].weight, std=0.02) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, **factory_kwargs) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=1, **factory_kwargs + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + + :param in_channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + in_channels, + emb_channels, + out_channels=None, + dropout=0.0, + use_conv=False, + dims=2, + up=False, + down=False, + device=None, + dtype=None, + ): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.in_channels = in_channels + self.dropout = dropout + self.out_channels = out_channels or self.in_channels + self.use_conv = use_conv + + self.in_layers = nn.Sequential( + normalization(self.in_channels, **factory_kwargs), + nn.SiLU(), + conv_nd(dims, self.in_channels, self.out_channels, 3, padding=1, **factory_kwargs), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(self.in_channels, False, dims, **factory_kwargs) + self.x_upd = Upsample(self.in_channels, False, dims, **factory_kwargs) + elif down: + self.h_upd = Downsample(self.in_channels, False, dims, **factory_kwargs) + self.x_upd = Downsample(self.in_channels, False, dims, **factory_kwargs) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear(emb_channels, 2 * self.out_channels, **factory_kwargs) + ) + + self.out_layers = nn.Sequential( + normalization(self.out_channels, **factory_kwargs), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, **factory_kwargs) + ), + ) + + if self.out_channels == self.in_channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, self.in_channels, self.out_channels, 3, padding=1, **factory_kwargs + ) + else: + self.skip_connection = conv_nd(dims, self.in_channels, self.out_channels, 1, **factory_kwargs) + + def forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + emb_out = self.emb_layers(emb) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + + # Adaptive Group Normalization + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1. + scale) + shift + h = out_rest(h) + + return self.skip_connection(x) + h + + +class UNetDown(nn.Module): + """ + patch_size: one of [1, 2 ,4 ,8] + in_channels: vae latent dim + hidden_channels: hidden dim for reducing parameters + out_channels: transformer model dim + """ + def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, + dropout=0.0, device=None, dtype=None): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + + self.patch_size = patch_size + assert self.patch_size in [1, 2, 4, 8] + + self.model = nn.ModuleList( + [conv_nd( + 2, + in_channels=in_channels, + out_channels=hidden_channels, + kernel_size=3, + padding=1, + **factory_kwargs + )] + ) + + if self.patch_size == 1: + self.model.append(ResBlock( + in_channels=hidden_channels, + emb_channels=emb_channels, + out_channels=out_channels, + dropout=dropout, + **factory_kwargs + )) + else: + for i in range(self.patch_size // 2): + self.model.append(ResBlock( + in_channels=hidden_channels, + emb_channels=emb_channels, + out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, + dropout=dropout, + down=True, + **factory_kwargs + )) + + def forward(self, x, t): + assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0 + for module in self.model: + if isinstance(module, ResBlock): + x = module(x, t) + else: + x = module(x) + _, _, token_h, token_w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + return x, token_h, token_w + + +class UNetUp(nn.Module): + """ + patch_size: one of [1, 2 ,4 ,8] + in_channels: transformer model dim + hidden_channels: hidden dim for reducing parameters + out_channels: vae latent dim + """ + def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels, + dropout=0.0, device=None, dtype=None, out_norm=False): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + + self.patch_size = patch_size + assert self.patch_size in [1, 2, 4, 8] + + self.model = nn.ModuleList() + + if self.patch_size == 1: + self.model.append(ResBlock( + in_channels=in_channels, + emb_channels=emb_channels, + out_channels=hidden_channels, + dropout=dropout, + **factory_kwargs + )) + else: + for i in range(self.patch_size // 2): + self.model.append(ResBlock( + in_channels=in_channels if i == 0 else hidden_channels, + emb_channels=emb_channels, + out_channels=hidden_channels, + dropout=dropout, + up=True, + **factory_kwargs + )) + + if out_norm: + self.model.append(nn.Sequential( + normalization(hidden_channels, **factory_kwargs), + nn.SiLU(), + conv_nd( + 2, + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + **factory_kwargs + ), + )) + else: + self.model.append(conv_nd( + 2, + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + **factory_kwargs + )) + + # batch_size, seq_len, model_dim + def forward(self, x, t, token_h, token_w): + x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w) + for module in self.model: + if isinstance(module, ResBlock): + x = module(x, t) + else: + x = module(x) + return x + + +# ======================================================= +# Modules for Transformer Backbone +# ======================================================= + +@dataclass +class CausalMMOutputWithPast(CausalLMOutputWithPast): + diffusion_prediction: Optional[torch.Tensor] = None + + +class HunyuanStaticCache(StaticCache): + """ + A custom static cache for multi-modal models that supports dynamic extension of the cache + and inplace updates of the cache. + + This cache supports batch cache_position updates. + """ + def __init__(self, *args, **kwargs): + self.dynamic = kwargs.pop("dynamic", False) + super().__init__(*args, **kwargs) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + cache_position = cache_kwargs.get("cache_position") + if hasattr(self, "key_cache") and hasattr(self, "value_cache"): + if self.key_cache[layer_idx].device != key_states.device: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + else: + if self.layers[layer_idx].keys is None: + self.layers[layer_idx].lazy_initialization(key_states) + k_out = self.layers[layer_idx].keys + v_out = self.layers[layer_idx].values + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + if cache_position.dim() == 1: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + + if self.dynamic: + end = cache_position[-1].item() + 1 + k_out = k_out[:, :, :end] + v_out = v_out[:, :, :end] + else: + assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}" + batch_size, idx_size = cache_position.shape + assert batch_size == k_out.size(0) + assert batch_size == v_out.size(0) + assert batch_size == key_states.size(0) + assert batch_size == value_states.size(0) + for i in range(batch_size): + unbatched_dim = 1 + k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i]) + v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i]) + + if self.dynamic: + assert len(cache_position) == 1 + end = cache_position[0, -1].item() + 1 + k_out = k_out[:, :, :end] + v_out = v_out[:, :, :end] + + return k_out, v_out + + +class HunyuanRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + HunyuanRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class HunyuanMLP(nn.Module): + def __init__(self, config: HunyuanImage3Config, layer_idx=None, is_shared_mlp=False, is_moe=False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.hidden_act = config.hidden_act + + self.intermediate_size = config.intermediate_size + if is_shared_mlp or is_moe: + # 如果是 moe 的话,优先用 moe_intermediate_size + if config.moe_intermediate_size is not None: + self.intermediate_size = config.moe_intermediate_size \ + if isinstance(config.moe_intermediate_size, int) else config.moe_intermediate_size[layer_idx] + + if is_shared_mlp: + num_shared_expert = config.num_shared_expert \ + if isinstance(config.num_shared_expert, int) else config.num_shared_expert[layer_idx] + self.intermediate_size *= num_shared_expert + + self.act_fn = ACT2FN[config.hidden_act] + if self.hidden_act == "silu": + self.intermediate_size *= 2 # SwiGLU + self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=config.mlp_bias) + elif self.hidden_act == "gelu": + self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + else: + assert False, "other hidden_act are not supported" + + def forward(self, x): + if self.hidden_act == "silu": + gate_and_up_proj = self.gate_and_up_proj(x) + x1, x2 = gate_and_up_proj.chunk(2, dim=2) + down_proj = self.down_proj(x1 * self.act_fn(x2)) + return down_proj + elif self.hidden_act == "gelu": + intermediate = self.gate_and_up_proj(x) + intermediate = self.act_fn(intermediate) + output = self.down_proj(intermediate) + return output + else: + assert False, "other hidden_act are not supported" + + +class HunyuanTopKGate(nn.Module): + def __init__(self, config: HunyuanImage3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.moe_topk = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] + self.drop_tokens = config.moe_drop_tokens + self.min_capacity = 8 + self.random_routing_dropped_token = config.moe_random_routing_dropped_token + num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] + self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32) + + # DeepSeek gating args + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.group_limited_greedy = config.group_limited_greedy + + def forward(self, hidden_states, topk_impl='default'): + bsz, seq_len, hidden_size = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_size) + if self.wg.weight.dtype == torch.float32: + hidden_states = hidden_states.float() + logits = self.wg(hidden_states) + if topk_impl == 'default': + gate_output = topkgating(logits, self.moe_topk, group_limited_greedy=self.group_limited_greedy, + n_group=self.n_group, topk_group=self.topk_group, + norm_topk_prob=self.norm_topk_prob, + routed_scaling_factor=self.routed_scaling_factor, + capacity_factor=self.config.capacity_factor, + drop_tokens=self.drop_tokens) + elif topk_impl == 'easy': + gate_output = self.easy_topk(logits, self.moe_topk) + else: + raise ValueError(f"Unsupported topk_impl: {topk_impl}") + + return gate_output + + @staticmethod + def easy_topk(logits, moe_topk): + gates = F.softmax(logits, dim=1) + topk_weight_1, expert_index = torch.topk(gates, moe_topk) + weight_sums = topk_weight_1.sum(dim=1, keepdim=True) + weight_sums = torch.clamp(weight_sums, min=1e-8) + topk_weight = topk_weight_1 / weight_sums + + return topk_weight, expert_index + + +class HunyuanMoE(nn.Module): + def __init__(self, config: HunyuanImage3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.moe_topk = config.moe_topk + self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] + if config.use_mixed_mlp_moe: + self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True) + self.gate = HunyuanTopKGate(config, layer_idx=layer_idx) + self.experts = nn.ModuleList( + [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] + ) + + self._moe_impl = config.moe_impl + # For FlashInfer + self.moe_weight = None + self.moe_weight_2 = None + self._weights_initialized = False + + @property + def moe_impl(self): + return self._moe_impl + + @moe_impl.setter + def moe_impl(self, value): + self._moe_impl = value + if self._moe_impl == "flashinfer": + assert flashinfer is not None, "When using fused_moe, flashinfer must be installed." + + def forward(self, hidden_states): + torch.cuda.set_device(hidden_states.device.index) + bsz, seq_len, hidden_size = hidden_states.shape + + if self.config.use_mixed_mlp_moe: + hidden_states_mlp = self.shared_mlp(hidden_states) + + reshaped_input = hidden_states.reshape(-1, hidden_size) # [bsz*seq_len, hidden_size] + + with nvtx.range("MoE"): + if self._moe_impl == "flashinfer": + # Get expert weights + if not self._weights_initialized: + self._initialize_weights_on_device(hidden_states.device) + topk_weight, topk_index = self.gate(hidden_states, topk_impl='easy') + + combined_output = torch.zeros_like(reshaped_input) + _ = flashinfer.fused_moe.cutlass_fused_moe( # noqa + reshaped_input.contiguous(), + topk_index.to(torch.int).contiguous(), + topk_weight.to(torch.float).contiguous(), + self.moe_weight, + self.moe_weight_2, + torch.bfloat16, + output=combined_output, + quant_scales=None, + ) + else: + # Original implementation - fallback for compatibility + l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states, topk_impl='default') + dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) + chunks = dispatched_input.chunk(self.num_experts, dim=0) + expert_outputs = [] + for chunk, expert in zip(chunks, self.experts): + expert_outputs.append(expert(chunk)) + + expert_output = torch.cat(expert_outputs, dim=0) + combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output) + + combined_output = combined_output.reshape(bsz, seq_len, hidden_size) + + if self.config.use_mixed_mlp_moe: + output = hidden_states_mlp + combined_output # noqa + else: + output = combined_output + + return output + + def _initialize_weights_on_device(self, device): + expert_weights_gate_up = [] + expert_weights_down = [] + + for expert in self.experts: + expert.to(device) + expert_weights_gate_up.append(expert.gate_and_up_proj.weight.to(device)) + expert_weights_down.append(expert.down_proj.weight.to(device)) + + self.moe_weight = torch.stack(expert_weights_gate_up).contiguous() + self.moe_weight_2 = torch.stack(expert_weights_down).contiguous() + # empty the expert weights + for expert in self.experts: + expert.gate_and_up_proj.weight.data = torch.empty(0, device=device) + if expert.gate_and_up_proj.bias is not None: + expert.gate_and_up_proj.bias.data = torch.empty(0, device=device) + expert.down_proj.weight.data = torch.empty(0, device=device) + if expert.down_proj.bias is not None: + expert.down_proj.bias.data = torch.empty(0, device=device) + + self._weights_initialized = True + + +class HunyuanImage3SDPAAttention(nn.Module): + """PyTorch SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention""" + + def __init__(self, config: HunyuanImage3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_type = 'self' + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + # self.head_dim = self.hidden_size // self.num_heads + self.head_dim = config.attention_head_dim + self.num_key_value_heads = config.num_key_value_heads if config.num_key_value_heads else self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.use_qk_norm = config.use_qk_norm + self.use_rotary_pos_emb = config.use_rotary_pos_emb + self.hidden_size_q = self.head_dim * self.num_heads + self.hidden_size_kv = self.head_dim * self.num_key_value_heads + + # define layers + self.qkv_proj = nn.Linear( + self.hidden_size, + self.hidden_size_q + 2 * self.hidden_size_kv, + bias=config.attention_bias + ) + self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=config.attention_bias) + + if self.use_qk_norm: + self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + if self.use_rotary_pos_emb: + self._init_rope() + + def _init_rope(self): + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "custom": + # Using custom rotary embedding + self.rotary_emb = None + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = False, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if output_attentions: + raise NotImplementedError( + 'HunyuanImage3Model is using HunyuanImage3SDPAAttention,' + 'but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`.' + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2, + self.head_dim) + query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3) + + query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.use_rotary_pos_emb: + cos, sin = custom_pos_emb + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if self.use_qk_norm: + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + query_states = query_states.to(value_states.dtype) + key_states = key_states.to(value_states.dtype) + + if past_key_value is not None: + cache_kwargs = {"cache_position": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + query_states = query_states.to(key_states.dtype) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with + # custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class HunyuanImage3FlashAttention2(HunyuanImage3SDPAAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = False, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if output_attentions: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2, + self.head_dim) + query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3) + + query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.use_rotary_pos_emb: + cos, sin = custom_pos_emb + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if self.use_qk_norm: + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + query_states = query_states.to(value_states.dtype) + key_states = key_states.to(value_states.dtype) + + if past_key_value is not None: + cache_kwargs = {"cache_position": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with + # custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + target_dtype = key_states.dtype if key_states.dtype in [torch.bfloat16, torch.float16] else torch.bfloat16 + + q_fa = query_states.to(target_dtype).transpose(1, 2).contiguous() + k_fa = key_states.to(target_dtype).transpose(1, 2).contiguous() + v_fa = value_states.to(target_dtype).transpose(1, 2).contiguous() + + mode = kwargs.get("mode", "gen_text") + # For gen_text and gen_image, we need to handle the attention differently + with nvtx.range("attention"): + if mode == "gen_text": + if attention_mask is None: + attn_output = flash_attn_func(q_fa, k_fa, v_fa, causal=False) # decode attention + else: + attn_output = flash_attn_func(q_fa, k_fa, v_fa, causal=True) # prefill attention + else: # image attention + gen_timestep_scatter_index: Optional[torch.Tensor] = kwargs.get("gen_timestep_scatter_index", None) + assert gen_timestep_scatter_index is not None, \ + "When gen_image, `gen_timestep_scatter_index` must be provided." + # TODO: batchify + timestep_index = gen_timestep_scatter_index[0, 0].item() + # When image generation, different attention implementations for the first step and the following steps + # help to improve the inference speed. + first_step = kwargs.get("first_step", None) + if first_step is None: + raise ValueError("When gen_image, `first_step` must be provided.") + if first_step: + casual_len = timestep_index + 1 + text_query_states = q_fa[:, :casual_len, :, :] + text_key_states = k_fa[:, :casual_len, :, :] + text_value_states = v_fa[:, :casual_len, :, :] + text_attn_output = flash_attn_func( + text_query_states, text_key_states, text_value_states, causal=True) + image_query_states = q_fa[:, casual_len:, :, :] + image_attn_output = flash_attn_func(image_query_states, k_fa, v_fa, causal=False) + attn_output = torch.cat((text_attn_output, image_attn_output), dim=1) + else: + casual_len = timestep_index + 1 + timestep_query_states = q_fa[:, 0:1, :, :] + timestep_key_states = k_fa[:, :casual_len, :, :] + timestep_value_states = v_fa[:, :casual_len, :, :] + timestep_attn_output = flash_attn_func( + timestep_query_states, timestep_key_states, timestep_value_states, causal=True) + image_query_states = q_fa[:, 1:, :, :] + image_attn_output = flash_attn_func(image_query_states, k_fa, v_fa, causal=False) + attn_output = torch.cat((timestep_attn_output, image_attn_output), dim=1) + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +Hunyuan_ATTENTION_CLASSES = { + "eager": HunyuanImage3SDPAAttention, + "sdpa": HunyuanImage3SDPAAttention, + "flash_attention_2": HunyuanImage3FlashAttention2, +} + + +class HunyuanImage3DecoderLayer(nn.Module): + def __init__(self, config: HunyuanImage3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + attn_impl = config._attn_implementation # noqa + if attn_impl in Hunyuan_ATTENTION_CLASSES: + self.self_attn = Hunyuan_ATTENTION_CLASSES[attn_impl](config=config, layer_idx=layer_idx) + else: + raise ValueError(f"Unsupported attention implementation: {attn_impl}") + + if ((isinstance(config.num_experts, int) and config.num_experts > 1) or ( + isinstance(config.num_experts, list) and max( + config.num_experts) > 1)) and layer_idx >= config.moe_layer_num_skipped: + self.mlp = HunyuanMoE(config, layer_idx=layer_idx) + else: + self.mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=False) + if config.norm_type == 'hf_rms' or config.norm_type == 'rms': + self.input_layernorm = HunyuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = HunyuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + elif config.norm_type == 'fused' or config.norm_type == 'torch_nn': + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + assert False, "other norm_type are not supported" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor | Any]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + custom_pos_emb (`Tuple[torch.FloatTensor]`, *optional*): custom position embedding for rotary + position embedding + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use " + "`attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + **kwargs, + ) + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Hunyuan Image 3 Model outputting raw hidden-states without any specific head on top.", + Hunyuan_START_DOCSTRING, +) +class HunyuanImage3PreTrainedModel(PreTrainedModel): + config_class = HunyuanImage3Config + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["HunyuanImage3DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +Hunyuan_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Hunyuan Model outputting raw hidden-states without any specific head on top.", + Hunyuan_START_DOCSTRING, +) +class HunyuanImage3Model(HunyuanImage3PreTrainedModel): + def __init__(self, config: HunyuanImage3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.add_classification_head = config.add_classification_head + self.wte = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [HunyuanImage3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if not config.add_classification_head: + self.ln_f = HunyuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + self.shared_tensor = None + + @add_start_docstrings_to_model_forward(Hunyuan_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + mode: str = "gen_text", + first_step: Optional[bool] = None, + gen_timestep_scatter_index: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + mode=mode, + first_step=first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if not self.add_classification_head: + # Do ln_f outside of the model for compatibility with image generation. + pass + # hidden_states = self.ln_f(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class HunyuanImage3ForCausalMM(HunyuanImage3PreTrainedModel, GenerationMixin): + def __init__(self, config: HunyuanImage3Config): + super().__init__(config) + self.config = config + self._tkwrapper: Optional[TokenizerWrapper] = None + + # Initialize image preprocessor (for conditional images) + self.image_processor = HunyuanImage3ImageProcessor(config) + + # vae and gen_image pipeline + self.vae = AutoencoderKLConv3D.from_config(config.vae) + self._pipeline = None + + # vit + self.vision_model = Siglip2VisionTransformer(config.vit) + self.vision_aligner = LightProjector(config.vit_aligner) + + # image generation related + self.timestep_emb = TimestepEmbedder(hidden_size=config.hidden_size) + if config.img_proj_type == "unet": + self.patch_embed = UNetDown( + patch_size=config.patch_size, + emb_channels=config.hidden_size, + in_channels=config.vae["latent_channels"], + hidden_channels=config.patch_embed_hidden_dim, + out_channels=config.hidden_size, + ) + self.time_embed = TimestepEmbedder(hidden_size=config.hidden_size) + + self.final_layer = UNetUp( + patch_size=config.patch_size, + emb_channels=config.hidden_size, + in_channels=config.hidden_size, + hidden_channels=config.patch_embed_hidden_dim, + out_channels=config.vae["latent_channels"], + out_norm=True, + ) + self.time_embed_2 = TimestepEmbedder(hidden_size=config.hidden_size) + else: + raise ValueError(f"Unknown img_proj_type {config.img_proj_type}") + + # transformer backbone + self.model = HunyuanImage3Model(config) + + self.pad_id = config.pad_id + self.vocab_size = config.vocab_size + + # linear head + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @property + def tokenizer(self): + if self._tkwrapper is None: + raise ValueError("Attribute `tokenizer` has not been initialized yet. Please set it first.") + return self._tkwrapper + + def load_tokenizer(self, tokenizer): + self._tkwrapper = TokenizerWrapper(tokenizer) + + @property + def pipeline(self): + if self._pipeline is None: + self.scheduler = FlowMatchDiscreteScheduler( + shift=self.generation_config.flow_shift, reverse=True, solver="euler", + ) + self._pipeline = HunyuanImage3Text2ImagePipeline( + model=self, scheduler=self.scheduler, vae=self.vae, + ) + return self._pipeline + + @staticmethod + def get_pos_emb(custom_pos_emb, position_ids): + cos, sin = custom_pos_emb + cos = real_batched_index_select(cos, dim=1, idx=position_ids) + sin = real_batched_index_select(sin, dim=1, idx=position_ids) + return cos, sin + + def instantiate_vae_image_tokens( + self, + x: torch.Tensor, + images: BatchRaggedImages, + ts: BatchRaggedTensor, + image_mask: torch.Tensor, + ): + """ + Instantiate the VAE image embeddings into the input embedding sequence. + + Args: + x: input sequence, (batch_size, seq_len, n_embd) + images: BatchRaggedImages + images can be a 4-D tensor, or a list of 4-D tensors, or a list of lists of 3-D tensors. + ts: BatchRaggedTensor + ts can be a 1-D tensor, or a list of 1-D tensors + image_mask: (batch_size, seq_len) + """ + batch_size, seq_len, n_embd = x.shape + + if isinstance(images, list): + index = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1) + t_emb = [] + for i, (image_i, t_i) in enumerate(zip(images, ts)): + if isinstance(image_i, torch.Tensor): + # time_embed needs a 1-D tensor as input + t_i_emb = self.time_embed(t_i) + # n_{i} x one_image_seq_len x n_embd + image_i_seq, _, _ = self.patch_embed(image_i, t_i_emb) + # 1 x (n_{i} * one_image_seq_len) + image_i_scatter_index = index[i:i + 1].masked_select(image_mask[i:i + 1].bool()).reshape(1, -1) + x[i:i + 1].scatter_( + dim=1, + index=image_i_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd), + # 1 x (n_{i} * one_image_seq_len) x n_embd + src=image_i_seq.reshape(1, -1, n_embd), # 1 x (n_{i} * one_image_seq_len) x n_embd + ) + t_emb.append(t_i_emb) + elif isinstance(image_i, list): + # time_embed needs a 1-D tensor as input + t_i_emb = self.time_embed(t_i) # n_{i} x d + image_i_seq_list = [], [] + for j in range(len(image_i)): + image_ij = image_i[j] + if image_ij.dim() == 4: + assert image_i[j].shape[0] == 1, "image_i[j] should have a batch dimension of 1" + elif image_ij.dim() == 3: + image_ij = image_ij.unsqueeze(0) + else: + raise ValueError(f"image_i[j] should have 3 or 4 dimensions, got {image_ij.dim()}") + # 1 x one_image_seq_len_{j} x n_embd + image_i_seq_j, _, _ = self.patch_embed(image_ij, t_i_emb[j:j + 1]) + image_i_seq_list.append(image_i_seq_j) + # 1 x sum_{j}(one_image_seq_len_{j}) x n_embd + image_i_seq = torch.cat(image_i_seq_list, dim=1) + # 1 x sum_{j}(one_image_seq_len_{j}) + image_i_scatter_index = index[i:i + 1].masked_select(image_mask[i:i + 1].bool()).reshape(1, -1) + x[i:i + 1].scatter_( + dim=1, + index=image_i_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd), + # 1 x sum_{j}(one_image_seq_len_{j}) x n_embd + src=image_i_seq.reshape(1, -1, n_embd), # 1 x sum_{j}(one_image_seq_len_{j}) x n_embd + ) + t_emb.append(t_i_emb) + else: + raise TypeError(f"image_i should be a torch.Tensor or a list, got {type(image_i)}") + token_h, token_w = None, None + else: + # images is a 4-D tensor + batch_size, seq_len, n_embd = x.shape + index = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1) + t_emb = self.time_embed(ts) + image_seq, token_h, token_w = self.patch_embed(images, t_emb) + image_scatter_index = index.masked_select(image_mask.bool()).reshape(batch_size, -1) + x.scatter_( + dim=1, + index=image_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd), + src=image_seq, + ) + + return x, token_h, token_w + + def instantiate_timestep_tokens( + self, + x: torch.Tensor, + t: BatchRaggedTensor, + timestep_scatter_index: BatchRaggedTensor, + ): + batch_size, seq_len, n_embd = x.shape + # batch_size x n x n_embd + timestep_scatter_src = self.timestep_emb(t.reshape(-1)).reshape(batch_size, -1, n_embd) + x.scatter_( + dim=1, + index=timestep_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd), + src=timestep_scatter_src, + ) + + return x + + def instantiate_vit_image_tokens( + self, + x: torch.Tensor, + cond_vit_images: Union[torch.Tensor, List[torch.Tensor]], + cond_vit_image_mask: torch.Tensor, + vit_kwargs: Dict[str, Any], + ): + # 1. Forward the vit encoder and vit aligner to get the vit image embeddings and align them to the + # transformer hidden size + cond_vit_image_embeds = [] + for batch_idx, image in enumerate(cond_vit_images): + cur_kwargs = {k: v[batch_idx] for k, v in vit_kwargs.items()} + image_embed = self.vision_model(image, **cur_kwargs).last_hidden_state + image_embed = self.vision_aligner(image_embed) + n, seq_len, dim = image_embed.shape + image_embed = image_embed.reshape(n * seq_len, dim) + cond_vit_image_embeds.append(image_embed) + + # 2. Instantiate the vit image embeddings into the input sequence + batch_size, seq_len, n_embd = x.shape + index = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1) + + for i, (image_embed, mask) in enumerate(zip(cond_vit_image_embeds, cond_vit_image_mask)): + image_scatter_index = index[i:i+1].masked_select(mask.bool()).reshape(1, -1) + x[i:i+1].scatter_( + dim=1, + index=image_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd), + src=image_embed.reshape(1, -1, n_embd), + ) + + return x + + def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step): + bsz, seq_len, n_embd = x.shape + if first_step: + image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd) + else: + image_output = x[:, 1:, :] + timestep_emb = self.time_embed_2(timestep) + pred = self.final_layer(image_output, timestep_emb, token_h, token_w) + return pred + + @staticmethod + def _check_inputs(cond, target, check_list): + if cond: + for name, item in check_list: + assert item is not None, f"`{name}` should be provided when `{target}`." + + @add_start_docstrings_to_model_forward(Hunyuan_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, + mode: str = "gen_text", + first_step: Optional[bool] = None, + # for gen image + images: Optional[BatchRaggedImages] = None, + image_mask: Optional[torch.Tensor] = None, + timestep: Optional[BatchRaggedTensor] = None, + gen_timestep_scatter_index: Optional[torch.Tensor] = None, + # for cond image + cond_vae_images: Optional[BatchRaggedImages] = None, + cond_timestep: Optional[BatchRaggedTensor] = None, + cond_vae_image_mask: Optional[torch.Tensor] = None, + cond_vit_images: Optional[BatchRaggedImages] = None, + cond_vit_image_mask: Optional[torch.Tensor] = None, + vit_kwargs: Optional[Dict[str, Any]] = None, + cond_timestep_scatter_index: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalMMOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Sanity Check of Inputs + self._check_inputs(mode == "gen_image", "in `gen_image` mode", [ + ("images", images), ("timestep", timestep), ("gen_timestep_scatter_index", gen_timestep_scatter_index), + ]) + self._check_inputs(mode == "gen_image" and first_step, "in `gen_image` mode at the first step", [ + ("image_mask", image_mask), + ]) + self._check_inputs(cond_vae_images is not None, "`cond_vae_images` is provided", [ + ("cond_timestep", cond_timestep), ("cond_vae_image_mask", cond_vae_image_mask), + ("cond_timestep_scatter_index", cond_timestep_scatter_index), + ]) + self._check_inputs(cond_vit_images is not None, "`cond_vit_images` is provided", [ + ("cond_vit_image_mask", cond_vit_image_mask), ("vit_kwargs", vit_kwargs), + ]) + + custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) + + inputs_embeds = self.model.wte(input_ids) + bsz, seq_len, n_embd = inputs_embeds.shape + + # Instantiate placeholder tokens: , for the gen image + if mode == "gen_text": + # For gen_text, make sure gen_timestep_scatter_index is None + gen_timestep_scatter_index = None + token_h, token_w = None, None + else: + if first_step: + inputs_embeds, token_h, token_w = self.instantiate_vae_image_tokens( + inputs_embeds, images, timestep, image_mask) + inputs_embeds = self.instantiate_timestep_tokens( + inputs_embeds, timestep, gen_timestep_scatter_index) + else: + t_emb = self.time_embed(timestep) + image_emb, token_h, token_w = self.patch_embed(images, t_emb) + timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) + inputs_embeds = torch.cat([timestep_emb, image_emb], dim=1) + + # Instantiate placeholder tokens: , for cond images + # Should only run once with kv-cache enabled. + if cond_vae_images is not None: + inputs_embeds, _, _ = self.instantiate_vae_image_tokens( + inputs_embeds, cond_vae_images, cond_timestep, cond_vae_image_mask) + inputs_embeds = self.instantiate_timestep_tokens( + inputs_embeds, cond_timestep, cond_timestep_scatter_index) + if cond_vit_images is not None: + inputs_embeds = self.instantiate_vit_image_tokens( + inputs_embeds, cond_vit_images, cond_vit_image_mask, vit_kwargs) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + custom_pos_emb=custom_pos_emb, + mode=mode, + first_step=first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + hidden_states = outputs[0] + + if mode == "gen_text": + hidden_states = self.model.ln_f(hidden_states) + logits = self.lm_head(hidden_states) + logits = logits.float() + diffusion_prediction = None + else: + logits = None + hidden_states = hidden_states.to(input_ids.device) + diffusion_prediction = self.ragged_final_layer( + hidden_states, image_mask, timestep, token_h, token_w, first_step) + + if not return_dict: + output = (logits,) + outputs[1:] + (diffusion_prediction,) + return output + + output = CausalMMOutputWithPast( + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + diffusion_prediction=diffusion_prediction, + ) + + return output + + @staticmethod + def check_inputs(prompt=None, message_list=None): + if prompt is None and message_list is None: + raise ValueError("Either `prompt` or `message_list` should be provided.") + if prompt is not None and message_list is not None: + raise ValueError("Only one of `prompt` or `message_list` should be provided.") + if prompt is not None: + assert isinstance(prompt, str) or isinstance(prompt, list), \ + f"`prompt` should be a string or a list of strings, but got {type(prompt)}." + if isinstance(prompt, list): + assert len(prompt) > 0 and all(isinstance(p, str) for p in prompt), \ + "`prompt` should be a non-empty list of strings." + if message_list is not None: + if not isinstance(message_list, list): + raise ValueError(f"`message_list` should be a list of messages, but got {type(message_list)}.") + assert len(message_list) > 0, "`message_list` should be a non-empty list." + for message in message_list: + assert isinstance(message, list) or isinstance(message, dict), \ + f"Each message should be a list of dicts or a dict, but got {type(message)}." + + @staticmethod + def prepare_seed(seed, batch_size): + if isinstance(seed, torch.Tensor): + seed = seed.tolist() + if seed is None: + seeds = [random.randint(0, 10_000_000) for _ in range(batch_size)] + elif isinstance(seed, int): + seeds = [seed for _ in range(batch_size)] + elif isinstance(seed, (list, tuple)): + if len(seed) == batch_size: + seeds = [int(seed[i]) for i in range(batch_size)] + else: + raise ValueError(f"Length of seed must be equal to the batch_size({batch_size}), got {seed}.") + else: + raise ValueError(f"Seed must be an integer, a list of integers, or None, got {seed}.") + return seeds + + @staticmethod + def build_batch_rope_image_info(output, sections): + rope_image_info = [] + for image_slices, sections_i in zip(output.all_image_slices, sections): + image_shapes = [] + for section in sections_i: + if 'image' in section['type']: + if isinstance(section['token_height'], list): + assert len(section['token_height']) == len(section['token_height']), \ + (f"token_height and token_width should have the same length, " + f"but got {len(section['token_height'])} and {len(section['token_width'])}") + image_shapes.extend(list(zip(section['token_height'], section['token_width']))) + else: + image_shapes.append((section['token_height'], section['token_width'])) + assert len(image_slices) == len(image_shapes), ( + f"Size miss matching: Image slices({len(image_slices)}) != image shapes({len(image_shapes)})" + ) + rope_image_info.append(list(zip(image_slices, image_shapes))) + return rope_image_info + + def vae_encode(self, image, cfg_factor=1): + config = self.vae.config + + with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): + vae_encode_result = self.vae.encode(image) + if isinstance(vae_encode_result, torch.Tensor): + latents = vae_encode_result + else: + latents = vae_encode_result.latent_dist.sample() + if hasattr(config, 'shift_factor') and config.shift_factor: + latents.sub_(config.shift_factor) + if hasattr(config, 'scaling_factor') and config.scaling_factor: + latents.mul_(config.scaling_factor) + + if hasattr(self.vae, "ffactor_temporal"): + assert latents.shape[2] == 1, "latents should have shape [B, C, T, H, W] and T should be 1" + latents = latents.squeeze(2) + + # Here we always use t=0 to declare it is a clean conditional image + t = torch.zeros((latents.shape[0],)) + + if cfg_factor > 1: + t = t.repeat(cfg_factor) + latents = latents.repeat(cfg_factor, 1, 1, 1) + + return t, latents + + def _encode_cond_image( + self, + batch_cond_image_info_list: List[List[JointImageInfo]], + cfg_factor: int = 1, + ): + # VAE encode one by one, as we assume cond images have different sizes + batch_cond_vae_images, batch_cond_t, batch_cond_vit_images = [], [], [] + for cond_image_info_list in batch_cond_image_info_list: + cond_vae_image_list, cond_t_list, cond_vit_image_list = [], [], [] + for image_info in cond_image_info_list: + cond_t_, cond_vae_image_ = self.vae_encode( + image_info.vae_image_info.image_tensor.to(self.device), + ) + cond_vit_image_list.append(image_info.vision_image_info.image_tensor) + cond_vae_image_list.append(cond_vae_image_.squeeze(0)) + cond_t_list.append(cond_t_) + batch_cond_vae_images.append(cond_vae_image_list) + batch_cond_t.append(cond_t_list) + batch_cond_vit_images.append(torch.cat(cond_vit_image_list, dim=0)) + + # If only one cond image for each sample and all have the same size, we can batch them together + # In this case, cond_vae_images is a 4-D tensor. + if all([len(items) == 1 for items in batch_cond_vae_images]) and all( + items[0].shape == batch_cond_vae_images[0][0].shape for items in batch_cond_vae_images): + cond_vae_images = torch.stack([items[0] for items in batch_cond_vae_images], dim=0) + cond_t = torch.cat([items[0] for items in batch_cond_t], dim=0) + if cfg_factor > 1: + cond_t = cond_t.repeat(cfg_factor) + cond_vae_images = cond_vae_images.repeat(cfg_factor, 1, 1, 1) + else: + # In this case, cond_vae_images is a list of 4-D tensors or a list of lists of 3-D tensors. + cond_t = [torch.cat(item, dim=0) for item in batch_cond_t] + cond_vae_images = [] + for items in batch_cond_vae_images: + if all(items[0].shape == item.shape for item in items): + cond_vae_images.append(torch.stack(items, dim=0)) + else: + cond_vae_images.append(items) + if cfg_factor > 1: + cond_t = cond_t * cfg_factor + cond_vae_images = cond_vae_images * cfg_factor + + if cfg_factor > 1: + batch_cond_vit_images = batch_cond_vit_images * cfg_factor + + return cond_vae_images, cond_t, batch_cond_vit_images + + def prepare_model_inputs( + self, + prompt=None, + mode="gen_text", + system_prompt=None, + cot_text=None, + image_size="auto", + message_list=None, + device=None, + max_new_tokens=None, + **kwargs, + ): + # 1. Sanity check + self.check_inputs(prompt, message_list) + device = default(device, self.device) + + # 2. Format inputs + batch_message_list = message_list + batch_prompt = prompt + batch_cot_text = cot_text + batch_system_prompt = system_prompt + batch_gen_image_info = None + # TODO: construct with user input images + batch_cond_image_info = None + + # -- 2.1 message_list + if batch_message_list is not None: + if isinstance(batch_message_list[0], dict): + batch_message_list = [batch_message_list] + batch_size = len(batch_message_list) + + batch_gen_image_info = [ + [message['content'] for message in message_list_ if message['type'] == 'gen_image'] + for message_list_ in batch_message_list + ] + # At most one gen_image is allowed for each message_list + batch_gen_image_info = [info[-1] if len(info) > 0 else None for info in batch_gen_image_info] + # Multiple cond images are allowed. + batch_cond_image_info = [ + [message['content'] for message in message_list_ if message['type'] == 'joint_image'] + for message_list_ in batch_message_list + ] + + # -- 2.2 Prompt, cot text, system prompt + else: + if isinstance(batch_prompt, str): + batch_prompt = [batch_prompt] + batch_size = len(batch_prompt) + + if batch_cot_text is not None: + if isinstance(batch_cot_text, str): + batch_cot_text = [batch_cot_text] + else: + assert isinstance(batch_cot_text, list) and len(batch_cot_text) == batch_size, \ + "`cot_text` should be a string or a list of strings with the same length as `prompt`." + + if batch_system_prompt is not None: + if isinstance(batch_system_prompt, str): + batch_system_prompt = [batch_system_prompt] + else: + assert isinstance(batch_system_prompt, list) and len(batch_system_prompt) == batch_size, \ + "`system_prompts` should be a string or a list of strings with the same length as `prompt`." + + if mode == "gen_image": + batch_gen_image_info = [self.image_processor.build_image_info(image_size) for _ in range(batch_size)] + + # -- 2.3 seed + seeds = self.prepare_seed(seed=kwargs.get('seed'), batch_size=batch_size) + generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] + + # 3. apply chat template + cfg_factor = {"gen_text": 1, "gen_image": 2} + bot_task = kwargs.pop("bot_task", "auto") + # If `drop_think` enabled, always drop parts in the context. + drop_think = kwargs.get('drop_think', self.generation_config.drop_think) + # Apply batched prompt or batched message_list to build input sequence with associated info. + out = self._tkwrapper.apply_chat_template( + batch_prompt=batch_prompt, + batch_message_list=batch_message_list, + mode=mode, + batch_gen_image_info=batch_gen_image_info, + batch_cond_image_info=batch_cond_image_info, + batch_system_prompt=batch_system_prompt, + batch_cot_text=batch_cot_text, + max_length=kwargs.get('max_length'), + bot_task=bot_task, + image_base_size=self.config.image_base_size, + sequence_template=self.generation_config.sequence_template, + cfg_factor=cfg_factor[mode], + drop_think=drop_think, + ) + output, sections = out['output'], out['sections'] + + # 4. Encode conditional images + if batch_cond_image_info is not None and len(batch_cond_image_info[0]) > 0: + cond_vae_images, cond_timestep, cond_vit_images = self._encode_cond_image( + batch_cond_image_info, cfg_factor[mode] + ) + # Pack vit kwargs. Siglip2-so requires spatial_shapes and attention_mask for inference. + vit_kwargs = {"spatial_shapes": [], "attention_mask": []} + for cond_image_info in batch_cond_image_info: + vit_kwargs["spatial_shapes"].append( + torch.stack([item.vision_encoder_kwargs["spatial_shapes"] for item in cond_image_info])) + vit_kwargs["attention_mask"].append( + torch.stack([item.vision_encoder_kwargs["pixel_attention_mask"] for item in cond_image_info])) + if cfg_factor[mode] > 1: + vit_kwargs["spatial_shapes"] = vit_kwargs["spatial_shapes"] * cfg_factor[mode] + vit_kwargs["attention_mask"] = vit_kwargs["attention_mask"] * cfg_factor[mode] + else: + cond_vae_images, cond_timestep, cond_vit_images = None, None, None + vit_kwargs = None + + # 5. Build position embeddings + rope_image_info = self.build_batch_rope_image_info(output, sections) + if mode == "gen_text": + seq_len = self.generation_config.max_length + else: + seq_len = output.tokens.shape[1] + cos, sin = build_batch_2d_rope( + image_infos=rope_image_info, + seq_len=seq_len, + n_elem=self.config.attention_head_dim, + device=device, + base=self.config.rope_theta, + ) + + # 6. Build kv cache + if bot_task == "img_ratio": + max_new_tokens = 1 + if mode == "gen_image": + # Image generation will not extend sequence length, using token length as max_cache_len is enough. + max_cache_len = output.tokens.shape[1] + else: + max_cache_len = output.tokens.shape[1] + default(max_new_tokens, self.generation_config.max_length) + cache = HunyuanStaticCache( + config=self.config, + batch_size=batch_size * cfg_factor[mode], + max_cache_len=max_cache_len, + dtype=torch.bfloat16, + dynamic=mode == "gen_text", + ) + + # 7. Build position ids + batch_input_pos = torch.arange( + 0, output.tokens.shape[1], dtype=torch.long, device=device)[None].expand( + batch_size * cfg_factor[mode], -1) # use expand to share indices to save memory + + # 8. Build model input kwargs + tkw = self._tkwrapper + if image_size == "auto": + extra_auto_stops = [tkw.special_token_map[f""] for i in range(33)] + else: + extra_auto_stops = [tkw.boi_token_id] + stop_token_id = dict( + auto=[tkw.eos_token_id] + extra_auto_stops, + recaption=[tkw.end_recaption_token_id, tkw.end_answer_token_id, tkw.eos_token_id], + think=[tkw.end_recaption_token_id, tkw.end_answer_token_id, tkw.eos_token_id], + img_ratio=extra_auto_stops, + ) + model_input_kwargs = dict( + input_ids=output.tokens.to(device), + position_ids=batch_input_pos, + past_key_values=cache, + custom_pos_emb=(cos, sin), + mode=mode, + image_mask=to_device(output.gen_image_mask, device), + gen_timestep_scatter_index=to_device(output.gen_timestep_scatter_index, device), + cond_vae_images=to_device(cond_vae_images, device), + cond_timestep=to_device(cond_timestep, device), + cond_vae_image_mask=to_device(output.cond_vae_image_mask, device), + cond_vit_images=to_device(cond_vit_images, device), + cond_vit_image_mask=to_device(output.cond_vit_image_mask, device), + vit_kwargs={ + k: to_device(v, self.device) for k, v in vit_kwargs.items() + } if vit_kwargs is not None else None, + cond_timestep_scatter_index=to_device(output.cond_timestep_scatter_index, device), + # for inner usage + tokenizer_output=output, + batch_gen_image_info=batch_gen_image_info, + generator=generator, + # generation config + eos_token_id=stop_token_id[bot_task], + max_new_tokens=max_new_tokens, + ) + + return model_input_kwargs + + def _prepare_attention_mask_for_generation( + self, + inputs_tensor: torch.Tensor, + generation_config: GenerationConfig, + model_kwargs: Dict[str, Any], + ) -> torch.Tensor: + # create `4d` bool attention mask (b, 1, seqlen, seqlen) using this implementation to bypass the 2d requirement + # in the `transformers.generation_utils.GenerationMixin.generate`. + # This implementation can handle sequences with text and image modalities, where text tokens use causal + # attention and image tokens use full attention. + bsz, seq_len = inputs_tensor.shape + tokenizer_output = model_kwargs["tokenizer_output"] + batch_image_slices = [ + tokenizer_output.joint_image_slices[i] + tokenizer_output.gen_image_slices[i] + for i in range(bsz) + ] + attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) + for i in range(bsz): + for j, image_slice in enumerate(batch_image_slices[i]): + attention_mask[i, image_slice, image_slice] = True + attention_mask = attention_mask.unsqueeze(1) + return attention_mask + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, + tokenizer_output=None, batch_gen_image_info=None, generator=None, **kwargs + ): + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + if input_ids.shape[1] != kwargs["position_ids"].shape[1]: # in decode steps + input_ids = torch.gather(input_ids, dim=1, index=kwargs["position_ids"]) + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "position_ids": kwargs["position_ids"], + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "custom_pos_emb": kwargs["custom_pos_emb"], + "mode": kwargs["mode"], + "images": kwargs.get("images"), + "image_mask": kwargs.get("image_mask"), + "timestep": kwargs.get("timestep"), + "gen_timestep_scatter_index": kwargs.get("gen_timestep_scatter_index"), + "cond_vae_images": kwargs.get("cond_vae_images"), + "cond_timestep": kwargs.get("cond_timestep"), + "cond_vae_image_mask": kwargs.get("cond_vae_image_mask"), + "cond_vit_images": kwargs.get("cond_vit_images"), + "cond_vit_image_mask": kwargs.get("cond_vit_image_mask"), + "vit_kwargs": kwargs.get("vit_kwargs"), + "cond_timestep_scatter_index": kwargs.get("cond_timestep_scatter_index"), + } + ) + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + mode = model_kwargs["mode"] + + updated_model_kwargs = { + "mode": mode, + "custom_pos_emb": model_kwargs["custom_pos_emb"], + } + + # update past_key_values keeping its naming used in model code + for possible_cache_name in ALL_CACHE_NAMES: + if possible_cache_name in outputs: + # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated + if possible_cache_name in ("past_buckets_states", "mems"): + cache_name = "past_key_values" + else: + cache_name = possible_cache_name + updated_model_kwargs[cache_name] = getattr(outputs, possible_cache_name) + break + + if "tokenizer_output" in model_kwargs: + if mode == "gen_text": + # When enable batching, we use right padding, which requires a real_pos to index the valid + # end position of the sequence. If tokenizer_output in model_kwargs, it means we are in the + # prefill step of generation. + real_pos = to_device(model_kwargs["tokenizer_output"].real_pos, self.device) + updated_model_kwargs["position_ids"] = real_pos + else: + # position ids + image_mask = model_kwargs["image_mask"] + bsz, seq_len = image_mask.shape + index = torch.arange(seq_len, device=image_mask.device).unsqueeze(0).repeat(bsz, 1) + position_ids = index.masked_select(image_mask.bool()).reshape(bsz, -1) + timestep_position_ids = \ + index[torch.arange(bsz), model_kwargs["gen_timestep_scatter_index"][:, -1]].unsqueeze(-1) + updated_model_kwargs["position_ids"] = torch.cat([timestep_position_ids, position_ids], dim=1) + + # attention mask + mask_list = [] + for attention_mask_i, position_ids_i in zip( + model_kwargs["attention_mask"], updated_model_kwargs["position_ids"]): + mask_list.append(torch.index_select(attention_mask_i, dim=1, index=position_ids_i.reshape(-1))) + attention_mask = torch.stack(mask_list, dim=0) + updated_model_kwargs["attention_mask"] = attention_mask + updated_model_kwargs["gen_timestep_scatter_index"] = model_kwargs["gen_timestep_scatter_index"] + + else: + if mode == "gen_text": + # Now we are in the decode steps. + updated_model_kwargs["position_ids"] = model_kwargs["position_ids"] + 1 + else: + updated_model_kwargs["position_ids"] = model_kwargs["position_ids"] + updated_model_kwargs["attention_mask"] = model_kwargs["attention_mask"] + updated_model_kwargs["gen_timestep_scatter_index"] = model_kwargs["gen_timestep_scatter_index"] + + return updated_model_kwargs + + def _generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + use_model_defaults: Optional[bool] = None, + generator: Optional[List[torch.Generator]] = None, + verbose: int = 0, + **kwargs, + ): + mode = kwargs.get("mode", "gen_text") + + # Log info + if verbose >= 1: + output = kwargs["tokenizer_output"] + context = self._tkwrapper.tokenizer.decode(output.tokens[0], skip_special_tokens=False) + # Replace ... with []{number} + context = re.sub(r"()+", lambda m: f"[]{{{len(m.group(0)) // 5}}}", context) + info_list = [ + ("token shape", output.tokens.shape), + ("context[0]", context), + ] + gen_config = default(generation_config, self.generation_config) + if mode == "gen_image": + if generator is not None: + info_list.extend([ + ("seed", [g.initial_seed() for g in generator]), + ]) + info_list.extend([ + ("image_size", [f"{info.image_height}x{info.image_width}" for info in kwargs["batch_gen_image_info"]]), + ("infer_steps", kwargs.get("diff_infer_steps", gen_config.diff_infer_steps)), + ("guidance_scale", kwargs.get("diff_guidance_scale", gen_config.diff_guidance_scale)), + ("flow_shift", kwargs.get("flow_shift", gen_config.flow_shift)), + ]) + else: + info_list.extend([ + ("do_sample", kwargs.get("do_sample", gen_config.do_sample)), + ("max_new_tokens", kwargs.get("max_new_tokens", gen_config.max_new_tokens)), + ("top_k", kwargs.get("top_k", gen_config.top_k)), + ("top_p", kwargs.get("top_p", gen_config.top_p)), + ("temperature", kwargs.get("temperature", gen_config.temperature)), + ("repetition_penalty", kwargs.get("repetition_penalty", gen_config.repetition_penalty)), + ]) + max_key_len = max(len(k) for k, _ in info_list) + info_str = "=" * 50 + \ + "\nModel input info:\n" + \ + "\n".join([f" {k.rjust(max_key_len)}: {v}" for k, v in info_list]) + \ + "\n--------------------------------------------------" + print(info_str) + + if mode == "gen_text": + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + return super().generate( + inputs, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + assistant_model, + streamer, + negative_prompt_ids, + negative_prompt_attention_mask, + use_model_defaults, + **kwargs, + ) + + elif mode == "gen_image": + batch_gen_image_info: List[ImageInfo] = kwargs.get("batch_gen_image_info") + if batch_gen_image_info is None: + raise ValueError("`batch_gen_image_info` should be provided when `mode` is `gen_image`.") + + results = self.pipeline( + batch_size=len(batch_gen_image_info), + image_size=[batch_gen_image_info[0].image_height, batch_gen_image_info[0].image_width], + num_inference_steps=kwargs.get("diff_infer_steps", self.generation_config.diff_infer_steps), + guidance_scale=kwargs.get("diff_guidance_scale", self.generation_config.diff_guidance_scale), + generator=generator, + model_kwargs=kwargs, + ) + samples = results[0] + return samples + + else: + raise ValueError(f"Unknown mode {mode}, only `gen_text` and `gen_image` are supported.") + + def get_cot_text(self, output: torch.Tensor): + if output.ndim == 2: + return [self.get_cot_text(output_i) for output_i in output] + elif output.ndim == 1: + if output[-1] == self._tkwrapper.eos_token_id: + output = output[:-1] + cot_text = self._tkwrapper.decode(output).split("Assistant: ")[1] + return cot_text + else: + raise ValueError(f"output should be 1D or 2D tensor, but got {output.ndim}D tensor.") + + def generate_image( + self, + prompt, + seed=None, + image_size="auto", + use_system_prompt=None, + system_prompt=None, + bot_task=None, + stream=False, + **kwargs, + ): + max_new_tokens = kwargs.pop("max_new_tokens", 8192) + verbose = kwargs.pop("verbose", 0) + + if stream: + from transformers import TextStreamer + streamer = TextStreamer(self._tkwrapper.tokenizer, skip_prompt=True, skip_special_tokens=False) + kwargs["streamer"] = streamer + + use_system_prompt = default(use_system_prompt, self.generation_config.use_system_prompt) + bot_task = default(bot_task, self.generation_config.bot_task) + system_prompt = get_system_prompt(use_system_prompt, bot_task, system_prompt) + + if bot_task in ["think", "recaption"]: + # Cot + model_inputs = self.prepare_model_inputs( + prompt=prompt, bot_task=bot_task, system_prompt=system_prompt, max_new_tokens=max_new_tokens) + print(f"<{bot_task}>", end="", flush=True) + outputs = self._generate(**model_inputs, **kwargs, verbose=verbose) + cot_text = self.get_cot_text(outputs[0]) + # Switch system_prompt to `en_recaption` if drop_think is enabled. + if self.generation_config.drop_think and system_prompt: + system_prompt = t2i_system_prompts["en_recaption"][0] + else: + cot_text = None + + # Image ratio + if image_size == "auto": + model_inputs = self.prepare_model_inputs( + prompt=prompt, cot_text=cot_text, bot_task="img_ratio", system_prompt=system_prompt, seed=seed) + outputs = self._generate(**model_inputs, **kwargs, verbose=verbose) + ratio_index = outputs[0, -1].item() - self._tkwrapper.ratio_token_offset + reso = self.image_processor.reso_group[ratio_index] + image_size = reso.height, reso.width + + # Generate image + model_inputs = self.prepare_model_inputs( + prompt=prompt, cot_text=cot_text, system_prompt=system_prompt, mode="gen_image", seed=seed, + image_size=image_size, + ) + outputs = self._generate(**model_inputs, **kwargs, verbose=verbose) + return outputs[0] diff --git a/hunyuan_image_3_pipeline.py b/hunyuan_image_3_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a119cab23ca9906fb5352f336a3f09a27ccc37e3 --- /dev/null +++ b/hunyuan_image_3_pipeline.py @@ -0,0 +1,879 @@ +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================================== + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +@dataclass +class HunyuanImage3Text2ImagePipelineOutput(BaseOutput): + samples: Union[List[Any], np.ndarray] + + +@dataclass +class FlowMatchDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + reverse (`bool`, defaults to `True`): + Whether to reverse the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + reverse: bool = True, + solver: str = "euler", + use_flux_shift: bool = False, + flux_base_shift: float = 0.5, + flux_max_shift: float = 1.15, + n_tokens: Optional[int] = None, + ): + sigmas = torch.linspace(1, 0, num_train_timesteps + 1) + + if not reverse: + sigmas = sigmas.flip(0) + + self.sigmas = sigmas + # the value fed to model + self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) + self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32) + + self._step_index = None + self._begin_index = None + + self.supported_solver = [ + "euler", + "heun-2", "midpoint-2", + "kutta-4", + ] + if solver not in self.supported_solver: + raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}") + + # empty dt and derivative (for heun) + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + @property + def state_in_first_order(self): + return self.derivative_1 is None + + @property + def state_in_second_order(self): + return self.derivative_2 is None + + @property + def state_in_third_order(self): + return self.derivative_3 is None + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, + n_tokens: int = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + """ + self.num_inference_steps = num_inference_steps + + sigmas = torch.linspace(1, 0, num_inference_steps + 1) + + # Apply timestep shift + if self.config.use_flux_shift: + assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift" + mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens) + sigmas = self.flux_time_shift(mu, 1.0, sigmas) + elif self.config.shift != 1.: + sigmas = self.sd3_time_shift(sigmas) + + if not self.config.reverse: + sigmas = 1 - sigmas + + self.sigmas = sigmas + self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) + self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) + + # empty dt and derivative (for kutta) + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + + # Reset step index + self._step_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + return sample + + @staticmethod + def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + @staticmethod + def flux_time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def sd3_time_shift(self, t: torch.Tensor): + return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + pred_uncond: torch.FloatTensor = None, + generator: Optional[torch.Generator] = None, + n_tokens: Optional[int] = None, + return_dict: bool = True, + ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + model_output = model_output.to(torch.float32) + pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None + + # dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + last_inner_step = True + if self.config.solver == "euler": + derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample) + elif self.config.solver in ["heun-2", "midpoint-2"]: + derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample) + elif self.config.solver == "kutta-4": + derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample) + else: + raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}") + + prev_sample = sample + derivative * dt + + # Cast sample back to model compatible dtype + # prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + if last_inner_step: + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) + + def first_order_method(self, model_output, sigma, sigma_next, sample): + derivative = model_output + dt = sigma_next - sigma + return derivative, dt, sample, True + + def second_order_method(self, model_output, sigma, sigma_next, sample): + if self.state_in_first_order: + # store for 2nd order step + self.derivative_1 = model_output + self.dt = sigma_next - sigma + self.sample = sample + + derivative = model_output + if self.config.solver == 'heun-2': + dt = self.dt + elif self.config.solver == 'midpoint-2': + dt = self.dt / 2 + else: + raise NotImplementedError(f"Solver {self.config.solver} not supported.") + last_inner_step = False + + else: + if self.config.solver == 'heun-2': + derivative = 0.5 * (self.derivative_1 + model_output) + elif self.config.solver == 'midpoint-2': + derivative = model_output + else: + raise NotImplementedError(f"Solver {self.config.solver} not supported.") + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + last_inner_step = True + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.derivative_1 = None + self.dt = None + self.sample = None + + return derivative, dt, sample, last_inner_step + + def fourth_order_method(self, model_output, sigma, sigma_next, sample): + if self.state_in_first_order: + self.derivative_1 = model_output + self.dt = sigma_next - sigma + self.sample = sample + derivative = model_output + dt = self.dt / 2 + last_inner_step = False + + elif self.state_in_second_order: + self.derivative_2 = model_output + derivative = model_output + dt = self.dt / 2 + last_inner_step = False + + elif self.state_in_third_order: + self.derivative_3 = model_output + derivative = model_output + dt = self.dt + last_inner_step = False + + else: + derivative = (1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 + + 1/6 * model_output) + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + last_inner_step = True + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + self.sample = None + + return derivative, dt, sample, last_inner_step + + def __len__(self): + return self.config.num_train_timesteps + + +class ClassifierFreeGuidance: + def __init__( + self, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__() + self.use_original_formulation = use_original_formulation + + def __call__( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor], + guidance_scale: float, + step: int, + ) -> torch.Tensor: + + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + guidance_scale * shift + + return pred + + +class HunyuanImage3Text2ImagePipeline(DiffusionPipeline): + r""" + Pipeline for condition-to-sample generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + model ([`ModelMixin`]): + A model to denoise the diffused latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `diffusion_model` to denoise the diffused latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "" + _optional_components = [] + _exclude_from_cpu_offload = [] + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + model, + scheduler: SchedulerMixin, + vae, + progress_bar_config: Dict[str, Any] = None, + ): + super().__init__() + + # ========================================================================================== + if progress_bar_config is None: + progress_bar_config = {} + if not hasattr(self, '_progress_bar_config'): + self._progress_bar_config = {} + self._progress_bar_config.update(progress_bar_config) + # ========================================================================================== + + self.register_modules( + model=model, + scheduler=scheduler, + vae=vae, + ) + + # should be a tuple or a list corresponding to the size of latents (batch_size, channel, *size) + # if None, will be treated as a tuple of 1 + self.latent_scale_factor = self.model.config.vae_downsample_factor + self.image_processor = VaeImageProcessor(vae_scale_factor=self.latent_scale_factor) + + # Must start with APG_mode_ + self.cfg_operator = ClassifierFreeGuidance() + + @staticmethod + def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + """ + Denormalize an image array to [0,1]. + """ + return (images / 2 + 0.5).clamp(0, 1) + + @staticmethod + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: + """ + Convert a PyTorch tensor to a NumPy image. + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def numpy_to_pil(images: np.ndarray): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def prepare_extra_func_kwargs(self, func, kwargs): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + extra_kwargs = {} + + for k, v in kwargs.items(): + accepts = k in set(inspect.signature(func).parameters.keys()) + if accepts: + extra_kwargs[k] = v + return extra_kwargs + + def prepare_latents(self, batch_size, latent_channel, image_size, dtype, device, generator, latents=None): + if self.latent_scale_factor is None: + latent_scale_factor = (1,) * len(image_size) + elif isinstance(self.latent_scale_factor, int): + latent_scale_factor = (self.latent_scale_factor,) * len(image_size) + elif isinstance(self.latent_scale_factor, tuple) or isinstance(self.latent_scale_factor, list): + assert len(self.latent_scale_factor) == len(image_size), \ + "len(latent_scale_factor) shoudl be the same as len(image_size)" + latent_scale_factor = self.latent_scale_factor + else: + raise ValueError( + f"latent_scale_factor should be either None, int, tuple of int, or list of int, " + f"but got {self.latent_scale_factor}" + ) + + latents_shape = ( + batch_size, + latent_channel, + *[int(s) // f for s, f in zip(image_size, latent_scale_factor)], + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler + if hasattr(self.scheduler, "init_noise_sigma"): + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + def set_scheduler(self, new_scheduler): + self.register_modules(scheduler=new_scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int, + image_size: List[int], + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + guidance_rescale: float = 0.0, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + model_kwargs: Dict[str, Any] = None, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The text to guide image generation. + image_size (`Tuple[int]` or `List[int]`): + The size (height, width) of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate samples closely linked to the + `condition` at the expense of lower sample quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for sample + generation. Can be used to tweak the same generation with different conditions. If not provided, + a latents tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~DiffusionPipelineOutput`] instead of a + plain tuple. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~DiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~DiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated samples. + """ + + callback_steps = kwargs.pop("callback_steps", None) + pbar_steps = kwargs.pop("pbar_steps", None) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + + cfg_factor = 1 + self.do_classifier_free_guidance + + # Define call parameters + device = self._execution_device + + # Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, + ) + + # Prepare latent variables + latents = self.prepare_latents( + batch_size=batch_size, + latent_channel=self.model.config.vae["latent_channels"], + image_size=image_size, + dtype=torch.bfloat16, + device=device, + generator=generator, + latents=latents, + ) + + # Prepare extra step kwargs. + _scheduler_step_extra_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, {"generator": generator} + ) + + # Prepare model kwargs + input_ids = model_kwargs.pop("input_ids") + attention_mask = self.model._prepare_attention_mask_for_generation( # noqa + input_ids, self.model.generation_config, model_kwargs=model_kwargs, + ) + model_kwargs["attention_mask"] = attention_mask.to(latents.device) + + # Sampling loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * cfg_factor) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + t_expand = t.repeat(latent_model_input.shape[0]) + + model_inputs = self.model.prepare_inputs_for_generation( + input_ids, + images=latent_model_input, + timestep=t_expand, + **model_kwargs, + ) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + model_output = self.model(**model_inputs, first_step=(i == 0)) + pred = model_output["diffusion_prediction"] + pred = pred.to(dtype=torch.float32) + + # perform guidance + if self.do_classifier_free_guidance: + pred_cond, pred_uncond = pred.chunk(2) + pred = self.cfg_operator(pred_cond, pred_uncond, self.guidance_scale, step=i) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + pred = rescale_noise_cfg(pred, pred_cond, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(pred, t, latents, **_scheduler_step_extra_kwargs, return_dict=False)[0] + + if i != len(timesteps) - 1: + model_kwargs = self.model._update_model_kwargs_for_generation( # noqa + model_output, + model_kwargs, + ) + if input_ids.shape[1] != model_kwargs["position_ids"].shape[1]: + input_ids = torch.gather(input_ids, 1, index=model_kwargs["position_ids"]) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor: + latents = latents / self.vae.config.scaling_factor + if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: + latents = latents + self.vae.config.shift_factor + + if hasattr(self.vae, "ffactor_temporal"): + latents = latents.unsqueeze(2) + + with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + + # b c t h w + if hasattr(self.vae, "ffactor_temporal"): + assert image.shape[2] == 1, "image should have shape [B, C, T, H, W] and T should be 1" + image = image.squeeze(2) + + do_denormalize = [True] * image.shape[0] + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image,) + + return HunyuanImage3Text2ImagePipelineOutput(samples=image) diff --git a/image_processor.py b/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..b099570c68bcfc09c3be3a6af9f09a3606c4f334 --- /dev/null +++ b/image_processor.py @@ -0,0 +1,125 @@ +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Tuple + +from PIL import Image +from torchvision import transforms +from transformers import Siglip2ImageProcessorFast + +from .tokenizer_wrapper import ImageInfo, JointImageInfo, ResolutionGroup + + +def resize_and_crop(image: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + tw, th = target_size + w, h = image.size + + tr = th / tw + r = h / w + + # resize + if r < tr: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + image = image.resize((resize_width, resize_height), resample=Image.Resampling.LANCZOS) + + # center crop + crop_top = int(round((resize_height - th) / 2.0)) + crop_left = int(round((resize_width - tw) / 2.0)) + + image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th)) + return image + + +class HunyuanImage3ImageProcessor(object): + def __init__(self, config): + self.config = config + + self.reso_group = ResolutionGroup(base_size=config.image_base_size) + self.vae_processor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), # transform to [-1, 1] + ]) + self.vision_encoder_processor = Siglip2ImageProcessorFast.from_dict(config.vit_processor) + + def build_image_info(self, image_size): + # parse image size (HxW, H:W, or ) + if isinstance(image_size, str): + if image_size.startswith("")) + reso = self.reso_group[ratio_index] + image_size = reso.height, reso.width + elif 'x' in image_size: + image_size = [int(s) for s in image_size.split('x')] + elif ':' in image_size: + image_size = [int(s) for s in image_size.split(':')] + else: + raise ValueError( + f"`image_size` should be in the format of 'HxW', 'H:W' or , got {image_size}.") + assert len(image_size) == 2, f"`image_size` should be in the format of 'HxW', got {image_size}." + elif isinstance(image_size, (list, tuple)): + assert len(image_size) == 2 and all(isinstance(s, int) for s in image_size), \ + f"`image_size` should be a tuple of two integers or a string in the format of 'HxW', got {image_size}." + else: + raise ValueError(f"`image_size` should be a tuple of two integers or a string in the format of 'WxH', " + f"got {image_size}.") + image_width, image_height = self.reso_group.get_target_size(image_size[1], image_size[0]) + token_height = image_height // (self.config.vae_downsample_factor[0] * self.config.patch_size) + token_width = image_width // (self.config.vae_downsample_factor[1] * self.config.patch_size) + base_size, ratio_idx = self.reso_group.get_base_size_and_ratio_index(image_size[1], image_size[0]) + image_info = ImageInfo( + image_type="gen_image", image_width=image_width, image_height=image_height, + token_width=token_width, token_height=token_height, base_size=base_size, ratio_index=ratio_idx, + ) + return image_info + + def preprocess(self, image: Image.Image): + # ==== VAE processor ==== + image_width, image_height = self.reso_group.get_target_size(image.width, image.height) + resized_image = resize_and_crop(image, (image_width, image_height)) + image_tensor = self.vae_processor(resized_image) + token_height = image_height // (self.config.vae_downsample_factor[0] * self.config.patch_size) + token_width = image_width // (self.config.vae_downsample_factor[1] * self.config.patch_size) + base_size, ratio_index = self.reso_group.get_base_size_and_ratio_index(width=image_width, height=image_height) + vae_image_info = ImageInfo( + image_type="vae", + image_tensor=image_tensor.unsqueeze(0), # include batch dim + image_width=image_width, image_height=image_height, + token_width=token_width, token_height=token_height, + base_size=base_size, ratio_index=ratio_index, + ) + + # ==== ViT processor ==== + inputs = self.vision_encoder_processor(image) + image = inputs["pixel_values"].squeeze(0) # seq_len x dim + pixel_attention_mask = inputs["pixel_attention_mask"].squeeze(0) # seq_len + spatial_shapes = inputs["spatial_shapes"].squeeze(0) # 2 (h, w) + vision_encoder_kwargs = dict( + pixel_attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + ) + vision_image_info = ImageInfo( + image_type="vit", + image_tensor=image.unsqueeze(0), # 1 x seq_len x dim + image_width=spatial_shapes[1].item() * self.config.vit_processor["patch_size"], + image_height=spatial_shapes[0].item() * self.config.vit_processor["patch_size"], + token_width=spatial_shapes[1].item(), + token_height=spatial_shapes[0].item(), + image_token_length=self.config.vit_processor["max_num_patches"], + # may not equal to token_width * token_height + ) + return JointImageInfo(vae_image_info, vision_image_info, vision_encoder_kwargs) diff --git a/siglip2.py b/siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a6d613e92c8217c04f8da8226249b7ade7b37a --- /dev/null +++ b/siglip2.py @@ -0,0 +1,564 @@ +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Optional, Tuple, Union +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + + +class Config(object): + def __init__(self, config): + if config is not None: + for key, value in config.items(): + setattr(self, key, value) + + def __getitem__(self, key): + return getattr(self, key, None) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + +class Siglip2VisionEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`List[Tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + # Get positional resized and padded positional embeddings + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] + ) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + resized_positional_embeddings + return embeddings + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, " + f"but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + +class Siglip2SdpaAttention(Siglip2Attention): + """ + Siglip2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Siglip2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt + to SDPA API. + """ + + is_causal = False + + # Adapted from Siglip2Attention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` + # once this is implemented. + warnings.warn( + "Siglip2Model is using Siglip2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` " + "does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. ' + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with + # custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an + # inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph options. + # An inline conditional prevents dynamic shapes from compiling. + is_causal = True if self.is_causal and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +class Siglip2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Siglip2Attention(config=config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very + large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Siglip2EncoderLayer`]. + + Args: + config: Siglip2Config + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = True + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for layer_index, encoder_layer in enumerate(self.layers): # len(self.layers): 27 + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class Siglip2MultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + self.num_heads = config.num_attention_heads + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + if attention_mask is not None: + target_len, source_len = probe.shape[1], hidden_state.shape[1] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) + attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) + attention_mask = attention_mask.reshape(-1, target_len, source_len) + + hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Siglip2VisionTransformer(nn.Module): + def __init__(self, config): + super().__init__() + config = Config(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = Siglip2MultiheadAttentionPoolingHead(config) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values, spatial_shapes) + + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + else: + encoder_attention_mask = attention_mask + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None + if not return_dict: + return (last_hidden_state, pooler_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class LightProjector(nn.Module): + def __init__(self, config): + config = Config(config) + super().__init__() + + if config.projector_type == "linear": + modules = nn.Linear(config.input_dim, config.n_embed) + + elif config.projector_type == "mlp_gelu": + modules = [nn.Linear(config.input_dim, config.n_embed)] + for _ in range(1, config.depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.n_embed, config.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise ValueError(f"Unknown projector type: {config.projector_type}") + + self.layers = modules + + def forward(self, x): + return self.layers(x) diff --git a/system_prompt.py b/system_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..221a3b95f89ecce0c7e3bbfad13e4e61f26885e9 --- /dev/null +++ b/system_prompt.py @@ -0,0 +1,128 @@ +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +t2i_system_prompt_en_vanilla = """ +You are an advanced AI text-to-image generation system. Given a detailed text prompt, your task is to create a high-quality, visually compelling image that accurately represents the described scene, characters, or objects. Pay careful attention to style, color, lighting, perspective, and any specific instructions provided. +""" + +# 775 +t2i_system_prompt_en_recaption = """ +You are a world-class image generation prompt expert. Your task is to rewrite a user's simple description into a **structured, objective, and detail-rich** professional-level prompt. + +The final output must be wrapped in `` tags. + +### **Universal Core Principles** + +When rewriting the prompt (inside the `` tags), you must adhere to the following principles: + +1. **Absolute Objectivity**: Describe only what is visually present. Avoid subjective words like "beautiful" or "sad". Convey aesthetic qualities through specific descriptions of color, light, shadow, and composition. +2. **Physical and Logical Consistency**: All scene elements (e.g., gravity, light, shadows, reflections, spatial relationships, object proportions) must strictly adhere to real-world physics and common sense. For example, tennis players must be on opposite sides of the net; objects cannot float without a cause. +3. **Structured Description**: Strictly follow a logical order: from general to specific, background to foreground, and primary to secondary elements. Use directional terms like "foreground," "mid-ground," "background," and "left side of the frame" to clearly define the spatial layout. +4. **Use Present Tense**: Describe the scene from an observer's perspective using the present tense, such as "A man stands..." or "Light shines on..." +5. **Use Rich and Specific Descriptive Language**: Use precise adjectives to describe the quantity, size, shape, color, and other attributes of objects, subjects, and text. Vague expressions are strictly prohibited. + +If the user specifies a style (e.g., oil painting, anime, UI design, text rendering), strictly adhere to that style. Otherwise, first infer a suitable style from the user's input. If there is no clear stylistic preference, default to an **ultra-realistic photographic style**. Then, generate the detailed rewritten prompt according to the **Style-Specific Creation Guide** below: + +### **Style-Specific Creation Guide** + +Based on the determined artistic style, apply the corresponding professional knowledge. + +**1. Photography and Realism Style** +* Utilize professional photography terms (e.g., lighting, lens, composition) and meticulously detail material textures, physical attributes of subjects, and environmental details. + +**2. Illustration and Painting Style** +* Clearly specify the artistic school (e.g., Japanese Cel Shading, Impasto Oil Painting) and focus on describing its unique medium characteristics, such as line quality, brushstroke texture, or paint properties. + +**3. Graphic/UI/APP Design Style** +* Objectively describe the final product, clearly defining the layout, elements, and color palette. All text on the interface must be enclosed in double quotes `""` to specify its exact content (e.g., "Login"). Vague descriptions are strictly forbidden. + +**4. Typographic Art** +* The text must be described as a complete physical object. The description must begin with the text itself. Use a straightforward front-on or top-down perspective to ensure the entire text is visible without cropping. + +### **Final Output Requirements** + +1. **Output the Final Prompt Only**: Do not show any thought process, Markdown formatting, or line breaks. +2. **Adhere to the Input**: You must retain the core concepts, attributes, and any specified text from the user's input. +3. **Style Reinforcement**: Mention the core style 3-5 times within the prompt and conclude with a style declaration sentence. +4. **Avoid Self-Reference**: Describe the image content directly. Remove redundant phrases like "This image shows..." or "The scene depicts..." +5. **The final output must be wrapped in `xxxx` tags.** + +The user will now provide an input prompt. You will provide the expanded prompt. +""" + +# 890 +t2i_system_prompt_en_think_recaption = """ +You will act as a top-tier Text-to-Image AI. Your core task is to deeply analyze the user's text input and transform it into a detailed, artistic, and fully user-intent-compliant image. + +Your workflow is divided into two phases: + +1. Thinking Phase (): In the tag, you need to conduct a structured thinking process, progressively breaking down and enriching the constituent elements of the image. This process must include, but is not limited to, the following dimensions: + +Subject: Clearly define the core character(s) or object(s) in the scene, including their appearance, posture, expression, and emotion. +Composition: Set the camera angle and layout, such as close-up, long shot, bird's-eye view, golden ratio composition, etc. +Environment/Background: Describe the scene where the subject is located, including the location, time of day, weather, and other elements in the background. +Lighting: Define the type, direction, and quality of the light source, such as soft afternoon sunlight, cool tones of neon lights, dramatic Rembrandt lighting, etc., to create a specific atmosphere. +Color Palette: Set the main color tone and color scheme of the image, such as vibrant and saturated, low-saturation Morandi colors, black and white, etc. +Quality/Style: Determine the artistic style and technical details of the image. This includes user-specified styles (e.g., anime, oil painting) or the default realistic style, as well as camera parameters (e.g., focal length, aperture, depth of field). +Details: Add minute elements that enhance the realism and narrative quality of the image, such as a character's accessories, the texture of a surface, dust particles in the air, etc. + + +2. Recaption Phase (): In the tag, merge all the key details from the thinking process into a coherent, precise, and visually evocative final description. This description is the direct instruction for generating the image, so it must be clear, unambiguous, and organized in a way that is most suitable for an image generation engine to understand. + +Absolutely Objective: Describe only what is visually present. Avoid subjective words like "beautiful" or "sad." Convey aesthetic sense through concrete descriptions of colors, light, shadow, and composition. + +Physical and Logical Consistency: All scene elements (e.g., gravity, light and shadow, reflections, spatial relationships, object proportions) must strictly adhere to the physical laws of the real world and common sense. For example, in a tennis match, players must be on opposite sides of the net; objects cannot float without reason. + +Structured Description: Strictly follow a logical order: from whole to part, background to foreground, and primary to secondary. Use directional words like "foreground," "mid-ground," "background," "left side of the frame" to clearly define the spatial layout. + +Use Present Tense: Describe from an observer's perspective using the present tense, such as "a man stands," "light shines on..." +Use Rich and Specific Descriptive Language: Use precise adjectives to describe the quantity, size, shape, color, and other attributes of objects/characters/text. Absolutely avoid any vague expressions. + + +Output Format: +Thinking processRefined image descriptionGenerate Image + + +You must strictly adhere to the following rules: + +1. Faithful to Intent, Reasonable Expansion: You can creatively add details to the user's description to enhance the image's realism and artistic quality. However, all additions must be highly consistent with the user's core intent and never introduce irrelevant or conflicting elements. +2. Style Handling: When the user does not specify a style, you must default to an "Ultra-realistic, Photorealistic" style. If the user explicitly specifies a style (e.g., anime, watercolor, oil painting, cyberpunk, etc.), both your thinking process and final description must strictly follow and reflect that specified style. +3. Text Rendering: If specific text needs to appear in the image (such as words on a sign, a book title), you must enclose this text in English double quotes (""). Descriptive text must not use double quotes. +4. Design-related Images: You need to specify all text and graphical elements that appear in the image and clearly describe their design details, including font, color, size, position, arrangement, visual effects, etc. +""" + +t2i_system_prompts = { + "en_vanilla": [t2i_system_prompt_en_vanilla], + "en_recaption": [t2i_system_prompt_en_recaption], + "en_think_recaption": [t2i_system_prompt_en_think_recaption] +} + + +def get_system_prompt(sys_type, bot_task, system_prompt=None): + if sys_type == 'None': + return None + elif sys_type in ['en_vanilla', 'en_recaption', 'en_think_recaption']: + return t2i_system_prompts[sys_type][0] + elif sys_type == "dynamic": + if bot_task == "think": + return t2i_system_prompts["en_think_recaption"][0] + elif bot_task == "recaption": + return t2i_system_prompts["en_recaption"][0] + elif bot_task == "image": + return t2i_system_prompts["en_vanilla"][0].strip("\n") + else: + return system_prompt + elif sys_type == 'custom': + return system_prompt + else: + raise NotImplementedError(f"Unsupported system prompt type: {sys_type}") diff --git a/tokenizer_wrapper.py b/tokenizer_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4dbb6b68b6f6f598ba3c36635f44cfcc210fe7 --- /dev/null +++ b/tokenizer_wrapper.py @@ -0,0 +1,1425 @@ +# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warnings +import random +from typing import List, Optional, Union, Dict, Any +from collections import defaultdict +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer +from diffusers.utils import BaseOutput + + +def default(value, default_value): + return value if value is not None else default_value + + +def ensure_list(value): + if value is None: + return [] + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + +class Resolution(object): + def __init__(self, size, *args): + if isinstance(size, str): + if 'x' in size: + size = size.split('x') + size = (int(size[0]), int(size[1])) + else: + size = int(size) + if len(args) > 0: + size = (size, args[0]) + if isinstance(size, int): + size = (size, size) + + self.h = self.height = size[0] + self.w = self.width = size[1] + self.r = self.ratio = self.height / self.width + + def __getitem__(self, idx): + if idx == 0: + return self.h + elif idx == 1: + return self.w + else: + raise IndexError(f'Index {idx} out of range') + + def __str__(self): + return f'{self.h}x{self.w}' + + +class ResolutionGroup(object): + def __init__(self, base_size=None, step=None, align=1): + self.align = align + self.base_size = base_size + assert base_size % align == 0, f'base_size {base_size} is not divisible by align {align}' + if base_size is not None and not isinstance(base_size, int): + raise ValueError(f'base_size must be None or int, but got {type(base_size)}') + if step is None: + step = base_size // 16 + if step is not None and step > base_size // 2: + raise ValueError(f'step must be smaller than base_size // 2, but got {step} > {base_size // 2}') + + self.step = step + self.data = self._calc_by_step() + + self.ratio = np.array([x.ratio for x in self.data]) + self.attr = ['' for _ in range(len(self.data))] + self.prefix_space = 0 + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def __repr__(self): + prefix = self.prefix_space * ' ' + prefix_close = (self.prefix_space - 4) * ' ' + res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data=' + attr_maxlen = max([len(x) for x in self.attr] + [5]) + res_str += \ + f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}' + res_str += \ + ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} ' + f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}' + for i, x in enumerate(self.data)]) + res_str += f'\n{prefix_close})' + return res_str + + def _calc_by_step(self): + assert self.align <= self.step, f'align {self.align} must be smaller than step {self.step}' + + min_height = self.base_size // 2 + min_width = self.base_size // 2 + max_height = self.base_size * 2 + max_width = self.base_size * 2 + + resolutions = [Resolution(self.base_size, self.base_size)] + + cur_height, cur_width = self.base_size, self.base_size + while True: + if cur_height >= max_height and cur_width <= min_width: + break + + cur_height = min(cur_height + self.step, max_height) + cur_width = max(cur_width - self.step, min_width) + resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) + + cur_height, cur_width = self.base_size, self.base_size + while True: + if cur_height <= min_height and cur_width >= max_width: + break + + cur_height = max(cur_height - self.step, min_height) + cur_width = min(cur_width + self.step, max_width) + resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) + + resolutions = sorted(resolutions, key=lambda x: x.ratio) + + return resolutions + + def get_target_size(self, width, height): + ratio = height / width + idx = np.argmin(np.abs(self.ratio - ratio)) + reso = self.data[idx] + return reso.w, reso.h + + def get_base_size_and_ratio_index(self, width, height): + ratio = height / width + idx = np.argmin(np.abs(self.ratio - ratio)) + return self.base_size, idx + + +class ImageInfo: + """ Class to store image information for processing and generation. """ + + def __init__( + self, + image_type: str = None, + image_tensor: torch.Tensor = None, + image_width: int = None, + image_height: int = None, + token_width: int = None, + token_height: int = None, + image_token_length: int = None, + base_size: int = None, + ratio_index: int = None, + **kwargs, + ): + self.image_type = image_type + self.image_tensor = image_tensor + self.image_width = image_width + self.w = image_width + self.image_height = image_height + self.h = image_height + self.token_width = token_width + self.tk_w = token_width + self.token_height = token_height + self.tk_h = token_height + self.image_token_length = default( + image_token_length, + token_width * token_height if token_width is not None and token_height is not None else None + ) + self.base_size = base_size + self.ratio_index = ratio_index + + self.add_timestep_token = kwargs.get("add_timestep_token", True) + self.add_guidance_token = kwargs.get("add_guidance_token", False) + self.use_front_boi_token = kwargs.get("use_front_boi_token", True) + self.add_image_shape_token = kwargs.get("add_image_shape_token", True) + + def __getitem__(self, key: str) -> Any: + """Allow dictionary-like access to attributes.""" + if hasattr(self, key): + return getattr(self, key) + raise KeyError(f"Key '{key}' not found in ImageInfo") + + def __setitem__(self, key: str, value: Any) -> None: + """Allow dictionary-like assignment to attributes.""" + if hasattr(self, key): + setattr(self, key, value) + else: + raise KeyError(f"Key '{key}' not found in ImageInfo") + + def __contains__(self, key: str) -> bool: + """Check if the key exists in the ImageInfo object.""" + return hasattr(self, key) + + def __repr__(self): + return (f"ImageInfo(image_type={self.image_type}, image_tensor={self.image_tensor}, " + f"image_width={self.image_width}, image_height={self.image_height}, " + f"token_width={self.token_width}, token_height={self.token_height}, " + f"image_token_length={self.image_token_length}, " + f"base_size={self.base_size}, ratio_index={self.ratio_index}") + + @property + def meta_info(self): + # Used for image sections of tkwrapper.encode_general() + if self.image_type in ["vae", "gen_image"]: + return dict( + token_length=self.image_token_length, + add_timestep_token=self.add_timestep_token, + add_guidance_token=self.add_guidance_token, + use_front_boi_token=self.use_front_boi_token, + add_image_shape_token=self.add_image_shape_token, + base_size=self.base_size, + ratio_idx=self.ratio_index, + # for rope 2d + token_height=self.token_height, + token_width=self.token_width, + # for bc + image_height=self.image_height, + image_width=self.image_width, + ) + elif self.image_type in ["vit"]: + return dict( + token_length=self.image_token_length, + use_front_boi_token=self.use_front_boi_token, + add_image_shape_token=self.add_image_shape_token, + # for rope 2d + token_height=self.token_height, + token_width=self.token_width, + # for bc + image_height=self.image_height, + image_width=self.image_width, + ) + else: + raise ValueError(f"Unknown image type '{self.image_type}'") + + @property + def num_special_tokens(self): + if self.args is None: + raise ValueError("meta_info requires `args` attribute to be set.") + if self.image_type in ["vae", "src_image", "gen_image"]: + count = ( + 2 + # + or + + (1 if self.add_timestep_token else 0) + + (1 if self.add_guidance_token else 0) + + (2 if self.add_image_shape_token else 0) + ) + else: + raise ValueError(f"Unknown image_type: {self.image_type}") + return count + + def copy(self, copy_image_tensor=True): + if copy_image_tensor and self.image_tensor is None: + raise ValueError("image_tensor is None, cannot copy") + return ImageInfo( + image_type=self.image_type, + image_tensor=self.image_tensor.clone() if copy_image_tensor else None, + image_width=self.image_width, + image_height=self.image_height, + token_width=self.token_width, + token_height=self.token_height, + image_token_length=self.image_token_length, + base_size=self.base_size, + ratio_index=self.ratio_index, + ) + + def zeros_(self): + self.image_tensor = torch.zeros_like(self.image_tensor) + + +class ImageTensor(torch.Tensor): + # This class is just for type hinting purposes. Attribute `i` should be defined + # as an instance attribute of the torch.Tensor instance, like: tensor.i = ImageInfo(...) + i: ImageInfo + vision_encoder_kwargs: dict + + +class JointImageInfo(object): + def __init__(self, vae_image_info: ImageInfo, vision_image_info: ImageInfo, vision_encoder_kwargs: dict = None): + self.vae_image_info = vae_image_info + self.vision_image_info = vision_image_info + self.vision_encoder_kwargs = vision_encoder_kwargs + + # Define key attributes to align with ImageInfo for uniformity + self.image_type = "joint_image" + self.image_token_length = vae_image_info.image_token_length + vision_image_info.image_token_length + + self.add_timestep_token = vae_image_info.add_timestep_token + self.use_front_boi_token = vae_image_info.use_front_boi_token + self.add_image_shape_token = vae_image_info.add_image_shape_token + + def __repr__(self): + return f"JointImageInfo(vae_image={self.vae_image_info}, vision_image={self.vision_image_info})" + + @property + def meta_info(self): + # Used for image sections of tkwrapper.encode_general() + return dict( + token_length=[self.vae_image_info.image_token_length, self.vision_image_info.image_token_length], + add_timestep_token=self.add_timestep_token, + use_front_boi_token=self.use_front_boi_token, + add_image_shape_token=self.add_image_shape_token, + base_size=self.vae_image_info.base_size, + ratio_idx=self.vae_image_info.ratio_index, + # for rope 2d + token_height=[self.vae_image_info.token_height, self.vision_image_info.token_height], + token_width=[self.vae_image_info.token_width, self.vision_image_info.token_width], + # for bc + image_height=[self.vae_image_info.image_height, self.vision_image_info.image_height], + image_width=[self.vae_image_info.image_width, self.vision_image_info.image_width], + ) + + @property + def num_special_tokens(self): + return ( + 2 + # + + (1 if self.add_timestep_token else 0) + + (2 if self.add_image_shape_token else 0) + + 1 # + ) + + def copy(self, copy_image_tensor=True): + if copy_image_tensor and ( + self.vae_image_info.image_tensor is None or self.vision_image_info.image_tensor is None): + raise ValueError("image_tensor is None, cannot copy") + return JointImageInfo( + self.vae_image_info.copy(copy_image_tensor), + self.vision_image_info.copy(copy_image_tensor), + self.vision_encoder_kwargs, + ) + + def zeros_(self): + self.vae_image_info.zeros_() + self.vision_image_info.zeros_() + + +class JointImage(object): + def __init__(self, vae_image: ImageTensor, vision_image: ImageTensor): + self.vae_image = vae_image + self.vision_image = vision_image + self.i = JointImageInfo(vae_image.i, vision_image.i) + + +class TokenizerEncodeOutput(BaseOutput): + tokens: torch.Tensor = None + timestep_scatter_index: Optional[torch.Tensor] = None + guidance_scatter_index: Optional[torch.Tensor] = None + text_slices: Optional[List[slice]] = None + gen_image_slices: Optional[List[slice]] = None + joint_image_slices: Optional[List[slice]] = None + cond_vae_image_slices: Optional[List[slice]] = None + cond_vit_image_slices: Optional[List[slice]] = None + text_mask: Optional[torch.Tensor] = None + gen_image_mask: Optional[torch.Tensor] = None + cond_vae_image_mask: Optional[torch.Tensor] = None + cond_vit_image_mask: Optional[torch.Tensor] = None + real_pos: Optional[torch.Tensor] = None + all_image_slices: Optional[List[slice]] = None + cond_timestep_scatter_index: Optional[torch.Tensor] = None + gen_timestep_scatter_index: Optional[torch.Tensor] = None + + +class Conversation: + roles: List[str] = ["User", "Assistant"] + sep: str = "\n\n" + + +class TokenizerWrapper(object): + def __init__(self, tokenizer): + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + else: + self.tokenizer = tokenizer + + # Define short names + self.bos_token_id = self.tokenizer.bos_token_id + self.eos_token_id = self.tokenizer.eos_token_id + self.pad_token_id = self.tokenizer.pad_token_id + self.boi_token_id = self.tokenizer.convert_tokens_to_ids("") + self.eoi_token_id = self.tokenizer.convert_tokens_to_ids("") + self.img_token_id = self.tokenizer.convert_tokens_to_ids("") + self.cfg_token_id = self.tokenizer.convert_tokens_to_ids("") + self.end_answer_token_id = self.tokenizer.convert_tokens_to_ids("") + self.end_recaption_token_id = self.tokenizer.convert_tokens_to_ids("") + self.ratio_token_offset = self.tokenizer.convert_tokens_to_ids("") + self.special_token_map = self.tokenizer.added_tokens_encoder + + def pad(self, tensor_list, dim=0, pad_val=None): + if pad_val is None: + pad_val = self.pad_token_id + max_len = max([t.shape[dim] for t in tensor_list]) + padded_tensor_list = [] + for t in tensor_list: + if t.shape[dim] < max_len: + assert pad_val is not False, "Not allowed pad." + t = F.pad(t, (0, max_len - t.shape[dim]), value=pad_val) + padded_tensor_list.append(t) + return padded_tensor_list + + def encode(self, *args, **kwargs): + return self.tokenizer.encode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + def encode_text( + self, + *texts, + uncond_enabled: Optional[Union[bool, List[bool]]] = None, + uncond_p: Optional[float] = None, + max_length: Optional[int] = None, + pad: Optional[str] = None, + return_lengths: bool = False, + ): + """ + Encode text and image for AR-like model training of the text-to-image/instruction tuning tasks. + Support encode multiple texts at once. Each text can be separately conditioned or unconditioned + based on the uncond_flags and a uniform uncond_p. + ** token is always prepended to the text tokens.** + + Parameters + ---------- + texts: str or List[str] + List of texts to be encoded. + uncond_enabled: bool or List[bool] + List of flags to indicate whether the text should be unconditioned. + If False, the text will never be unconditioned. + If True, the text will be unconditioned with uncond_p. + uncond_p: float + Probability to the unconditional text. Only works when uncond_enabled is True. + max_length: int + Maximum length of the encoded text. + pad: Optional[str] + Padding method. Can be 'left' or 'right'. + return_lengths: bool + Whether to return the length of each encoded text. + """ + if pad is not None: + assert max_length is not None, "max_length should be provided when pad is not None." + + if uncond_enabled is None: + uncond_enabled = [True] * len(texts) + elif isinstance(uncond_enabled, bool): + uncond_enabled = [uncond_enabled] * len(texts) + if len(uncond_enabled) != len(texts): + print(uncond_enabled, texts) + assert len(uncond_enabled) == len(texts), ( + f"Length of uncond_flags should be equal to the number of texts, " + f"but got {len(uncond_enabled)} and {len(texts)}." + ) + + # Prepare text/uncond tokens + # TODO: If len(texts) > 1, such as instruction + prompt in inpainting, we need to determine how to do uncond. + # Now all texts will be cond or uncond at the same time. + do_uncond_drop = (uncond_p is not None) and (random.random() < uncond_p) + text_tokens, lengths = [], [] + cum_length = 0 + for text, uncond_flag in zip(texts, uncond_enabled): + # If reach the max_length and there still have unencoded texts, give a warning message and break the loop. + if max_length is not None and cum_length >= max_length: + warnings.warn( + f"Text length exceeds the max_length({max_length}). The remaining texts will be ignored: " + f"{text[:80]}..." + ) + break + # Set add_special_tokens=False to avoid adding token in some LLMs. + if isinstance(text, str): + text_token = self.tokenizer.encode(text, add_special_tokens=False) + else: + text_token = text + if uncond_flag and do_uncond_drop: + text_token = [self.cfg_token_id] * len(text_token) + # Cutoff the text by max_length if necessary + if max_length is not None and (cum_length + len(text_token)) > max_length: + text_token = text_token[:max_length - cum_length] + text_tokens.extend(text_token) + lengths.append(len(text_token)) + cum_length += len(text_token) + + # Prepend/Append tokens if applicable + if pad is not None and (pad_length := max_length - len(text_tokens)) > 0: + if pad == 'left': + text_tokens = [self.pad_token_id] * pad_length + text_tokens + elif pad == 'right': + text_tokens = text_tokens + [self.pad_token_id] * pad_length + else: + raise ValueError(f"Unsupported padding method: {pad}.") + + if return_lengths: + return text_tokens, lengths + return text_tokens + + @staticmethod + def _check_key_number_matched(keys, data): + # Assert keys and token_source are matched + assert set(keys) == set(data.keys()), ( + f"Keys in the template and token source should be matched, but got {set(keys)} and {list(data.keys())}." + ) + key_counts = {k: 0 for k in keys} + for key in keys: + key_counts[key] += 1 + for key, count in key_counts.items(): + assert len(data[key]) == count, ( + f"Number of `{key}` in the token source should be matched with the template, but got " + f"{data[key]}({len(data[key])}) and {count}." + ) + + def _add_image_meta_info_token(self, token_seq, token_count, extra_token_pos, add_timestep_token=False, + add_image_shape_token=False, base_size=None, ratio_idx=None, image_type=None, + add_guidance_token=False): + if add_image_shape_token: + token_seq.extend([ + self.special_token_map[f""], + self.special_token_map[f""] + ]) + token_count += 2 + if add_timestep_token: + token_seq.extend([self.special_token_map[""]]) + extra_token_pos['timestep'].append(token_count) + if image_type is not None: + if image_type == "gen_image": + extra_token_pos['gen_timestep'].append(token_count) + elif image_type in ["joint_image"]: + extra_token_pos['cond_timestep'].append(token_count) + else: + raise ValueError(f"Unsupported image type: {image_type}.") + token_count += 1 + if add_guidance_token: + token_seq.extend([self.special_token_map[""]]) + extra_token_pos['guidance'].append(token_count) + token_count += 1 + return token_count + + @staticmethod + def _shorten_text(text): + import re + text = re.sub(r"()+", lambda m: f"[]{{{len(m.group(0)) // 5}}}", text) + text = re.sub(r"()+", lambda m: f"[]{{{len(m.group(0)) // 5}}}", text) + return text + + def encode_sequence( + self, + template: str, + token_source: Dict[str, List], + total_length=None, + add_timestep_token=False, + add_guidance_token=False, + last_key_only_prefix=False, + add_eos=True, + use_front_boi_token=True, + add_pad=True, + add_bos=True, + drop_last: Union[str, bool] = 'auto', + add_image_shape_token=False, + ): + """ + Encode a sequence based on the template (e.g., `text-image` for t2i, `text-image-image` for instruction tuning) + and token source. + + Parameters + ---------- + template: str + Template of the sequence. E.g., "text-gen_image" means the sequence is composed of text and an image. + "text-text-gen_image" means the sequence is composed of two sections of text and an image. + token_source: Dict[str, List] + Token source for each key in the template, in order. + - text: List[Dict]. + - gen_image: List[Dict]. + - joint_image: List[Dict]. + total_length: int + Total length of the encoded sequence, include padding tokens. + add_timestep_token: bool + Whether to add timestep token before the image tokens. + (Right after the tokens) + add_guidance_token: bool + Whether to add guidance token before the image tokens. + last_key_only_prefix: bool + Whether to only use the modal prefix in the last key. + add_eos: bool or 'auto' + Whether to add eos token at the end of the sequence. If True, always add eos token. If 'auto', + add eos token only when the total_length is not reached and the last token is not . + use_front_boi_token: bool: + Whether to put the token at the front of iw, ih and timestep tokens. + add_pad: bool or 'auto' + Whether to add padding tokens to the sequence. If True and total_length is not reached, add padding tokens. + add_bos: bool + Whether to add bos token at the beginning of the sequence. + drop_last: bool or 'auto' + - If auto, drop last tokens exceeding the total_length if the total_length is provided. If cut point is + in the middle of the image tokens, an error will raised. + - If True, drop last tokens exceeding the total_length. If cut point is in the middle of the image tokens, + all the successive image tokens will be dropped. + - If False, keep the last tokens exceeding the total_length, even if the total_length is reached. + add_image_shape_token: bool + Whether to add image shape token before the image tokens. (Right before the token) + + Returns + ------- + token_seq: list + Encoded token sequence. + extra_token_pos: dict + Positions of extra tokens. + """ + if last_key_only_prefix: + assert add_eos is not True, "add_eos should not be True when last_key_only_prefix is True." + if drop_last is True and total_length is None: + raise ValueError("total_length should be provided when drop_last is True.") + + keys = template.split('-') + modal_length = len(keys) + index_indicator = {k: 0 for k in token_source} + for k, v in token_source.items(): + assert isinstance(v, (list, tuple)), ( + f"Value of `{k}` in the token source should be a list or tuple, but got {type(v)}." + ) + self._check_key_number_matched(keys, token_source) + + token_seq = [] + token_count = 0 + extra_token_pos = defaultdict(list) + if add_bos: + token_seq.append(self.bos_token_id) + token_count += 1 + # If drop_last is True, we check the token_count on the fly and exit the loop if the total_length is reached. + # This check is only applied to the block tokens. Block tokens mean the tokens that are unsplittable, like + # image tokens. Text tokens are splittable, so we don't need to check the token_count for text. + # If the loop is broken by drop_last, we don't add the eos token at the end because the sequence is not + # complete. + drop_last_break = False + for i, key in enumerate(keys): + source = token_source[key][index_indicator[key]] + if key == "text": + token_seq.extend(source) # text token sequence + extra_token_pos["_start"].append(token_count) + token_count += len(source) + extra_token_pos["_end"].append(token_count - 1) + + elif key == "gen_image": + if isinstance(source, int): + source = {'length': source} + extra_count = 2 + ( + 1 if source.get('timestep', add_timestep_token) else 0) + ( + 1 if source.get('guidance', add_guidance_token) else 0) + ( + 2 if source.get('image_shape', add_image_shape_token) else 0 + ) + if drop_last is True and token_count + extra_count + source['length'] > total_length: + drop_last_break = True + break + if source.get('front_boi', use_front_boi_token): + token_seq.append(self.boi_token_id) + extra_token_pos["boi"].append(token_count) + token_count += 1 + token_count = self._add_image_meta_info_token( + token_seq=token_seq, + token_count=token_count, + extra_token_pos=extra_token_pos, + add_timestep_token=source.get('timestep', add_timestep_token), + add_guidance_token=source.get('guidance', add_guidance_token), + add_image_shape_token=source.get('image_shape', add_image_shape_token), + base_size=source.get('base_size'), + ratio_idx=source.get('ratio_idx'), + image_type=key, + ) + if not source.get('front_boi', use_front_boi_token): + token_seq.append(self.boi_token_id) + extra_token_pos["boi"].append(token_count) + token_count += 1 + if last_key_only_prefix and i == modal_length - 1: + pass # for AR inference + else: + token_seq.extend( + [self.img_token_id] * source['length'] + # token number + [self.eoi_token_id] + ) + extra_token_pos["_start"].append(token_count) + extra_token_pos["_start"].append(token_count) + token_count += source['length'] + extra_token_pos["_end"].append(token_count - 1) + extra_token_pos["_end"].append(token_count - 1) + extra_token_pos["eoi"].append(token_count) + token_count += 1 # + + elif key == "joint_image": + assert isinstance(source['length'], list) and len( + source['length']) == 2, "joint_image length should be a list of two integers" + extra_count = 2 + 1 + ( # boi, eoi, joint_img_sep + 1 if source.get('timestep', add_timestep_token) else 0) + ( + 2 if source.get('image_shape', add_image_shape_token) else 0 + ) + if drop_last is True and token_count + extra_count + sum(source['length']) > total_length: + drop_last_break = True + break + if source.get('front_boi', use_front_boi_token): + token_seq.append(self.boi_token_id) # Use patched boi for Janus, otherwise useing default + extra_token_pos["boi"].append(token_count) + token_count += 1 + token_count = self._add_image_meta_info_token( + token_seq=token_seq, + token_count=token_count, + extra_token_pos=extra_token_pos, + add_timestep_token=source.get('timestep', add_timestep_token), + add_image_shape_token=source.get('image_shape', add_image_shape_token), + base_size=source.get('base_size'), + ratio_idx=source.get('ratio_idx'), + image_type=key, + ) + if not source.get('front_boi', use_front_boi_token): + token_seq.append(self.boi_token_id) + extra_token_pos["boi"].append(token_count) + token_count += 1 + if last_key_only_prefix and i == modal_length - 1: + pass # for AR inference + else: + token_seq.extend( + [self.img_token_id] * source['length'][0] + ) + extra_token_pos["_start"].append(token_count) + extra_token_pos["_start"].append(token_count) + extra_token_pos["_start"].append(token_count) + token_count += source['length'][0] + extra_token_pos["_end"].append(token_count - 1) + extra_token_pos["_end"].append(token_count - 1) + + token_seq.extend( + [self.special_token_map[""]] + ) + extra_token_pos["joint_img_sep"].append(token_count) + token_count += 1 + + token_seq.extend( + [self.img_token_id] * source['length'][1] + ) + extra_token_pos["_start"].append(token_count) + extra_token_pos["_start"].append(token_count) + token_count += source['length'][1] + extra_token_pos["_end"].append(token_count - 1) + extra_token_pos["_end"].append(token_count - 1) + extra_token_pos["_end"].append(token_count - 1) + + token_seq.extend( + [self.eoi_token_id] + ) + extra_token_pos["eoi"].append(token_count) + token_count += 1 # + + else: + raise ValueError(f"Not supported key: {key}") + index_indicator[key] += 1 + + if add_eos is True and not drop_last_break: + # Typically used for t2i task. + token_seq.append(self.eos_token_id) + extra_token_pos["eos"].append(token_count) + token_count += 1 + elif add_eos == 'auto' and not drop_last_break: + # Typically used for lm and mmu task. + if token_seq[-1] != self.eos_token_id and (total_length is None or token_count < total_length): + token_seq.append(self.eos_token_id) + extra_token_pos["eos"].append(token_count) + token_count += 1 + + if total_length: + # Check token count and clip sequence if necessary + if token_count > total_length and drop_last: + # Assert clip position is not in the middle of the block-wise tokens (gen_image, joint_image) + for start_key, end_key in [ + ("_start", "_end"), ("_start", "_end"), + ("_start", "_end"), ("_start", "_end"), + ]: + if start_key in extra_token_pos and end_key in extra_token_pos: + assert all( + (start > total_length or end + 1 < total_length) + for start, end in zip(extra_token_pos[start_key], extra_token_pos[end_key]) + ), ("Clip position should not be in the middle of the image tokens.\n" + f"Below is the text:\n{self._shorten_text(self.tokenizer.decode(token_seq))}") + token_seq = token_seq[:total_length] + + # Pad the sequence if necessary + pad_num = max(0, total_length - len(token_seq)) + if add_pad and pad_num: + token_seq.extend([self.pad_token_id] * pad_num) + extra_token_pos["first_pad"].append(token_count) + + return token_seq, extra_token_pos + + def batch_gen_infer( + self, + infer_fn, + prompt_list: list, + negative_prompt_list: list = None, + infer_fn_kwargs_list: List[Dict[str, int]] = None, + do_classifier_free_guidance=False, + condition_repeat_times: int = 1, + uncondition_repeat_times: int = 1, + ): + """ + Batch inference for the AR-like model training of the text-to-image/instruction tuning tasks. + + Parameters + ---------- + infer_fn: callable + Inference function to encode the prompt. + prompt_list: list + List of prompts. Each element can be a single prompt or a list of prompts passed to the infer_fn. + negative_prompt_list: list + List of negative prompts. Only used when do_classifier_free_guidance is True. If None, will use + token sequence as negative prompt. + infer_fn_kwargs_list: List[Dict[str, int]] + List of keyword arguments for the infer_fn. + do_classifier_free_guidance: bool + Whether to do classifier-free guidance. + condition_repeat_times: int + Support multi-condition. + uncondition_repeat_times: int + Support multi-uncondition. + """ + if infer_fn_kwargs_list is None: + infer_fn_kwargs_list = [{} for _ in prompt_list] + + # [n_output, bsz] + cond_results_list = None + uncond_results_list = None + output_type_list = [] + + for prompt_idx, (prompt, infer_fn_kwargs) in enumerate(zip(prompt_list, infer_fn_kwargs_list)): + if not isinstance(prompt, (list, tuple)): + prompt = [prompt] + cond_kwargs = {"uncond_p": 0.0} if do_classifier_free_guidance else {} + results = infer_fn( + *prompt, + **infer_fn_kwargs, + **cond_kwargs, + ) + output_type_list.append((type(results), len(results) if isinstance(results, (list, tuple)) else 1)) + if isinstance(results, dict): + raise ValueError("Make batch on dict is not supported. Please return list or tuple for infer_fn.") + if not isinstance(results, (list, tuple)): + results = (results,) + if cond_results_list is None: + cond_results_list = [[] for _ in results] + uncond_results_list = [[] for _ in results] + for i, result in enumerate(results): + cond_results_list[i].append(result) + + if do_classifier_free_guidance: + if negative_prompt_list is None: + uncond_kwargs = {"uncond_p": 1.0} + uncond_results = infer_fn( + *prompt, + **infer_fn_kwargs, + **uncond_kwargs, + ) + else: + negative_prompt = negative_prompt_list[prompt_idx] + if not isinstance(negative_prompt, (list, tuple)): + negative_prompt = [negative_prompt] + uncond_results = infer_fn( + *negative_prompt, + **infer_fn_kwargs, + ) + if isinstance(uncond_results, TokenizerEncodeOutput): + uncond_results_list.append(uncond_results) + else: + for i, result in enumerate(uncond_results): + uncond_results_list[i].append(result) + + assert all(output_type_list[0] == n for n in output_type_list), \ + f"Number of outputs should be equal for all samples, but got {output_type_list}." + output_type, output_num = output_type_list[0] + + def make_batch(batch_cond_item, batch_uncond_item): + # Process each output item to make batch + first = batch_cond_item[0] # The first element in the batch + if isinstance(first, torch.Tensor): + stacked_item = torch.stack(self.pad( + batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times, + )) + + elif first is None: + assert all(item is None for item in batch_cond_item + batch_uncond_item), \ + (f"The first cond item is None, but some items are not None:\n\n" + f"condition: {batch_cond_item}\n\n" + f"uncondition: {batch_uncond_item}") + stacked_item = None + + elif isinstance(first, (list, tuple)): + # If the output item is a list or tuple, we treat it as a whole, and won't make nested batch any more. + stacked_item = batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times + + elif isinstance(first, TokenizerEncodeOutput): + stacked_item = {} + # Traverse not-None attributes + for key in list(first.keys()): + merged_list = [cond_item[key] for cond_item in batch_cond_item] * condition_repeat_times + \ + [uncond_item[key] for uncond_item in batch_uncond_item] * uncondition_repeat_times + if isinstance(first[key], torch.Tensor): + if 'mask' in key: + pad_val = 0.0 + elif key == 'tokens': + pad_val = self.special_token_map[""] + else: + pad_val = False # Should not pad for other tensors + stacked_item[key] = torch.stack(self.pad(merged_list, pad_val=pad_val), dim=0) + elif isinstance(first[key], list): + stacked_item[key] = merged_list + elif first[key] is None: + pass + else: + raise ValueError(f"Unsupported type of {key}: {type(first[key])}.") + stacked_item = TokenizerEncodeOutput(stacked_item) + + else: + raise TypeError(f"Making batch on type {type(first)} is not supported.") + + return stacked_item + + stacked_outputs = [] + for cond_results, uncond_results in zip(cond_results_list, uncond_results_list): + stacked_outputs.append(make_batch(cond_results, uncond_results)) + + if output_type == list: + return stacked_outputs + elif output_type == tuple: + return tuple(stacked_outputs) + elif output_num == 1: + return stacked_outputs[0] + else: + raise ValueError(f"Unsupported output type: {output_type}.") + + @staticmethod + def parse_extra_token_pos(extra_token_pos, prefix, tokens, rng=None): + if rng is None: + rng = slice(None) + image_slices = [ + slice(start, end + 1) + for start, end in zip(extra_token_pos[f'<{prefix}>_start'][rng], extra_token_pos[f'<{prefix}>_end'][rng]) + ] if f'<{prefix}>_start' in extra_token_pos and f'<{prefix}>_end' in extra_token_pos else [] + if image_slices: + image_mask = torch.zeros_like(tokens, dtype=torch.bool) + for image_slice in image_slices: + image_mask[image_slice] = True + else: + image_mask = None + return image_slices, image_mask + + def encode_general( + self, + sections: Optional[List[Dict[str, Any]]] = None, + max_token_length: Optional[int] = None, + add_eos='auto', + use_text_mask=True, + add_pad='auto', + add_bos=True, + drop_last='auto', + ): + """ + General encode function to encode a sequence with multiple sections of text and images. + Each section is a dict with a `type` key and other keys depending on the type. + Supported section types: + - text: dict with keys: + - text: str or List[int], text to be encoded. Either `text` or `tokens` should be provided. + - tokens: List[int], pre-encoded text tokens. Either `text` or `tokens` should be provided. + - uncond_enabled: bool, whether to enable uncondition for this text section. + - uncond_p: float, probability to drop the text section for uncondition. + - max_length: int, maximum length of the text section. + - ignore: bool, whether to ignore this text section in the text mask. + - start_offset: int, start offset of the text mask. + - end_offset: int, end offset of the text mask. + - gen_image: dict with keys: + - token_length: int, number of image tokens. + - add_timestep_token: bool, whether to add timestep token before the image tokens. + - add_guidance_token: bool, whether to add guidance token before the image tokens. + - use_front_boi_token: bool, whether to put the token at the front of size, ratio and timestep tokens. + - add_image_shape_token: bool, whether to add image shape token before the image tokens. + - base_size: int, base size of the image. + - ratio_idx: int, ratio index of the image. + - joint_image: dict with keys: + - token_length: List[int], number of image tokens for the two images. + - add_timestep_token: bool, whether to add timestep token before the image tokens. + - use_front_boi_token: bool, whether to put the token at the front of size, ratio and timestep tokens. + - add_image_shape_token: bool, whether to add image shape token before the image tokens. + - base_size: int, base size of the image. + - ratio_idx: int, ratio index of the image. + + Parameters + ---------- + sections: List[Dict[str, Any]] + List of sections to be encoded. + max_token_length: int + Maximum length of the encoded token sequence. + add_eos: bool or 'auto' + Whether to add eos token at the end of the sequence. If True, always add eos + token. If 'auto', add eos token only when the total_length is not reached and the last token is not . + use_text_mask: bool + Whether to generate text mask. + add_pad: bool or 'auto' + Whether to add padding tokens to the sequence. If True and total_length is not reached, + add padding tokens. + add_bos: bool + Whether to add bos token at the beginning of the sequence. + drop_last: bool or 'auto' + - If auto, drop last tokens exceeding the total_length if the total_length is provided. + If cut point is in the middle of the image tokens, an error will raised. + - If True, drop last tokens exceeding the total_length. If cut point is in the + middle of the image tokens, all the successive image tokens will be dropped. + - If False, keep the last tokens exceeding the total_length, even if the total_length + is reached. + + Returns + ------- + TokenizerEncodeOutput + Encoded token sequence and extra information. + """ + if sections is None: + raise ValueError("sections must be provided.") + template = '-'.join([section['type'] for section in sections]) + + sections = deepcopy(sections) + token_source = defaultdict(list) + text_mask_specs = [] + for section in sections: + if section['type'] == 'text': + text = self.encode_text( + section['text'] if 'text' in section else section['tokens'], + uncond_enabled=section.get('uncond_enabled'), + uncond_p=section.get('uncond_p'), + max_length=section.get('max_length'), + ) + token_source['text'].append(text) + text_mask_specs.append(dict( + ignore=section.get('ignore', False), + start_offset=section.get('start_offset', 0), + end_offset=section.get('end_offset', 0), + )) + elif section['type'] == 'gen_image': + token_source['gen_image'].append(dict( + length=section['token_length'], + timestep=section.get('add_timestep_token', False), + guidance=section.get('add_guidance_token', False), + front_boi=section.get('use_front_boi_token', False), + image_shape=section.get('add_image_shape_token', False), + base_size=section.get('base_size'), + ratio_idx=section.get('ratio_idx'), + )) + elif section['type'] == 'joint_image': + token_source['joint_image'].append(dict( + length=section['token_length'], + timestep=section.get('add_timestep_token', False), + front_boi=section.get('use_front_boi_token', False), + image_shape=section.get('add_image_shape_token', False), + base_size=section.get('base_size'), + ratio_idx=section.get('ratio_idx'), + )) + else: + raise ValueError(f"Invalid section type: {section['type']}") + + # Combine text and image tokens + full_token_seq, extra_token_pos = self.encode_sequence( + template=template, + token_source=dict(token_source), + total_length=max_token_length, + add_eos=add_eos, + add_pad=add_pad, + add_bos=add_bos, + drop_last=drop_last, + ) + full_seq_token_tensor = torch.tensor(full_token_seq, dtype=torch.long) + + timestep_scatter_index = torch.tensor(extra_token_pos['timestep'], dtype=torch.long) \ + if 'timestep' in extra_token_pos else None + guidance_scatter_index = torch.tensor(extra_token_pos['guidance'], dtype=torch.long) \ + if 'guidance' in extra_token_pos else None + cond_timestep_scatter_index = torch.tensor(extra_token_pos['cond_timestep'], dtype=torch.long) \ + if 'cond_timestep' in extra_token_pos else None + gen_timestep_scatter_index = torch.tensor(extra_token_pos['gen_timestep'], dtype=torch.long) \ + if 'gen_timestep' in extra_token_pos else None + + # Gen image mask + gen_image_slices, gen_image_mask = self.parse_extra_token_pos(extra_token_pos, 'img', full_seq_token_tensor) + # Joint image + joint_image_slices, _ = self.parse_extra_token_pos(extra_token_pos, 'joint_img', full_seq_token_tensor) + # Conditional vae image + cond_vae_image_slices, cond_vae_image_mask = self.parse_extra_token_pos( + extra_token_pos, 'vae_img', full_seq_token_tensor) + # Conditional vit image + cond_vit_image_slices, cond_vit_image_mask = self.parse_extra_token_pos( + extra_token_pos, 'vit_img', full_seq_token_tensor) + # All image slices (gen_image, joint_image) + all_image_slices = [ + slice(start, end + 1) + for start, end in zip(extra_token_pos['_start'], extra_token_pos['_end']) + ] if '_start' in extra_token_pos and '_end' in extra_token_pos else [] + + # Text mask + text_slices = [ + slice(start, end + 1) + for start, end in zip(extra_token_pos['_start'], extra_token_pos['_end']) + ] if '_start' in extra_token_pos and '_end' in extra_token_pos else [] + assert len(text_slices) <= len(text_mask_specs), \ + (f"Number of text slices ({len(text_slices)}) should be less than or equal to " + f"number of text mask specs ({len(text_mask_specs)})") + if use_text_mask: + text_mask = torch.zeros_like(full_seq_token_tensor, dtype=torch.float32) + for text_slice, mask_spec in zip(text_slices, text_mask_specs): + if not mask_spec['ignore']: + real_slice = slice( + text_slice.start + mask_spec['start_offset'], + text_slice.stop + mask_spec['end_offset'] + ) + text_mask[real_slice] = 1.0 + else: + text_mask = None + + # real_pos is the first position of the token + real_pos = torch.tensor(extra_token_pos.get('first_pad', [full_seq_token_tensor.shape[0]]), dtype=torch.long) + + return TokenizerEncodeOutput( + tokens=full_seq_token_tensor, + timestep_scatter_index=timestep_scatter_index, + guidance_scatter_index=guidance_scatter_index, + text_slices=text_slices, + gen_image_slices=gen_image_slices, + joint_image_slices=joint_image_slices, + cond_vae_image_slices=cond_vae_image_slices, + cond_vit_image_slices=cond_vit_image_slices, + text_mask=text_mask, + gen_image_mask=gen_image_mask, + cond_vae_image_mask=cond_vae_image_mask, + cond_vit_image_mask=cond_vit_image_mask, + real_pos=real_pos, + all_image_slices=all_image_slices, + cond_timestep_scatter_index=cond_timestep_scatter_index, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + + def get_cot_sections(self, cot_text, uncond_kwargs, cot_max_length=None, drop_think=False): + if not cot_text: # None or empty + return [] + if '' in cot_text and '' in cot_text: + before_think_sec = cot_text.split('')[0] + after_think_sec = cot_text.split('')[1] + think_sec = cot_text.split('')[1].split('')[0] + return self.get_cot_sections(before_think_sec, uncond_kwargs, drop_think=drop_think) + \ + ([ + dict(type="text", text=""), + dict(type="text", text=think_sec, max_length=cot_max_length, **uncond_kwargs), + dict(type="text", text="") + ] if not drop_think else []) + \ + self.get_cot_sections(after_think_sec, uncond_kwargs, drop_think=drop_think) + + if '' in cot_text and '' in cot_text: + before_recaption_sec = cot_text.split('')[0] + after_recaption_sec = cot_text.split('')[1] + recaption_sec = cot_text.split('')[1].split('')[0] + return self.get_cot_sections(before_recaption_sec, uncond_kwargs, drop_think=drop_think) + \ + [ + dict(type="text", text=""), + dict(type="text", text=recaption_sec, max_length=cot_max_length, **uncond_kwargs), + dict(type="text", text="") + ] + \ + self.get_cot_sections(after_recaption_sec, uncond_kwargs, drop_think=drop_think) + + return [ + dict(type="text", text=cot_text, **uncond_kwargs), + ] + + def apply_general_template( + self, + message_list, + max_length=None, + add_assistant_prefix=False, + answer="auto", + bot_task="auto", + sequence_template="instruct", + uncond_p=0.0, + cfg_factor=1, + batchify=False, + image_base_size=1024, + drop_think=False, + ): + # If cfg_factor > 1, we need to repeat the unconditioned part + if batchify: + assert isinstance(message_list[0], list), \ + f"When batchify is True, message_list should be a list of list, but got [{type(message_list[0])}, ...]." + return self.batch_gen_infer( + infer_fn=self.apply_general_template, + prompt_list=[[]], + infer_fn_kwargs_list=[dict( + message_list=message_list_i, + max_length=max_length, + add_assistant_prefix=add_assistant_prefix, + answer=answer, + bot_task=bot_task, + sequence_template=sequence_template, + image_base_size=image_base_size, + drop_think=drop_think, + ) for message_list_i in message_list], + do_classifier_free_guidance=cfg_factor > 1, + condition_repeat_times=1, + uncondition_repeat_times=cfg_factor - 1, + ) + + conv = Conversation() + uncond_kwargs = dict(uncond_enabled=uncond_p == 1.0, uncond_p=uncond_p) + + def process_successive_message(_message_list, _cur_message_idx, role, prefix, suffix, + answer_prefix="", answer_suffix=""): + _sub_sections = [] + while _cur_message_idx < len(message_list) and _message_list[_cur_message_idx]['role'] == role: + message = _message_list[_cur_message_idx] + if message['type'] == 'text': + text = message['content'] + if role == "system": + _sub_sections.append(dict(type="text", text=text)) + elif role == "assistant": + if ("" in text and "" in text) or ( + "" in text and "" in text): + _sub_sections.extend(self.get_cot_sections(text, uncond_kwargs, drop_think=drop_think)) + else: + _sub_sections.append(dict(type="text", text=text, **uncond_kwargs)) + else: + _sub_sections.append(dict( + type="text", text=f"{answer_prefix}{text}{answer_suffix}", **uncond_kwargs)) + elif message['type'] == 'gen_image': + info = message['content'] + assert isinstance(info, ImageInfo), f"Expected ImageInfo, but got {type(info)}" + if role == "assistant": + _sub_sections.append(dict(type="text", text=answer_prefix)) + _sub_sections.append(dict(type=message['type'], **info.meta_info)) + if role == "assistant": + _sub_sections.append(dict(type="text", text=answer_suffix)) + elif message['type'] == 'joint_image': + info = message['content'] + assert isinstance(info, JointImageInfo), f"Expected JointImageInfo, but got {type(info)}" + _sub_sections.append(dict(type=message['type'], **info.meta_info)) + else: + raise ValueError(f"Unknown message type: {message['type']}") + _cur_message_idx += 1 + if len(_sub_sections) > 0: + # Add role prefix and suffix + _sub_sections.insert(0, dict(type='text', text=prefix)) + _sub_sections.append(dict(type='text', text=suffix)) + return _sub_sections, _cur_message_idx + + # Define assistant prefix and suffix + if (answer == "auto" and sequence_template == "instruct") or answer is True: + answer_prefix, answer_suffix = "", "" + else: + answer_prefix, answer_suffix = "", "" + if sequence_template == "pretrain": + system_suffix = "" + user_prefix = "" + user_suffix = "" + bot_prefix = "" + bot_suffix = "" + else: + system_suffix = f"{conv.sep}" + user_prefix = f"{conv.roles[0]}: " + user_suffix = f"{conv.sep}" + bot_prefix = f"{conv.roles[1]}: " + bot_suffix = f"{conv.sep}" + + # Process successive user and assistant messages + sections = [] + cur_message_idx = 0 + final_role = None + while cur_message_idx < len(message_list): + # Process successive system messages + sub_sections, cur_message_idx = process_successive_message( + message_list, cur_message_idx, role="system", prefix="", suffix=system_suffix) + # Add to the template and sections + sections.extend(sub_sections) + if len(sub_sections) > 0: + final_role = "system" + + # Process successive user messages + sub_sections, cur_message_idx = process_successive_message( + message_list, cur_message_idx, role="user", prefix=user_prefix, suffix=user_suffix) + # Add to the template and sections + sections.extend(sub_sections) + if len(sub_sections) > 0: + final_role = "user" + + # Process successive assistant messages + sub_sections, cur_message_idx = process_successive_message( + message_list, cur_message_idx, role="assistant", prefix=bot_prefix, suffix=bot_suffix, + answer_prefix=answer_prefix, answer_suffix=answer_suffix, + ) + # Add to the template and sections + sections.extend(sub_sections) + if len(sub_sections) > 0: + final_role = "assistant" + + if add_assistant_prefix: + if final_role == "assistant": + # Avoid adding prefix twice + _bot_prefix = "" + # Remove the final bot_suffix + if len(sections) > 0 and sections[-1]['type'] == 'text' and sections[-1]['text'] == bot_suffix: + sections = sections[:-1] + else: + _bot_prefix = bot_prefix + # We can add special tokens for the bot lastest message according to different tasks + bot_response_prefix = dict( + auto=_bot_prefix, + think=f"{_bot_prefix}", + recaption=f"{_bot_prefix}", + img_ratio=f"{_bot_prefix}{answer_prefix}", + )[bot_task] + sections.append(dict(type='text', text=bot_response_prefix)) + + output = self.encode_general( + sections=sections, + use_text_mask=False, + add_eos=False, + add_pad=False, + ) + + if max_length is not None: + if output.tokens.shape[-1] > max_length: + raise ValueError( + f"Encoded token length {output.tokens.shape[-1]} exceeds max_length {max_length}.\n" + f"Please set a larger max_length or check the input messages:\n{message_list}" + ) + + return output, sections + + def apply_chat_template( + self, + batch_prompt: Optional[List[str]] = None, + batch_message_list: Optional[List[List[Dict[str, Any]]]] = None, + mode: str = "gen_text", + batch_gen_image_info: Optional[List[ImageInfo]] = None, + batch_cond_image_info: Optional[Union[List[JointImageInfo], List[List[JointImageInfo]]]] = None, + batch_system_prompt: Optional[List[str]] = None, + batch_cot_text: Optional[List[str]] = None, + max_length: Optional[int] = None, + bot_task: str = "auto", # auto/think/recaption/img_ratio + image_base_size: int = 1024, + sequence_template: str = "pretrain", + cfg_factor: int = 1, + add_assistant_prefix: Optional[bool] = None, + drop_think: bool = False, + ) -> Dict[str, Any]: + assert bot_task in ["auto", "think", "recaption", "img_ratio"], \ + f"bot_task should be one of ['auto', 'think', 'recaption', 'img_ratio'], but got {bot_task}." + + if batch_message_list is None: + # Simple text-to-image or text-cot-to-image task + batch_size = len(batch_prompt) + + # Batchify inputs + if not isinstance(batch_system_prompt, list): + batch_system_prompt = [batch_system_prompt] * batch_size + if not isinstance(batch_gen_image_info, list): + batch_gen_image_info = [batch_gen_image_info] * batch_size + if batch_cot_text is not None: + assert len(batch_cot_text) == batch_size, \ + (f"batch_cot_text should have the same length as batch_size ({batch_size}), " + f"but got {len(batch_cot_text)}.") + else: + batch_cot_text = [None] * batch_size + if batch_cond_image_info is not None: + assert len(batch_cond_image_info) == batch_size, \ + (f"batch_cond_image_info should have the same length as batch_size ({batch_size}), " + f"but got {len(batch_cond_image_info)}.") + batch_cond_image_info = [ + cond_image_info if isinstance(cond_image_info, list) else [cond_image_info] + for cond_image_info in batch_cond_image_info + ] + else: + batch_cond_image_info = [[] for _ in range(batch_size)] + + # Convert single round materials into standard message list + batch_message_list = [] + for prompt, system_prompt, cot_text, gen_image_info, cond_image_info_list in zip( + batch_prompt, batch_system_prompt, batch_cot_text, batch_gen_image_info, + batch_cond_image_info, + ): + message_list = [] + # 1. system prompt section + if system_prompt: + message_list.append(dict( + role="system", type="text", content=system_prompt, context_type="str")) + # 2. user inputs sections + # 2.1 image inputs + if len(cond_image_info_list) > 0: + message_list.extend([ + dict(role="user", type="joint_image", content=cond_image_info, context_type="image_info") + for cond_image_info in cond_image_info_list + ]) + # 2.2 text inputs + message_list.append(dict( + role="user", type="text", content=prompt, context_type="str")) + # 3. assistant answer sections + if cot_text is not None: + message_list.append(dict(role="assistant", type="text", content=cot_text, context_type="str")) + if mode == "gen_image": + message_list.append(dict( + role="assistant", type="gen_image", content=gen_image_info, context_type="image_info")) + # --- + batch_message_list.append(message_list) + + output, sections = self.apply_general_template( + message_list=batch_message_list, + max_length=max_length, + add_assistant_prefix=default(add_assistant_prefix, mode != "gen_image"), + bot_task=bot_task, + sequence_template=sequence_template, + cfg_factor=cfg_factor, + batchify=True, + image_base_size=image_base_size, + drop_think=drop_think, + ) + return dict(output=output, sections=sections)