Spaces:
Paused
Paused
| from typing import Any, Dict, Optional, Tuple, List | |
| import torch | |
| import torch.nn as nn | |
| import einops | |
| from einops import repeat | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers | |
| from diffusers.utils.torch_utils import maybe_allow_in_graph | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from ..embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed | |
| from ..attention import HiDreamAttention, FeedForwardSwiGLU | |
| from ..attention_processor import HiDreamAttnProcessor_flashattn | |
| from ..moe import MOEFeedForwardSwiGLU | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class TextProjection(nn.Module): | |
| def __init__(self, in_features, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) | |
| def forward(self, caption): | |
| hidden_states = self.linear(caption) | |
| return hidden_states | |
| class BlockType: | |
| TransformerBlock = 1 | |
| SingleTransformerBlock = 2 | |
| class HiDreamImageSingleTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| num_routed_experts: int = 4, | |
| num_activated_experts: int = 2 | |
| ): | |
| super().__init__() | |
| self.num_attention_heads = num_attention_heads | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(dim, 6 * dim, bias=True) | |
| ) | |
| nn.init.zeros_(self.adaLN_modulation[1].weight) | |
| nn.init.zeros_(self.adaLN_modulation[1].bias) | |
| # 1. Attention | |
| self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) | |
| self.attn1 = HiDreamAttention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| processor = HiDreamAttnProcessor_flashattn(), | |
| single = True | |
| ) | |
| # 3. Feed-forward | |
| self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) | |
| if num_routed_experts > 0: | |
| self.ff_i = MOEFeedForwardSwiGLU( | |
| dim = dim, | |
| hidden_dim = 4 * dim, | |
| num_routed_experts = num_routed_experts, | |
| num_activated_experts = num_activated_experts, | |
| ) | |
| else: | |
| self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) | |
| def forward( | |
| self, | |
| image_tokens: torch.FloatTensor, | |
| image_tokens_masks: Optional[torch.FloatTensor] = None, | |
| text_tokens: Optional[torch.FloatTensor] = None, | |
| adaln_input: Optional[torch.FloatTensor] = None, | |
| rope: torch.FloatTensor = None, | |
| ) -> torch.FloatTensor: | |
| wtype = image_tokens.dtype | |
| shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ | |
| self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) | |
| # 1. MM-Attention | |
| norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) | |
| norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i | |
| attn_output_i = self.attn1( | |
| norm_image_tokens, | |
| image_tokens_masks, | |
| rope = rope, | |
| ) | |
| image_tokens = gate_msa_i * attn_output_i + image_tokens | |
| # 2. Feed-forward | |
| norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) | |
| norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i | |
| ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) | |
| image_tokens = ff_output_i + image_tokens | |
| return image_tokens | |
| class HiDreamImageTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| num_routed_experts: int = 4, | |
| num_activated_experts: int = 2 | |
| ): | |
| super().__init__() | |
| self.num_attention_heads = num_attention_heads | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(dim, 12 * dim, bias=True) | |
| ) | |
| nn.init.zeros_(self.adaLN_modulation[1].weight) | |
| nn.init.zeros_(self.adaLN_modulation[1].bias) | |
| # 1. Attention | |
| self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) | |
| self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) | |
| self.attn1 = HiDreamAttention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| processor = HiDreamAttnProcessor_flashattn(), | |
| single = False | |
| ) | |
| # 3. Feed-forward | |
| self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) | |
| if num_routed_experts > 0: | |
| self.ff_i = MOEFeedForwardSwiGLU( | |
| dim = dim, | |
| hidden_dim = 4 * dim, | |
| num_routed_experts = num_routed_experts, | |
| num_activated_experts = num_activated_experts, | |
| ) | |
| else: | |
| self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) | |
| self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) | |
| self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) | |
| def forward( | |
| self, | |
| image_tokens: torch.FloatTensor, | |
| image_tokens_masks: Optional[torch.FloatTensor] = None, | |
| text_tokens: Optional[torch.FloatTensor] = None, | |
| adaln_input: Optional[torch.FloatTensor] = None, | |
| rope: torch.FloatTensor = None, | |
| ) -> torch.FloatTensor: | |
| wtype = image_tokens.dtype | |
| shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ | |
| shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ | |
| self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) | |
| # 1. MM-Attention | |
| norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) | |
| norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i | |
| norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) | |
| norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t | |
| attn_output_i, attn_output_t = self.attn1( | |
| norm_image_tokens, | |
| image_tokens_masks, | |
| norm_text_tokens, | |
| rope = rope, | |
| ) | |
| image_tokens = gate_msa_i * attn_output_i + image_tokens | |
| text_tokens = gate_msa_t * attn_output_t + text_tokens | |
| # 2. Feed-forward | |
| norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) | |
| norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i | |
| norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) | |
| norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t | |
| ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) | |
| ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) | |
| image_tokens = ff_output_i + image_tokens | |
| text_tokens = ff_output_t + text_tokens | |
| return image_tokens, text_tokens | |
| class HiDreamImageBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| num_routed_experts: int = 4, | |
| num_activated_experts: int = 2, | |
| block_type: BlockType = BlockType.TransformerBlock, | |
| ): | |
| super().__init__() | |
| block_classes = { | |
| BlockType.TransformerBlock: HiDreamImageTransformerBlock, | |
| BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, | |
| } | |
| self.block = block_classes[block_type]( | |
| dim, | |
| num_attention_heads, | |
| attention_head_dim, | |
| num_routed_experts, | |
| num_activated_experts | |
| ) | |
| def forward( | |
| self, | |
| image_tokens: torch.FloatTensor, | |
| image_tokens_masks: Optional[torch.FloatTensor] = None, | |
| text_tokens: Optional[torch.FloatTensor] = None, | |
| adaln_input: torch.FloatTensor = None, | |
| rope: torch.FloatTensor = None, | |
| ) -> torch.FloatTensor: | |
| return self.block( | |
| image_tokens, | |
| image_tokens_masks, | |
| text_tokens, | |
| adaln_input, | |
| rope, | |
| ) | |
| class HiDreamImageTransformer2DModel( | |
| ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin | |
| ): | |
| _supports_gradient_checkpointing = True | |
| _no_split_modules = ["HiDreamImageBlock"] | |
| def __init__( | |
| self, | |
| patch_size: Optional[int] = None, | |
| in_channels: int = 64, | |
| out_channels: Optional[int] = None, | |
| num_layers: int = 16, | |
| num_single_layers: int = 32, | |
| attention_head_dim: int = 128, | |
| num_attention_heads: int = 20, | |
| caption_channels: List[int] = None, | |
| text_emb_dim: int = 2048, | |
| num_routed_experts: int = 4, | |
| num_activated_experts: int = 2, | |
| axes_dims_rope: Tuple[int, int] = (32, 32), | |
| max_resolution: Tuple[int, int] = (128, 128), | |
| llama_layers: List[int] = None, | |
| ): | |
| super().__init__() | |
| self.out_channels = out_channels or in_channels | |
| self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim | |
| self.llama_layers = llama_layers | |
| self.t_embedder = TimestepEmbed(self.inner_dim) | |
| self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim) | |
| self.x_embedder = PatchEmbed( | |
| patch_size = patch_size, | |
| in_channels = in_channels, | |
| out_channels = self.inner_dim, | |
| ) | |
| self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) | |
| self.double_stream_blocks = nn.ModuleList( | |
| [ | |
| HiDreamImageBlock( | |
| dim = self.inner_dim, | |
| num_attention_heads = self.config.num_attention_heads, | |
| attention_head_dim = self.config.attention_head_dim, | |
| num_routed_experts = num_routed_experts, | |
| num_activated_experts = num_activated_experts, | |
| block_type = BlockType.TransformerBlock | |
| ) | |
| for i in range(self.config.num_layers) | |
| ] | |
| ) | |
| self.single_stream_blocks = nn.ModuleList( | |
| [ | |
| HiDreamImageBlock( | |
| dim = self.inner_dim, | |
| num_attention_heads = self.config.num_attention_heads, | |
| attention_head_dim = self.config.attention_head_dim, | |
| num_routed_experts = num_routed_experts, | |
| num_activated_experts = num_activated_experts, | |
| block_type = BlockType.SingleTransformerBlock | |
| ) | |
| for i in range(self.config.num_single_layers) | |
| ] | |
| ) | |
| self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels) | |
| caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] | |
| caption_projection = [] | |
| for caption_channel in caption_channels: | |
| caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim)) | |
| self.caption_projection = nn.ModuleList(caption_projection) | |
| self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) | |
| self.gradient_checkpointing = False | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if hasattr(module, "gradient_checkpointing"): | |
| module.gradient_checkpointing = value | |
| def expand_timesteps(self, timesteps, batch_size, device): | |
| if not torch.is_tensor(timesteps): | |
| is_mps = device.type == "mps" | |
| if isinstance(timesteps, float): | |
| dtype = torch.float32 if is_mps else torch.float64 | |
| else: | |
| dtype = torch.int32 if is_mps else torch.int64 | |
| timesteps = torch.tensor([timesteps], dtype=dtype, device=device) | |
| elif len(timesteps.shape) == 0: | |
| timesteps = timesteps[None].to(device) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timesteps = timesteps.expand(batch_size) | |
| return timesteps | |
| def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: | |
| if is_training: | |
| x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size) | |
| else: | |
| x_arr = [] | |
| for i, img_size in enumerate(img_sizes): | |
| pH, pW = img_size | |
| x_arr.append( | |
| einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', | |
| p1=self.config.patch_size, p2=self.config.patch_size) | |
| ) | |
| x = torch.cat(x_arr, dim=0) | |
| return x | |
| def patchify(self, x, max_seq, img_sizes=None): | |
| pz2 = self.config.patch_size * self.config.patch_size | |
| if isinstance(x, torch.Tensor): | |
| B, C = x.shape[0], x.shape[1] | |
| device = x.device | |
| dtype = x.dtype | |
| else: | |
| B, C = len(x), x[0].shape[0] | |
| device = x[0].device | |
| dtype = x[0].dtype | |
| x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) | |
| if img_sizes is not None: | |
| for i, img_size in enumerate(img_sizes): | |
| x_masks[i, 0:img_size[0] * img_size[1]] = 1 | |
| x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) | |
| elif isinstance(x, torch.Tensor): | |
| pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size | |
| x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size) | |
| img_sizes = [[pH, pW]] * B | |
| x_masks = None | |
| else: | |
| raise NotImplementedError | |
| return x, x_masks, img_sizes | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| timesteps: torch.LongTensor = None, | |
| encoder_hidden_states: torch.Tensor = None, | |
| pooled_embeds: torch.Tensor = None, | |
| img_sizes: Optional[List[Tuple[int, int]]] = None, | |
| img_ids: Optional[torch.Tensor] = None, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| return_dict: bool = True, | |
| ): | |
| if joint_attention_kwargs is not None: | |
| joint_attention_kwargs = joint_attention_kwargs.copy() | |
| lora_scale = joint_attention_kwargs.pop("scale", 1.0) | |
| else: | |
| lora_scale = 1.0 | |
| if USE_PEFT_BACKEND: | |
| # weight the lora layers by setting `lora_scale` for each PEFT layer | |
| scale_lora_layers(self, lora_scale) | |
| else: | |
| if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: | |
| logger.warning( | |
| "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." | |
| ) | |
| # spatial forward | |
| batch_size = hidden_states.shape[0] | |
| hidden_states_type = hidden_states.dtype | |
| # 0. time | |
| timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) | |
| timesteps = self.t_embedder(timesteps, hidden_states_type) | |
| p_embedder = self.p_embedder(pooled_embeds) | |
| adaln_input = timesteps + p_embedder | |
| hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) | |
| if image_tokens_masks is None: | |
| pH, pW = img_sizes[0] | |
| img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) | |
| hidden_states = self.x_embedder(hidden_states) | |
| T5_encoder_hidden_states = encoder_hidden_states[0] | |
| encoder_hidden_states = encoder_hidden_states[-1] | |
| encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] | |
| if self.caption_projection is not None: | |
| new_encoder_hidden_states = [] | |
| for i, enc_hidden_state in enumerate(encoder_hidden_states): | |
| enc_hidden_state = self.caption_projection[i](enc_hidden_state) | |
| enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) | |
| new_encoder_hidden_states.append(enc_hidden_state) | |
| encoder_hidden_states = new_encoder_hidden_states | |
| T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) | |
| T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) | |
| encoder_hidden_states.append(T5_encoder_hidden_states) | |
| txt_ids = torch.zeros( | |
| batch_size, | |
| encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], | |
| 3, | |
| device=img_ids.device, dtype=img_ids.dtype | |
| ) | |
| ids = torch.cat((img_ids, txt_ids), dim=1) | |
| rope = self.pe_embedder(ids) | |
| # 2. Blocks | |
| block_id = 0 | |
| initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) | |
| initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] | |
| for bid, block in enumerate(self.double_stream_blocks): | |
| cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] | |
| cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module, return_dict=None): | |
| def custom_forward(*inputs): | |
| if return_dict is not None: | |
| return module(*inputs, return_dict=return_dict) | |
| else: | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| hidden_states, | |
| image_tokens_masks, | |
| cur_encoder_hidden_states, | |
| adaln_input, | |
| rope, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| hidden_states, initial_encoder_hidden_states = block( | |
| image_tokens = hidden_states, | |
| image_tokens_masks = image_tokens_masks, | |
| text_tokens = cur_encoder_hidden_states, | |
| adaln_input = adaln_input, | |
| rope = rope, | |
| ) | |
| initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] | |
| block_id += 1 | |
| image_tokens_seq_len = hidden_states.shape[1] | |
| hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) | |
| hidden_states_seq_len = hidden_states.shape[1] | |
| if image_tokens_masks is not None: | |
| encoder_attention_mask_ones = torch.ones( | |
| (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), | |
| device=image_tokens_masks.device, dtype=image_tokens_masks.dtype | |
| ) | |
| image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) | |
| for bid, block in enumerate(self.single_stream_blocks): | |
| cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] | |
| hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module, return_dict=None): | |
| def custom_forward(*inputs): | |
| if return_dict is not None: | |
| return module(*inputs, return_dict=return_dict) | |
| else: | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| hidden_states, | |
| image_tokens_masks, | |
| None, | |
| adaln_input, | |
| rope, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| hidden_states = block( | |
| image_tokens = hidden_states, | |
| image_tokens_masks = image_tokens_masks, | |
| text_tokens = None, | |
| adaln_input = adaln_input, | |
| rope = rope, | |
| ) | |
| hidden_states = hidden_states[:, :hidden_states_seq_len] | |
| block_id += 1 | |
| hidden_states = hidden_states[:, :image_tokens_seq_len, ...] | |
| output = self.final_layer(hidden_states, adaln_input) | |
| output = self.unpatchify(output, img_sizes, self.training) | |
| if image_tokens_masks is not None: | |
| image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len] | |
| if USE_PEFT_BACKEND: | |
| # remove `lora_scale` from each PEFT layer | |
| unscale_lora_layers(self, lora_scale) | |
| if not return_dict: | |
| return (output, image_tokens_masks) | |
| return Transformer2DModelOutput(sample=output, mask=image_tokens_masks) | |