Upload 19 files
Browse filesupdate paint-turbo weights
    	
        hunyuan3d-paint-v2-0-turbo/unet/diffusion_pytorch_model.bin
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256: | 
| 3 | 
            -
            size  | 
|  | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:24e7f1aea8a7c94cee627eb06f5265f19eeff4e19568636c5eaef050cc19ba3d
         | 
| 3 | 
            +
            size 7325432923
         | 
    	
        hunyuan3d-paint-v2-0-turbo/unet/modules.py
    CHANGED
    
    | @@ -22,7 +22,6 @@ | |
| 22 | 
             
            # fine-tuning enabling code and other elements of the foregoing made publicly available
         | 
| 23 | 
             
            # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
         | 
| 24 |  | 
| 25 | 
            -
             | 
| 26 | 
             
            import copy
         | 
| 27 | 
             
            import json
         | 
| 28 | 
             
            import os
         | 
| @@ -41,7 +40,9 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: | |
| 41 | 
             
                # "feed_forward_chunk_size" can be used to save memory
         | 
| 42 | 
             
                if hidden_states.shape[chunk_dim] % chunk_size != 0:
         | 
| 43 | 
             
                    raise ValueError(
         | 
| 44 | 
            -
                        f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} | 
|  | |
|  | |
| 45 | 
             
                    )
         | 
| 46 |  | 
| 47 | 
             
                num_chunks = hidden_states.shape[chunk_dim] // chunk_size
         | 
| @@ -51,329 +52,16 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: | |
| 51 | 
             
                )
         | 
| 52 | 
             
                return ff_output
         | 
| 53 |  | 
| 54 | 
            -
            class PoseRoPEAttnProcessor2_0:
         | 
| 55 | 
            -
                r"""
         | 
| 56 | 
            -
                Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
         | 
| 57 | 
            -
                """
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                def __init__(self):
         | 
| 60 | 
            -
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 61 | 
            -
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                def get_1d_rotary_pos_embed(
         | 
| 64 | 
            -
                        self,
         | 
| 65 | 
            -
                        dim: int,
         | 
| 66 | 
            -
                        pos: torch.Tensor,
         | 
| 67 | 
            -
                        theta: float = 10000.0,
         | 
| 68 | 
            -
                        linear_factor=1.0,
         | 
| 69 | 
            -
                        ntk_factor=1.0,
         | 
| 70 | 
            -
                ):
         | 
| 71 | 
            -
                    assert dim % 2 == 0
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                    theta = theta * ntk_factor
         | 
| 74 | 
            -
                    freqs = (
         | 
| 75 | 
            -
                        1.0
         | 
| 76 | 
            -
                        / (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim))
         | 
| 77 | 
            -
                        / linear_factor
         | 
| 78 | 
            -
                    )  # [D/2]
         | 
| 79 | 
            -
                    freqs = torch.outer(pos, freqs)  # type: ignore   # [S, D/2]
         | 
| 80 | 
            -
                    # flux, hunyuan-dit, cogvideox
         | 
| 81 | 
            -
                    freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()  # [S, D]
         | 
| 82 | 
            -
                    freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()  # [S, D]
         | 
| 83 | 
            -
                    return freqs_cos, freqs_sin
         | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
                def get_3d_rotary_pos_embed(
         | 
| 87 | 
            -
                        self,
         | 
| 88 | 
            -
                        position,
         | 
| 89 | 
            -
                        embed_dim, 
         | 
| 90 | 
            -
                        voxel_resolution,
         | 
| 91 | 
            -
                        theta: int = 10000,
         | 
| 92 | 
            -
                ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         | 
| 93 | 
            -
                    """
         | 
| 94 | 
            -
                    RoPE for video tokens with 3D structure.
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                    Args:
         | 
| 97 | 
            -
                    voxel_resolution (`int`):
         | 
| 98 | 
            -
                        The grid size of the spatial positional embedding (height, width).
         | 
| 99 | 
            -
                    theta (`float`):
         | 
| 100 | 
            -
                        Scaling factor for frequency computation.
         | 
| 101 | 
            -
             | 
| 102 | 
            -
                    Returns:
         | 
| 103 | 
            -
                        `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
         | 
| 104 | 
            -
                    """
         | 
| 105 | 
            -
                    assert position.shape[-1]==3
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                    # Compute dimensions for each axis
         | 
| 108 | 
            -
                    dim_xy = embed_dim // 8 * 3
         | 
| 109 | 
            -
                    dim_z = embed_dim // 8 * 2
         | 
| 110 | 
            -
             | 
| 111 | 
            -
                    # Temporal frequencies
         | 
| 112 | 
            -
                    grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device)
         | 
| 113 | 
            -
                    freqs_xy = self.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta)
         | 
| 114 | 
            -
                    freqs_z = self.get_1d_rotary_pos_embed(dim_z, grid, theta=theta)
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                    xy_cos, xy_sin = freqs_xy  # both t_cos and t_sin has shape: voxel_resolution, dim_xy
         | 
| 117 | 
            -
                    z_cos, z_sin = freqs_z  # both w_cos and w_sin has shape: voxel_resolution, dim_z
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                    embed_flattn = position.view(-1, position.shape[-1])
         | 
| 120 | 
            -
                    x_cos = xy_cos[embed_flattn[:,0], :]
         | 
| 121 | 
            -
                    x_sin = xy_sin[embed_flattn[:,0], :]
         | 
| 122 | 
            -
                    y_cos = xy_cos[embed_flattn[:,1], :]
         | 
| 123 | 
            -
                    y_sin = xy_sin[embed_flattn[:,1], :]
         | 
| 124 | 
            -
                    z_cos = z_cos[embed_flattn[:,2], :]
         | 
| 125 | 
            -
                    z_sin = z_sin[embed_flattn[:,2], :]
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                    cos = torch.cat((x_cos, y_cos, z_cos), dim=-1)
         | 
| 128 | 
            -
                    sin = torch.cat((x_sin, y_sin, z_sin), dim=-1)
         | 
| 129 | 
            -
             | 
| 130 | 
            -
                    cos = cos.view(*position.shape[:-1], embed_dim)
         | 
| 131 | 
            -
                    sin = sin.view(*position.shape[:-1], embed_dim)
         | 
| 132 | 
            -
                    return cos, sin
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                def apply_rotary_emb(
         | 
| 135 | 
            -
                        self, 
         | 
| 136 | 
            -
                        x: torch.Tensor,
         | 
| 137 | 
            -
                        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
         | 
| 138 | 
            -
                    ):
         | 
| 139 | 
            -
                    cos, sin = freqs_cis  # [S, D]
         | 
| 140 | 
            -
                    cos, sin = cos.to(x.device), sin.to(x.device)
         | 
| 141 | 
            -
                    cos = cos.unsqueeze(1)
         | 
| 142 | 
            -
                    sin = sin.unsqueeze(1)
         | 
| 143 | 
            -
             | 
| 144 | 
            -
                    x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
         | 
| 145 | 
            -
                    x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
         | 
| 146 | 
            -
             | 
| 147 | 
            -
                    out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                    return out
         | 
| 150 | 
            -
             | 
| 151 | 
            -
                def __call__(
         | 
| 152 | 
            -
                    self,
         | 
| 153 | 
            -
                    attn: Attention,
         | 
| 154 | 
            -
                    hidden_states: torch.Tensor,
         | 
| 155 | 
            -
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 156 | 
            -
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 157 | 
            -
                    position_indices: Dict = None,
         | 
| 158 | 
            -
                    temb: Optional[torch.Tensor] = None,
         | 
| 159 | 
            -
                    *args,
         | 
| 160 | 
            -
                    **kwargs,
         | 
| 161 | 
            -
                ) -> torch.Tensor:
         | 
| 162 | 
            -
                    if len(args) > 0 or kwargs.get("scale", None) is not None:
         | 
| 163 | 
            -
                        deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
         | 
| 164 | 
            -
                        deprecate("scale", "1.0.0", deprecation_message)
         | 
| 165 | 
            -
             | 
| 166 | 
            -
                    residual = hidden_states
         | 
| 167 | 
            -
                    if attn.spatial_norm is not None:
         | 
| 168 | 
            -
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                    input_ndim = hidden_states.ndim
         | 
| 171 | 
            -
             | 
| 172 | 
            -
                    if input_ndim == 4:
         | 
| 173 | 
            -
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 174 | 
            -
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                    batch_size, sequence_length, _ = (
         | 
| 177 | 
            -
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 178 | 
            -
                    )
         | 
| 179 | 
            -
             | 
| 180 | 
            -
                    if attention_mask is not None:
         | 
| 181 | 
            -
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 182 | 
            -
                        # scaled_dot_product_attention expects attention_mask shape to be
         | 
| 183 | 
            -
                        # (batch, heads, source_length, target_length)
         | 
| 184 | 
            -
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
         | 
| 185 | 
            -
             | 
| 186 | 
            -
                    if attn.group_norm is not None:
         | 
| 187 | 
            -
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                    query = attn.to_q(hidden_states)
         | 
| 190 | 
            -
             | 
| 191 | 
            -
                    if encoder_hidden_states is None:
         | 
| 192 | 
            -
                        encoder_hidden_states = hidden_states
         | 
| 193 | 
            -
                    elif attn.norm_cross:
         | 
| 194 | 
            -
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 195 | 
            -
             | 
| 196 | 
            -
                    key = attn.to_k(encoder_hidden_states)
         | 
| 197 | 
            -
                    value = attn.to_v(encoder_hidden_states)
         | 
| 198 | 
            -
             | 
| 199 | 
            -
                    inner_dim = key.shape[-1]
         | 
| 200 | 
            -
                    head_dim = inner_dim // attn.heads
         | 
| 201 | 
            -
             | 
| 202 | 
            -
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 203 | 
            -
             | 
| 204 | 
            -
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 205 | 
            -
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 206 | 
            -
             | 
| 207 | 
            -
                    if attn.norm_q is not None:
         | 
| 208 | 
            -
                        query = attn.norm_q(query)
         | 
| 209 | 
            -
                    if attn.norm_k is not None:
         | 
| 210 | 
            -
                        key = attn.norm_k(key)
         | 
| 211 | 
            -
                    
         | 
| 212 | 
            -
                    if position_indices is not None:
         | 
| 213 | 
            -
                        if head_dim in position_indices:
         | 
| 214 | 
            -
                            image_rotary_emb = position_indices[head_dim]
         | 
| 215 | 
            -
                        else:
         | 
| 216 | 
            -
                            image_rotary_emb = self.get_3d_rotary_pos_embed(position_indices['voxel_indices'], head_dim, voxel_resolution=position_indices['voxel_resolution'])
         | 
| 217 | 
            -
                            position_indices[head_dim] = image_rotary_emb
         | 
| 218 | 
            -
                        query = self.apply_rotary_emb(query, image_rotary_emb)
         | 
| 219 | 
            -
                        key = self.apply_rotary_emb(key, image_rotary_emb)
         | 
| 220 | 
            -
             | 
| 221 | 
            -
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 222 | 
            -
                    # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 223 | 
            -
                    hidden_states = F.scaled_dot_product_attention(
         | 
| 224 | 
            -
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         | 
| 225 | 
            -
                    )
         | 
| 226 | 
            -
             | 
| 227 | 
            -
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 228 | 
            -
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 229 | 
            -
             | 
| 230 | 
            -
                    # linear proj
         | 
| 231 | 
            -
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 232 | 
            -
                    # dropout
         | 
| 233 | 
            -
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                    if input_ndim == 4:
         | 
| 236 | 
            -
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 237 | 
            -
             | 
| 238 | 
            -
                    if attn.residual_connection:
         | 
| 239 | 
            -
                        hidden_states = hidden_states + residual
         | 
| 240 | 
            -
             | 
| 241 | 
            -
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 242 | 
            -
             | 
| 243 | 
            -
                    return hidden_states
         | 
| 244 | 
            -
             | 
| 245 | 
            -
            class IPAttnProcessor2_0:
         | 
| 246 | 
            -
                r"""
         | 
| 247 | 
            -
                Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
         | 
| 248 | 
            -
                """
         | 
| 249 | 
            -
             | 
| 250 | 
            -
                def __init__(self, scale=0.0):
         | 
| 251 | 
            -
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 252 | 
            -
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         | 
| 253 | 
            -
             | 
| 254 | 
            -
                    self.scale = scale
         | 
| 255 | 
            -
             | 
| 256 | 
            -
                def __call__(
         | 
| 257 | 
            -
                    self,
         | 
| 258 | 
            -
                    attn: Attention,
         | 
| 259 | 
            -
                    hidden_states: torch.Tensor,
         | 
| 260 | 
            -
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 261 | 
            -
                    ip_hidden_states: Optional[torch.Tensor] = None,
         | 
| 262 | 
            -
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 263 | 
            -
                    temb: Optional[torch.Tensor] = None,
         | 
| 264 | 
            -
                    *args,
         | 
| 265 | 
            -
                    **kwargs,
         | 
| 266 | 
            -
                ) -> torch.Tensor:
         | 
| 267 | 
            -
                    if len(args) > 0 or kwargs.get("scale", None) is not None:
         | 
| 268 | 
            -
                        deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
         | 
| 269 | 
            -
                        deprecate("scale", "1.0.0", deprecation_message)
         | 
| 270 | 
            -
             | 
| 271 | 
            -
                    residual = hidden_states
         | 
| 272 | 
            -
                    if attn.spatial_norm is not None:
         | 
| 273 | 
            -
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 274 | 
            -
             | 
| 275 | 
            -
                    input_ndim = hidden_states.ndim
         | 
| 276 | 
            -
             | 
| 277 | 
            -
                    if input_ndim == 4:
         | 
| 278 | 
            -
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 279 | 
            -
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 280 | 
            -
             | 
| 281 | 
            -
                    batch_size, sequence_length, _ = (
         | 
| 282 | 
            -
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 283 | 
            -
                    )
         | 
| 284 | 
            -
             | 
| 285 | 
            -
                    if attention_mask is not None:
         | 
| 286 | 
            -
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 287 | 
            -
                        # scaled_dot_product_attention expects attention_mask shape to be
         | 
| 288 | 
            -
                        # (batch, heads, source_length, target_length)
         | 
| 289 | 
            -
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
         | 
| 290 | 
            -
             | 
| 291 | 
            -
                    if attn.group_norm is not None:
         | 
| 292 | 
            -
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 293 | 
            -
             | 
| 294 | 
            -
                    query = attn.to_q(hidden_states)
         | 
| 295 | 
            -
             | 
| 296 | 
            -
                    if encoder_hidden_states is None:
         | 
| 297 | 
            -
                        encoder_hidden_states = hidden_states
         | 
| 298 | 
            -
                    elif attn.norm_cross:
         | 
| 299 | 
            -
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 300 | 
            -
             | 
| 301 | 
            -
                    key = attn.to_k(encoder_hidden_states)
         | 
| 302 | 
            -
                    value = attn.to_v(encoder_hidden_states)
         | 
| 303 | 
            -
             | 
| 304 | 
            -
                    inner_dim = key.shape[-1]
         | 
| 305 | 
            -
                    head_dim = inner_dim // attn.heads
         | 
| 306 | 
            -
             | 
| 307 | 
            -
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 308 | 
            -
             | 
| 309 | 
            -
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 310 | 
            -
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 311 | 
            -
             | 
| 312 | 
            -
                    if attn.norm_q is not None:
         | 
| 313 | 
            -
                        query = attn.norm_q(query)
         | 
| 314 | 
            -
                    if attn.norm_k is not None:
         | 
| 315 | 
            -
                        key = attn.norm_k(key)
         | 
| 316 | 
            -
                    
         | 
| 317 | 
            -
             | 
| 318 | 
            -
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 319 | 
            -
                    # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 320 | 
            -
                    hidden_states = F.scaled_dot_product_attention(
         | 
| 321 | 
            -
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         | 
| 322 | 
            -
                    )
         | 
| 323 | 
            -
             | 
| 324 | 
            -
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 325 | 
            -
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 326 | 
            -
             | 
| 327 | 
            -
                    # for ip adapter
         | 
| 328 | 
            -
                    if ip_hidden_states is not None:
         | 
| 329 | 
            -
             | 
| 330 | 
            -
                        ip_key = attn.to_k_ip(ip_hidden_states)
         | 
| 331 | 
            -
                        ip_value = attn.to_v_ip(ip_hidden_states)
         | 
| 332 | 
            -
             | 
| 333 | 
            -
                        ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 334 | 
            -
                        ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 335 | 
            -
             | 
| 336 | 
            -
                        # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 337 | 
            -
                        ip_hidden_states = F.scaled_dot_product_attention(
         | 
| 338 | 
            -
                            query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
         | 
| 339 | 
            -
                        )
         | 
| 340 | 
            -
             | 
| 341 | 
            -
                        ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 342 | 
            -
                        ip_hidden_states = ip_hidden_states.to(query.dtype)
         | 
| 343 | 
            -
             | 
| 344 | 
            -
                        hidden_states = hidden_states + self.scale * ip_hidden_states
         | 
| 345 | 
            -
             | 
| 346 | 
            -
                    # linear proj
         | 
| 347 | 
            -
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 348 | 
            -
                    # dropout
         | 
| 349 | 
            -
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 350 | 
            -
             | 
| 351 | 
            -
                    if input_ndim == 4:
         | 
| 352 | 
            -
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 353 | 
            -
             | 
| 354 | 
            -
                    if attn.residual_connection:
         | 
| 355 | 
            -
                        hidden_states = hidden_states + residual
         | 
| 356 | 
            -
             | 
| 357 | 
            -
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 358 | 
            -
             | 
| 359 | 
            -
                    return hidden_states
         | 
| 360 | 
            -
                    
         | 
| 361 |  | 
| 362 | 
             
            class Basic2p5DTransformerBlock(torch.nn.Module):
         | 
| 363 | 
            -
                def __init__(self, transformer: BasicTransformerBlock, layer_name,  | 
| 364 | 
             
                    super().__init__()
         | 
| 365 | 
             
                    self.transformer = transformer
         | 
| 366 | 
             
                    self.layer_name = layer_name
         | 
| 367 | 
            -
                    self.use_ipa = use_ipa
         | 
| 368 | 
             
                    self.use_ma = use_ma
         | 
| 369 | 
             
                    self.use_ra = use_ra
         | 
|  | |
| 370 |  | 
| 371 | 
            -
                    if use_ipa:
         | 
| 372 | 
            -
                        self.attn2.set_processor(IPAttnProcessor2_0())
         | 
| 373 | 
            -
                        cross_attention_dim = 1024
         | 
| 374 | 
            -
                        self.attn2.to_k_ip = nn.Linear(cross_attention_dim, self.dim, bias=False)
         | 
| 375 | 
            -
                        self.attn2.to_v_ip = nn.Linear(cross_attention_dim, self.dim, bias=False)
         | 
| 376 | 
            -
                        
         | 
| 377 | 
             
                    # multiview attn
         | 
| 378 | 
             
                    if self.use_ma:
         | 
| 379 | 
             
                        self.attn_multiview = Attention(
         | 
| @@ -385,7 +73,6 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 385 | 
             
                            cross_attention_dim=None,
         | 
| 386 | 
             
                            upcast_attention=self.attn1.upcast_attention,
         | 
| 387 | 
             
                            out_bias=True,
         | 
| 388 | 
            -
                            processor=PoseRoPEAttnProcessor2_0(),
         | 
| 389 | 
             
                        )
         | 
| 390 |  | 
| 391 | 
             
                    # ref attn
         | 
| @@ -400,8 +87,8 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 400 | 
             
                            upcast_attention=self.attn1.upcast_attention,
         | 
| 401 | 
             
                            out_bias=True,
         | 
| 402 | 
             
                        )
         | 
| 403 | 
            -
             | 
| 404 | 
            -
             | 
| 405 |  | 
| 406 | 
             
                def _initialize_attn_weights(self):
         | 
| 407 |  | 
| @@ -418,10 +105,6 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 418 | 
             
                                for param in layer.parameters():
         | 
| 419 | 
             
                                    param.zero_()
         | 
| 420 |  | 
| 421 | 
            -
                    if self.use_ipa:
         | 
| 422 | 
            -
                        self.attn2.to_k_ip.load_state_dict(self.attn2.to_k.state_dict()) 
         | 
| 423 | 
            -
                        self.attn2.to_v_ip.load_state_dict(self.attn2.to_v.state_dict()) 
         | 
| 424 | 
            -
             | 
| 425 | 
             
                def __getattr__(self, name: str):
         | 
| 426 | 
             
                    try:
         | 
| 427 | 
             
                        return super().__getattr__(name)
         | 
| @@ -447,10 +130,16 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 447 | 
             
                    cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
         | 
| 448 | 
             
                    num_in_batch = cross_attention_kwargs.pop('num_in_batch', 1)
         | 
| 449 | 
             
                    mode = cross_attention_kwargs.pop('mode', None)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 450 | 
             
                    condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None)
         | 
| 451 | 
            -
                    ip_hidden_states = cross_attention_kwargs.pop("ip_hidden_states", None)
         | 
| 452 | 
            -
                    position_attn_mask = cross_attention_kwargs.pop("position_attn_mask", None)
         | 
| 453 | 
            -
                    position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None)
         | 
| 454 |  | 
| 455 | 
             
                    if self.norm_type == "ada_norm":
         | 
| 456 | 
             
                        norm_hidden_states = self.norm1(hidden_states, timestep)
         | 
| @@ -470,10 +159,10 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 470 | 
             
                        norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
         | 
| 471 | 
             
                    else:
         | 
| 472 | 
             
                        raise ValueError("Incorrect norm used")
         | 
| 473 | 
            -
             | 
| 474 | 
             
                    if self.pos_embed is not None:
         | 
| 475 | 
             
                        norm_hidden_states = self.pos_embed(norm_hidden_states)
         | 
| 476 | 
            -
             | 
| 477 | 
             
                    # 1. Prepare GLIGEN inputs
         | 
| 478 | 
             
                    cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
         | 
| 479 | 
             
                    gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
         | 
| @@ -484,6 +173,7 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 484 | 
             
                        attention_mask=attention_mask,
         | 
| 485 | 
             
                        **cross_attention_kwargs,
         | 
| 486 | 
             
                    )
         | 
|  | |
| 487 | 
             
                    if self.norm_type == "ada_norm_zero":
         | 
| 488 | 
             
                        attn_output = gate_msa.unsqueeze(1) * attn_output
         | 
| 489 | 
             
                    elif self.norm_type == "ada_norm_single":
         | 
| @@ -492,13 +182,17 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 492 | 
             
                    hidden_states = attn_output + hidden_states
         | 
| 493 | 
             
                    if hidden_states.ndim == 4:
         | 
| 494 | 
             
                        hidden_states = hidden_states.squeeze(1)
         | 
| 495 | 
            -
             | 
| 496 | 
             
                    # 1.2 Reference Attention
         | 
| 497 | 
             
                    if 'w' in mode:
         | 
| 498 | 
            -
                        condition_embed_dict[self.layer_name] = rearrange( | 
| 499 | 
            -
             | 
| 500 | 
            -
             | 
| 501 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
| 502 | 
             
                        condition_embed = rearrange(condition_embed, 'b n l c -> (b n) l c')
         | 
| 503 |  | 
| 504 | 
             
                        attn_output = self.attn_refview(
         | 
| @@ -507,35 +201,48 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 507 | 
             
                            attention_mask=None,
         | 
| 508 | 
             
                            **cross_attention_kwargs
         | 
| 509 | 
             
                        )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 510 |  | 
| 511 | 
            -
                        hidden_states = attn_output + hidden_states
         | 
| 512 | 
             
                        if hidden_states.ndim == 4:
         | 
| 513 | 
             
                            hidden_states = hidden_states.squeeze(1)
         | 
| 514 | 
            -
                        
         | 
| 515 |  | 
| 516 | 
             
                    # 1.3 Multiview Attention
         | 
| 517 | 
             
                    if num_in_batch > 1 and self.use_ma:
         | 
| 518 | 
             
                        multivew_hidden_states = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch)
         | 
| 519 | 
            -
                        position_mask = None
         | 
| 520 | 
            -
                        if position_attn_mask is not None:
         | 
| 521 | 
            -
                            if multivew_hidden_states.shape[1] in position_attn_mask:
         | 
| 522 | 
            -
                                position_mask = position_attn_mask[multivew_hidden_states.shape[1]]
         | 
| 523 | 
            -
                        position_indices = None
         | 
| 524 | 
            -
                        if position_voxel_indices is not None:
         | 
| 525 | 
            -
                            if multivew_hidden_states.shape[1] in position_voxel_indices:
         | 
| 526 | 
            -
                                position_indices = position_voxel_indices[multivew_hidden_states.shape[1]]
         | 
| 527 | 
            -
             | 
| 528 | 
            -
                        attn_output = self.attn_multiview(
         | 
| 529 | 
            -
                            multivew_hidden_states,
         | 
| 530 | 
            -
                            encoder_hidden_states=multivew_hidden_states,
         | 
| 531 | 
            -
                            attention_mask=position_mask,
         | 
| 532 | 
            -
                            position_indices=position_indices,
         | 
| 533 | 
            -
                            **cross_attention_kwargs
         | 
| 534 | 
            -
                        )
         | 
| 535 |  | 
| 536 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 537 |  | 
| 538 | 
            -
                         | 
|  | |
|  | |
| 539 | 
             
                        if hidden_states.ndim == 4:
         | 
| 540 | 
             
                            hidden_states = hidden_states.squeeze(1)
         | 
| 541 |  | 
| @@ -561,25 +268,12 @@ class Basic2p5DTransformerBlock(torch.nn.Module): | |
| 561 | 
             
                        if self.pos_embed is not None and self.norm_type != "ada_norm_single":
         | 
| 562 | 
             
                            norm_hidden_states = self.pos_embed(norm_hidden_states)
         | 
| 563 |  | 
| 564 | 
            -
                         | 
| 565 | 
            -
                             | 
| 566 | 
            -
                             | 
| 567 | 
            -
             | 
| 568 | 
            -
             | 
| 569 | 
            -
             | 
| 570 | 
            -
                                norm_hidden_states,
         | 
| 571 | 
            -
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 572 | 
            -
                                ip_hidden_states=ip_hidden_states,
         | 
| 573 | 
            -
                                attention_mask=encoder_attention_mask,
         | 
| 574 | 
            -
                                **cross_attention_kwargs,
         | 
| 575 | 
            -
                            )
         | 
| 576 | 
            -
                        else:
         | 
| 577 | 
            -
                            attn_output = self.attn2(
         | 
| 578 | 
            -
                                norm_hidden_states,
         | 
| 579 | 
            -
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 580 | 
            -
                                attention_mask=encoder_attention_mask,
         | 
| 581 | 
            -
                                **cross_attention_kwargs,
         | 
| 582 | 
            -
                            )
         | 
| 583 |  | 
| 584 | 
             
                        hidden_states = attn_output + hidden_states
         | 
| 585 |  | 
| @@ -626,8 +320,16 @@ def compute_voxel_grid_mask(position, grid_resolution=8): | |
| 626 | 
             
                position[valid_mask==False] = 0
         | 
| 627 |  | 
| 628 |  | 
| 629 | 
            -
                position = rearrange( | 
| 630 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 631 |  | 
| 632 | 
             
                grid_position = position.sum(dim=(-2, -1))
         | 
| 633 | 
             
                count_masked = valid_mask.sum(dim=(-2, -1))
         | 
| @@ -674,8 +376,16 @@ def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution= | |
| 674 | 
             
                valid_mask = valid_mask.expand_as(position)
         | 
| 675 | 
             
                position[valid_mask==False] = 0
         | 
| 676 |  | 
| 677 | 
            -
                position = rearrange( | 
| 678 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 679 |  | 
| 680 | 
             
                grid_position = position.sum(dim=(-2, -1))
         | 
| 681 | 
             
                count_masked = valid_mask.sum(dim=(-2, -1))
         | 
| @@ -688,45 +398,36 @@ def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution= | |
| 688 | 
             
                voxel_indices = torch.round(voxel_indices).long()
         | 
| 689 | 
             
                return voxel_indices
         | 
| 690 |  | 
| 691 | 
            -
            def compute_multi_resolution_discrete_voxel_indice( | 
|  | |
|  | |
|  | |
|  | |
| 692 | 
             
                voxel_indices = {}
         | 
| 693 | 
             
                with torch.no_grad():
         | 
| 694 | 
             
                    for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions):
         | 
| 695 | 
             
                        voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution)
         | 
| 696 | 
             
                        voxel_indice = rearrange(voxel_indice, 'b n c h w -> b (n h w) c')
         | 
| 697 | 
             
                        voxel_indices[voxel_indice.shape[1]] = {'voxel_indices':voxel_indice, 'voxel_resolution':voxel_resolution}
         | 
| 698 | 
            -
                return voxel_indices | 
| 699 | 
            -
                
         | 
| 700 | 
            -
            class ImageProjModel(torch.nn.Module):
         | 
| 701 | 
            -
                """Projection Model"""
         | 
| 702 | 
            -
             | 
| 703 | 
            -
                def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
         | 
| 704 | 
            -
                    super().__init__()
         | 
| 705 | 
            -
             | 
| 706 | 
            -
                    self.generator = None
         | 
| 707 | 
            -
                    self.cross_attention_dim = cross_attention_dim
         | 
| 708 | 
            -
                    self.clip_extra_context_tokens = clip_extra_context_tokens
         | 
| 709 | 
            -
                    self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
         | 
| 710 | 
            -
                    self.norm = torch.nn.LayerNorm(cross_attention_dim)
         | 
| 711 |  | 
| 712 | 
            -
                def forward(self, image_embeds):
         | 
| 713 | 
            -
                    embeds = image_embeds
         | 
| 714 | 
            -
                    clip_extra_context_tokens = self.proj(embeds).reshape(
         | 
| 715 | 
            -
                        -1, self.clip_extra_context_tokens, self.cross_attention_dim
         | 
| 716 | 
            -
                    )
         | 
| 717 | 
            -
                    clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
         | 
| 718 | 
            -
                    return clip_extra_context_tokens
         | 
| 719 | 
            -
                    
         | 
| 720 | 
             
            class UNet2p5DConditionModel(torch.nn.Module):
         | 
| 721 | 
             
                def __init__(self, unet: UNet2DConditionModel) -> None:
         | 
| 722 | 
             
                    super().__init__()
         | 
| 723 | 
             
                    self.unet = unet
         | 
| 724 | 
            -
                    self.unet_dual = copy.deepcopy(unet)
         | 
| 725 |  | 
| 726 | 
            -
                    self. | 
| 727 | 
            -
                    self. | 
| 728 | 
            -
                    self. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 729 | 
             
                    self.init_condition()
         | 
|  | |
| 730 |  | 
| 731 | 
             
                @staticmethod
         | 
| 732 | 
             
                def from_pretrained(pretrained_model_name_or_path, **kwargs):
         | 
| @@ -737,170 +438,158 @@ class UNet2p5DConditionModel(torch.nn.Module): | |
| 737 | 
             
                        config = json.load(file)
         | 
| 738 | 
             
                    unet = UNet2DConditionModel(**config)
         | 
| 739 | 
             
                    unet = UNet2p5DConditionModel(unet)
         | 
| 740 | 
            -
             | 
| 741 | 
            -
                    unet.unet.conv_in = torch.nn.Conv2d(
         | 
| 742 | 
            -
                        12,
         | 
| 743 | 
            -
                        unet.unet.conv_in.out_channels,
         | 
| 744 | 
            -
                        kernel_size=unet.unet.conv_in.kernel_size,
         | 
| 745 | 
            -
                        stride=unet.unet.conv_in.stride,
         | 
| 746 | 
            -
                        padding=unet.unet.conv_in.padding,
         | 
| 747 | 
            -
                        dilation=unet.unet.conv_in.dilation,
         | 
| 748 | 
            -
                        groups=unet.unet.conv_in.groups,
         | 
| 749 | 
            -
                        bias=unet.unet.conv_in.bias is not None)
         | 
| 750 | 
            -
                    
         | 
| 751 | 
             
                    unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
         | 
| 752 | 
             
                    unet.load_state_dict(unet_ckpt, strict=True)
         | 
| 753 | 
             
                    unet = unet.to(torch_dtype)
         | 
| 754 | 
             
                    return unet
         | 
| 755 | 
            -
                    
         | 
| 756 | 
            -
                def init_condition(self):
         | 
| 757 | 
            -
                    self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1,77,1024))
         | 
| 758 | 
            -
                    self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1,77,1024))
         | 
| 759 |  | 
| 760 | 
            -
             | 
| 761 | 
            -
             | 
| 762 | 
            -
                         | 
| 763 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 764 |  | 
|  | |
|  | |
| 765 |  | 
| 766 | 
             
                def init_camera_embedding(self):
         | 
| 767 | 
            -
                    self.max_num_ref_image = 5
         | 
| 768 | 
            -
                    self.max_num_gen_image = 12*3+4*2
         | 
| 769 |  | 
| 770 | 
            -
                     | 
| 771 | 
            -
             | 
| 772 | 
            -
             | 
| 773 | 
            -
             | 
| 774 | 
            -
             | 
| 775 | 
            -
             | 
|  | |
| 776 |  | 
| 777 | 
             
                    for down_block_i, down_block in enumerate(unet.down_blocks):
         | 
| 778 | 
             
                        if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
         | 
| 779 | 
             
                            for attn_i, attn in enumerate(down_block.attentions):
         | 
| 780 | 
             
                                for transformer_i, transformer in enumerate(attn.transformer_blocks):
         | 
| 781 | 
             
                                    if isinstance(transformer, BasicTransformerBlock):
         | 
| 782 | 
            -
                                        attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( | 
|  | |
|  | |
|  | |
|  | |
| 783 |  | 
| 784 | 
             
                    if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
         | 
| 785 | 
             
                        for attn_i, attn in enumerate(unet.mid_block.attentions):
         | 
| 786 | 
             
                            for transformer_i, transformer in enumerate(attn.transformer_blocks):
         | 
| 787 | 
             
                                if isinstance(transformer, BasicTransformerBlock):
         | 
| 788 | 
            -
                                    attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( | 
|  | |
|  | |
|  | |
|  | |
| 789 |  | 
| 790 | 
             
                    for up_block_i, up_block in enumerate(unet.up_blocks):
         | 
| 791 | 
             
                        if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention:
         | 
| 792 | 
             
                            for attn_i, attn in enumerate(up_block.attentions):
         | 
| 793 | 
             
                                for transformer_i, transformer in enumerate(attn.transformer_blocks):
         | 
| 794 | 
             
                                    if isinstance(transformer, BasicTransformerBlock):
         | 
| 795 | 
            -
                                        attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( | 
| 796 | 
            -
             | 
|  | |
|  | |
|  | |
| 797 |  | 
| 798 | 
             
                def __getattr__(self, name: str):
         | 
| 799 | 
             
                    try:
         | 
| 800 | 
             
                        return super().__getattr__(name)
         | 
| 801 | 
             
                    except AttributeError:
         | 
| 802 | 
             
                        return getattr(self.unet, name)
         | 
| 803 | 
            -
             | 
| 804 | 
             
                def forward(
         | 
| 805 | 
            -
                    self, sample, timestep, encoder_hidden_states, | 
| 806 | 
            -
                    *args,  | 
| 807 | 
             
                    down_block_res_samples=None, mid_block_res_sample=None,
         | 
| 808 | 
             
                    **cached_condition,
         | 
| 809 | 
             
                ):
         | 
| 810 | 
             
                    B, N_gen, _, H, W = sample.shape
         | 
| 811 | 
            -
                     | 
| 812 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 813 | 
             
                    sample = [sample]
         | 
| 814 | 
            -
                    
         | 
| 815 | 
             
                    if 'normal_imgs' in cached_condition:
         | 
| 816 | 
             
                        sample.append(cached_condition["normal_imgs"])
         | 
| 817 | 
             
                    if 'position_imgs' in cached_condition:
         | 
| 818 | 
             
                        sample.append(cached_condition["position_imgs"])
         | 
| 819 | 
            -
             | 
| 820 | 
             
                    sample = torch.cat(sample, dim=2)
         | 
|  | |
| 821 | 
             
                    sample = rearrange(sample, 'b n c h w -> (b n) c h w')
         | 
| 822 |  | 
| 823 | 
             
                    encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(1, N_gen, 1, 1)
         | 
| 824 | 
             
                    encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, 'b n l c -> (b n) l c')
         | 
| 825 | 
            -
                    
         | 
| 826 | 
            -
                    
         | 
| 827 | 
            -
                    use_position_mask = False
         | 
| 828 | 
            -
                    use_position_rope = True
         | 
| 829 | 
            -
             | 
| 830 | 
            -
                    position_attn_mask = None
         | 
| 831 | 
            -
                    if use_position_mask:
         | 
| 832 | 
            -
                        if 'position_attn_mask' in cached_condition:
         | 
| 833 | 
            -
                            position_attn_mask = cached_condition['position_attn_mask']
         | 
| 834 | 
            -
                        else:
         | 
| 835 | 
            -
                            if 'position_maps' in cached_condition:
         | 
| 836 | 
            -
                                position_attn_mask = compute_multi_resolution_mask(cached_condition['position_maps'])
         | 
| 837 | 
            -
                    
         | 
| 838 | 
            -
                    position_voxel_indices = None
         | 
| 839 | 
            -
                    if use_position_rope:
         | 
| 840 | 
            -
                        if 'position_voxel_indices' in cached_condition:
         | 
| 841 | 
            -
                            position_voxel_indices = cached_condition['position_voxel_indices']
         | 
| 842 | 
            -
                        else:
         | 
| 843 | 
            -
                            if 'position_maps' in cached_condition:
         | 
| 844 | 
            -
                                position_voxel_indices = compute_multi_resolution_discrete_voxel_indice(cached_condition['position_maps'])
         | 
| 845 |  | 
| 846 | 
            -
                    if  | 
| 847 | 
            -
                         | 
| 848 | 
            -
             | 
| 849 | 
            -
                        if 'clip_embeds' in cached_condition:
         | 
| 850 | 
            -
                            ip_hidden_states = self.image_proj_model(cached_condition['clip_embeds'])
         | 
| 851 | 
             
                        else:
         | 
| 852 | 
            -
                             | 
| 853 | 
            -
             | 
| 854 | 
            -
             | 
| 855 | 
            -
             | 
| 856 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 857 | 
             
                    else:
         | 
| 858 | 
            -
                        condition_embed_dict =  | 
| 859 | 
            -
                        ref_latents = cached_condition['ref_latents']
         | 
| 860 | 
            -
                        N_ref = ref_latents.shape[1]
         | 
| 861 | 
            -
                        camera_info_ref = cached_condition['camera_info_ref']
         | 
| 862 | 
            -
                        camera_info_ref = rearrange(camera_info_ref, 'b n -> (b n)')
         | 
| 863 | 
            -
                        
         | 
| 864 | 
            -
                        #ref_latents = [ref_latents]
         | 
| 865 | 
            -
                        #if 'normal_imgs' in cached_condition:
         | 
| 866 | 
            -
                        #    ref_latents.append(torch.zeros_like(ref_latents[0]))
         | 
| 867 | 
            -
                        #if 'position_imgs' in cached_condition:
         | 
| 868 | 
            -
                        #    ref_latents.append(torch.zeros_like(ref_latents[0]))
         | 
| 869 | 
            -
                        #ref_latents = torch.cat(ref_latents, dim=2)
         | 
| 870 | 
            -
                        
         | 
| 871 | 
            -
                        ref_latents = rearrange(ref_latents, 'b n c h w -> (b n) c h w')
         | 
| 872 |  | 
| 873 | 
            -
             | 
| 874 | 
            -
             | 
| 875 | 
            -
             | 
| 876 | 
            -
                        noisy_ref_latents = ref_latents
         | 
| 877 | 
            -
                        timestep_ref = 0
         | 
| 878 | 
            -
                        '''
         | 
| 879 | 
            -
                        if timestep.dim()>0:
         | 
| 880 | 
            -
                            timestep_ref = rearrange(timestep, '(b n) -> b n', b=B)[:,:1].repeat(1, N_ref)
         | 
| 881 | 
            -
                            timestep_ref = rearrange(timestep_ref, 'b n -> (b n)')
         | 
| 882 | 
            -
                        else:
         | 
| 883 | 
            -
                            timestep_ref = timestep
         | 
| 884 | 
            -
                        noise = torch.randn_like(noisy_ref_latents[:,:4,...])
         | 
| 885 | 
            -
                        if self.training:
         | 
| 886 | 
            -
                            noisy_ref_latents[:,:4,...] = self.train_sched.add_noise(noisy_ref_latents[:,:4,...], noise, timestep_ref)
         | 
| 887 | 
            -
                            noisy_ref_latents[:,:4,...] = self.train_sched.scale_model_input(noisy_ref_latents[:,:4,...], timestep_ref)
         | 
| 888 | 
            -
                        else:
         | 
| 889 | 
            -
                            noisy_ref_latents[:,:4,...] = self.val_sched.add_noise(noisy_ref_latents[:,:4,...], noise, timestep_ref.reshape(-1))
         | 
| 890 | 
            -
                            noisy_ref_latents[:,:4,...] = self.val_sched.scale_model_input(noisy_ref_latents[:,:4,...], timestep_ref.reshape(-1))
         | 
| 891 | 
            -
                        '''
         | 
| 892 | 
            -
                        self.unet_dual(
         | 
| 893 | 
            -
                            noisy_ref_latents, timestep_ref,
         | 
| 894 | 
            -
                            encoder_hidden_states=encoder_hidden_states_ref,
         | 
| 895 | 
            -
                            #class_labels=camera_info_ref,
         | 
| 896 | 
            -
                            # **kwargs
         | 
| 897 | 
            -
                            return_dict=False,
         | 
| 898 | 
            -
                            cross_attention_kwargs={
         | 
| 899 | 
            -
                                'mode':'w', 'num_in_batch':N_ref, 
         | 
| 900 | 
            -
                                'condition_embed_dict':condition_embed_dict},
         | 
| 901 | 
            -
                        )
         | 
| 902 | 
            -
                        cached_condition['condition_embed_dict'] = condition_embed_dict
         | 
| 903 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 904 | 
             
                    return self.unet(
         | 
| 905 | 
             
                        sample, timestep,
         | 
| 906 | 
             
                        encoder_hidden_states_gen, *args,
         | 
| @@ -916,11 +605,6 @@ class UNet2p5DConditionModel(torch.nn.Module): | |
| 916 | 
             
                            if mid_block_res_sample is not None else None
         | 
| 917 | 
             
                        ),
         | 
| 918 | 
             
                        return_dict=False,
         | 
| 919 | 
            -
                        cross_attention_kwargs= | 
| 920 | 
            -
             | 
| 921 | 
            -
             | 
| 922 | 
            -
                            'condition_embed_dict':condition_embed_dict, 
         | 
| 923 | 
            -
                            'position_attn_mask':position_attn_mask, 
         | 
| 924 | 
            -
                            'position_voxel_indices':position_voxel_indices
         | 
| 925 | 
            -
                        },
         | 
| 926 | 
            -
                    ) 
         | 
|  | |
| 22 | 
             
            # fine-tuning enabling code and other elements of the foregoing made publicly available
         | 
| 23 | 
             
            # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
         | 
| 24 |  | 
|  | |
| 25 | 
             
            import copy
         | 
| 26 | 
             
            import json
         | 
| 27 | 
             
            import os
         | 
|  | |
| 40 | 
             
                # "feed_forward_chunk_size" can be used to save memory
         | 
| 41 | 
             
                if hidden_states.shape[chunk_dim] % chunk_size != 0:
         | 
| 42 | 
             
                    raise ValueError(
         | 
| 43 | 
            +
                        f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]}"
         | 
| 44 | 
            +
                        f"has to be divisible by chunk size: {chunk_size}."
         | 
| 45 | 
            +
                        f" Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
         | 
| 46 | 
             
                    )
         | 
| 47 |  | 
| 48 | 
             
                num_chunks = hidden_states.shape[chunk_dim] // chunk_size
         | 
|  | |
| 52 | 
             
                )
         | 
| 53 | 
             
                return ff_output
         | 
| 54 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 55 |  | 
| 56 | 
             
            class Basic2p5DTransformerBlock(torch.nn.Module):
         | 
| 57 | 
            +
                def __init__(self, transformer: BasicTransformerBlock, layer_name, use_ma=True, use_ra=True, is_turbo=False) -> None:
         | 
| 58 | 
             
                    super().__init__()
         | 
| 59 | 
             
                    self.transformer = transformer
         | 
| 60 | 
             
                    self.layer_name = layer_name
         | 
|  | |
| 61 | 
             
                    self.use_ma = use_ma
         | 
| 62 | 
             
                    self.use_ra = use_ra
         | 
| 63 | 
            +
                    self.is_turbo = is_turbo
         | 
| 64 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 65 | 
             
                    # multiview attn
         | 
| 66 | 
             
                    if self.use_ma:
         | 
| 67 | 
             
                        self.attn_multiview = Attention(
         | 
|  | |
| 73 | 
             
                            cross_attention_dim=None,
         | 
| 74 | 
             
                            upcast_attention=self.attn1.upcast_attention,
         | 
| 75 | 
             
                            out_bias=True,
         | 
|  | |
| 76 | 
             
                        )
         | 
| 77 |  | 
| 78 | 
             
                    # ref attn
         | 
|  | |
| 87 | 
             
                            upcast_attention=self.attn1.upcast_attention,
         | 
| 88 | 
             
                            out_bias=True,
         | 
| 89 | 
             
                        )
         | 
| 90 | 
            +
                    if self.is_turbo:
         | 
| 91 | 
            +
                        self._initialize_attn_weights()
         | 
| 92 |  | 
| 93 | 
             
                def _initialize_attn_weights(self):
         | 
| 94 |  | 
|  | |
| 105 | 
             
                                for param in layer.parameters():
         | 
| 106 | 
             
                                    param.zero_()
         | 
| 107 |  | 
|  | |
|  | |
|  | |
|  | |
| 108 | 
             
                def __getattr__(self, name: str):
         | 
| 109 | 
             
                    try:
         | 
| 110 | 
             
                        return super().__getattr__(name)
         | 
|  | |
| 130 | 
             
                    cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
         | 
| 131 | 
             
                    num_in_batch = cross_attention_kwargs.pop('num_in_batch', 1)
         | 
| 132 | 
             
                    mode = cross_attention_kwargs.pop('mode', None)
         | 
| 133 | 
            +
                    if not self.is_turbo:
         | 
| 134 | 
            +
                        mva_scale = cross_attention_kwargs.pop('mva_scale', 1.0)
         | 
| 135 | 
            +
                        ref_scale = cross_attention_kwargs.pop('ref_scale', 1.0)
         | 
| 136 | 
            +
                    else:
         | 
| 137 | 
            +
                        position_attn_mask = cross_attention_kwargs.pop("position_attn_mask", None)
         | 
| 138 | 
            +
                        position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None)
         | 
| 139 | 
            +
                        mva_scale = 1.0
         | 
| 140 | 
            +
                        ref_scale = 1.0
         | 
| 141 | 
            +
                        
         | 
| 142 | 
             
                    condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None)
         | 
|  | |
|  | |
|  | |
| 143 |  | 
| 144 | 
             
                    if self.norm_type == "ada_norm":
         | 
| 145 | 
             
                        norm_hidden_states = self.norm1(hidden_states, timestep)
         | 
|  | |
| 159 | 
             
                        norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
         | 
| 160 | 
             
                    else:
         | 
| 161 | 
             
                        raise ValueError("Incorrect norm used")
         | 
| 162 | 
            +
             | 
| 163 | 
             
                    if self.pos_embed is not None:
         | 
| 164 | 
             
                        norm_hidden_states = self.pos_embed(norm_hidden_states)
         | 
| 165 | 
            +
             | 
| 166 | 
             
                    # 1. Prepare GLIGEN inputs
         | 
| 167 | 
             
                    cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
         | 
| 168 | 
             
                    gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
         | 
|  | |
| 173 | 
             
                        attention_mask=attention_mask,
         | 
| 174 | 
             
                        **cross_attention_kwargs,
         | 
| 175 | 
             
                    )
         | 
| 176 | 
            +
             | 
| 177 | 
             
                    if self.norm_type == "ada_norm_zero":
         | 
| 178 | 
             
                        attn_output = gate_msa.unsqueeze(1) * attn_output
         | 
| 179 | 
             
                    elif self.norm_type == "ada_norm_single":
         | 
|  | |
| 182 | 
             
                    hidden_states = attn_output + hidden_states
         | 
| 183 | 
             
                    if hidden_states.ndim == 4:
         | 
| 184 | 
             
                        hidden_states = hidden_states.squeeze(1)
         | 
| 185 | 
            +
             | 
| 186 | 
             
                    # 1.2 Reference Attention
         | 
| 187 | 
             
                    if 'w' in mode:
         | 
| 188 | 
            +
                        condition_embed_dict[self.layer_name] = rearrange(
         | 
| 189 | 
            +
                            norm_hidden_states, '(b n) l c -> b (n l) c',
         | 
| 190 | 
            +
                            n=num_in_batch
         | 
| 191 | 
            +
                        )  # B, (N L), C
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    if 'r' in mode and self.use_ra:
         | 
| 194 | 
            +
                        condition_embed = condition_embed_dict[self.layer_name].unsqueeze(1).repeat(1, num_in_batch, 1,
         | 
| 195 | 
            +
                                                                                                    1)  # B N L C
         | 
| 196 | 
             
                        condition_embed = rearrange(condition_embed, 'b n l c -> (b n) l c')
         | 
| 197 |  | 
| 198 | 
             
                        attn_output = self.attn_refview(
         | 
|  | |
| 201 | 
             
                            attention_mask=None,
         | 
| 202 | 
             
                            **cross_attention_kwargs
         | 
| 203 | 
             
                        )
         | 
| 204 | 
            +
                        if not self.is_turbo:
         | 
| 205 | 
            +
                            ref_scale_timing = ref_scale
         | 
| 206 | 
            +
                            if isinstance(ref_scale, torch.Tensor):
         | 
| 207 | 
            +
                                ref_scale_timing = ref_scale.unsqueeze(1).repeat(1, num_in_batch).view(-1)
         | 
| 208 | 
            +
                                for _ in range(attn_output.ndim - 1):
         | 
| 209 | 
            +
                                    ref_scale_timing = ref_scale_timing.unsqueeze(-1)
         | 
| 210 | 
            +
                                    
         | 
| 211 | 
            +
                        hidden_states = ref_scale_timing * attn_output + hidden_states
         | 
| 212 |  | 
|  | |
| 213 | 
             
                        if hidden_states.ndim == 4:
         | 
| 214 | 
             
                            hidden_states = hidden_states.squeeze(1)
         | 
|  | |
| 215 |  | 
| 216 | 
             
                    # 1.3 Multiview Attention
         | 
| 217 | 
             
                    if num_in_batch > 1 and self.use_ma:
         | 
| 218 | 
             
                        multivew_hidden_states = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 219 |  | 
| 220 | 
            +
                        if self.is_turbo:
         | 
| 221 | 
            +
                            position_mask = None
         | 
| 222 | 
            +
                            if position_attn_mask is not None:
         | 
| 223 | 
            +
                                if multivew_hidden_states.shape[1] in position_attn_mask:
         | 
| 224 | 
            +
                                    position_mask = position_attn_mask[multivew_hidden_states.shape[1]]
         | 
| 225 | 
            +
                            position_indices = None
         | 
| 226 | 
            +
                            if position_voxel_indices is not None:
         | 
| 227 | 
            +
                                if multivew_hidden_states.shape[1] in position_voxel_indices:
         | 
| 228 | 
            +
                                    position_indices = position_voxel_indices[multivew_hidden_states.shape[1]]
         | 
| 229 | 
            +
                            attn_output = self.attn_multiview(
         | 
| 230 | 
            +
                                multivew_hidden_states,
         | 
| 231 | 
            +
                                encoder_hidden_states=multivew_hidden_states,
         | 
| 232 | 
            +
                                attention_mask=position_mask,
         | 
| 233 | 
            +
                                position_indices=position_indices,
         | 
| 234 | 
            +
                                **cross_attention_kwargs
         | 
| 235 | 
            +
                            )
         | 
| 236 | 
            +
                        else:
         | 
| 237 | 
            +
                            attn_output = self.attn_multiview(
         | 
| 238 | 
            +
                                multivew_hidden_states,
         | 
| 239 | 
            +
                                encoder_hidden_states=multivew_hidden_states,
         | 
| 240 | 
            +
                                **cross_attention_kwargs
         | 
| 241 | 
            +
                            )
         | 
| 242 |  | 
| 243 | 
            +
                        attn_output = rearrange(attn_output, 'b (n l) c -> (b n) l c', n=num_in_batch)
         | 
| 244 | 
            +
                        
         | 
| 245 | 
            +
                        hidden_states = mva_scale * attn_output + hidden_states
         | 
| 246 | 
             
                        if hidden_states.ndim == 4:
         | 
| 247 | 
             
                            hidden_states = hidden_states.squeeze(1)
         | 
| 248 |  | 
|  | |
| 268 | 
             
                        if self.pos_embed is not None and self.norm_type != "ada_norm_single":
         | 
| 269 | 
             
                            norm_hidden_states = self.pos_embed(norm_hidden_states)
         | 
| 270 |  | 
| 271 | 
            +
                        attn_output = self.attn2(
         | 
| 272 | 
            +
                            norm_hidden_states,
         | 
| 273 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 274 | 
            +
                            attention_mask=encoder_attention_mask,
         | 
| 275 | 
            +
                            **cross_attention_kwargs,
         | 
| 276 | 
            +
                        )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 277 |  | 
| 278 | 
             
                        hidden_states = attn_output + hidden_states
         | 
| 279 |  | 
|  | |
| 320 | 
             
                position[valid_mask==False] = 0
         | 
| 321 |  | 
| 322 |  | 
| 323 | 
            +
                position = rearrange(
         | 
| 324 | 
            +
                    position,
         | 
| 325 | 
            +
                    'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', 
         | 
| 326 | 
            +
                    num_h=grid_resolution, num_w=grid_resolution
         | 
| 327 | 
            +
                )
         | 
| 328 | 
            +
                valid_mask = rearrange(
         | 
| 329 | 
            +
                    valid_mask, 
         | 
| 330 | 
            +
                    'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', 
         | 
| 331 | 
            +
                    num_h=grid_resolution, num_w=grid_resolution
         | 
| 332 | 
            +
                )
         | 
| 333 |  | 
| 334 | 
             
                grid_position = position.sum(dim=(-2, -1))
         | 
| 335 | 
             
                count_masked = valid_mask.sum(dim=(-2, -1))
         | 
|  | |
| 376 | 
             
                valid_mask = valid_mask.expand_as(position)
         | 
| 377 | 
             
                position[valid_mask==False] = 0
         | 
| 378 |  | 
| 379 | 
            +
                position = rearrange(
         | 
| 380 | 
            +
                    position, 
         | 
| 381 | 
            +
                    'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', 
         | 
| 382 | 
            +
                    num_h=grid_resolution, num_w=grid_resolution
         | 
| 383 | 
            +
                )
         | 
| 384 | 
            +
                valid_mask = rearrange(
         | 
| 385 | 
            +
                    valid_mask, 
         | 
| 386 | 
            +
                    'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', 
         | 
| 387 | 
            +
                    num_h=grid_resolution, num_w=grid_resolution
         | 
| 388 | 
            +
                )
         | 
| 389 |  | 
| 390 | 
             
                grid_position = position.sum(dim=(-2, -1))
         | 
| 391 | 
             
                count_masked = valid_mask.sum(dim=(-2, -1))
         | 
|  | |
| 398 | 
             
                voxel_indices = torch.round(voxel_indices).long()
         | 
| 399 | 
             
                return voxel_indices
         | 
| 400 |  | 
| 401 | 
            +
            def compute_multi_resolution_discrete_voxel_indice(
         | 
| 402 | 
            +
                position_maps, 
         | 
| 403 | 
            +
                grid_resolutions=[64, 32, 16, 8], 
         | 
| 404 | 
            +
                voxel_resolutions=[512, 256, 128, 64]
         | 
| 405 | 
            +
            ):
         | 
| 406 | 
             
                voxel_indices = {}
         | 
| 407 | 
             
                with torch.no_grad():
         | 
| 408 | 
             
                    for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions):
         | 
| 409 | 
             
                        voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution)
         | 
| 410 | 
             
                        voxel_indice = rearrange(voxel_indice, 'b n c h w -> b (n h w) c')
         | 
| 411 | 
             
                        voxel_indices[voxel_indice.shape[1]] = {'voxel_indices':voxel_indice, 'voxel_resolution':voxel_resolution}
         | 
| 412 | 
            +
                return voxel_indices
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 413 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 414 | 
             
            class UNet2p5DConditionModel(torch.nn.Module):
         | 
| 415 | 
             
                def __init__(self, unet: UNet2DConditionModel) -> None:
         | 
| 416 | 
             
                    super().__init__()
         | 
| 417 | 
             
                    self.unet = unet
         | 
|  | |
| 418 |  | 
| 419 | 
            +
                    self.use_ma = True
         | 
| 420 | 
            +
                    self.use_ra = True
         | 
| 421 | 
            +
                    self.use_camera_embedding = True
         | 
| 422 | 
            +
                    self.use_dual_stream = True
         | 
| 423 | 
            +
                    self.is_turbo = False
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    if self.use_dual_stream:
         | 
| 426 | 
            +
                        self.unet_dual = copy.deepcopy(unet)
         | 
| 427 | 
            +
                        self.init_attention(self.unet_dual)
         | 
| 428 | 
            +
                    self.init_attention(self.unet, use_ma=self.use_ma, use_ra=self.use_ra, is_turbo=self.is_turbo)
         | 
| 429 | 
             
                    self.init_condition()
         | 
| 430 | 
            +
                    self.init_camera_embedding()
         | 
| 431 |  | 
| 432 | 
             
                @staticmethod
         | 
| 433 | 
             
                def from_pretrained(pretrained_model_name_or_path, **kwargs):
         | 
|  | |
| 438 | 
             
                        config = json.load(file)
         | 
| 439 | 
             
                    unet = UNet2DConditionModel(**config)
         | 
| 440 | 
             
                    unet = UNet2p5DConditionModel(unet)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 441 | 
             
                    unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
         | 
| 442 | 
             
                    unet.load_state_dict(unet_ckpt, strict=True)
         | 
| 443 | 
             
                    unet = unet.to(torch_dtype)
         | 
| 444 | 
             
                    return unet
         | 
|  | |
|  | |
|  | |
|  | |
| 445 |  | 
| 446 | 
            +
                def init_condition(self):
         | 
| 447 | 
            +
                    self.unet.conv_in = torch.nn.Conv2d(
         | 
| 448 | 
            +
                        12,
         | 
| 449 | 
            +
                        self.unet.conv_in.out_channels,
         | 
| 450 | 
            +
                        kernel_size=self.unet.conv_in.kernel_size,
         | 
| 451 | 
            +
                        stride=self.unet.conv_in.stride,
         | 
| 452 | 
            +
                        padding=self.unet.conv_in.padding,
         | 
| 453 | 
            +
                        dilation=self.unet.conv_in.dilation,
         | 
| 454 | 
            +
                        groups=self.unet.conv_in.groups,
         | 
| 455 | 
            +
                        bias=self.unet.conv_in.bias is not None)
         | 
| 456 |  | 
| 457 | 
            +
                    self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1, 77, 1024))
         | 
| 458 | 
            +
                    self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1, 77, 1024))
         | 
| 459 |  | 
| 460 | 
             
                def init_camera_embedding(self):
         | 
|  | |
|  | |
| 461 |  | 
| 462 | 
            +
                    if self.use_camera_embedding:
         | 
| 463 | 
            +
                        time_embed_dim = 1280
         | 
| 464 | 
            +
                        self.max_num_ref_image = 5
         | 
| 465 | 
            +
                        self.max_num_gen_image = 12 * 3 + 4 * 2
         | 
| 466 | 
            +
                        self.unet.class_embedding = nn.Embedding(self.max_num_ref_image + self.max_num_gen_image, time_embed_dim)
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                def init_attention(self, unet, use_ma=False, use_ra=False, is_turbo=False):
         | 
| 469 |  | 
| 470 | 
             
                    for down_block_i, down_block in enumerate(unet.down_blocks):
         | 
| 471 | 
             
                        if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
         | 
| 472 | 
             
                            for attn_i, attn in enumerate(down_block.attentions):
         | 
| 473 | 
             
                                for transformer_i, transformer in enumerate(attn.transformer_blocks):
         | 
| 474 | 
             
                                    if isinstance(transformer, BasicTransformerBlock):
         | 
| 475 | 
            +
                                        attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
         | 
| 476 | 
            +
                                            transformer,
         | 
| 477 | 
            +
                                            f'down_{down_block_i}_{attn_i}_{transformer_i}',
         | 
| 478 | 
            +
                                            use_ma, use_ra, is_turbo
         | 
| 479 | 
            +
                                        )
         | 
| 480 |  | 
| 481 | 
             
                    if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
         | 
| 482 | 
             
                        for attn_i, attn in enumerate(unet.mid_block.attentions):
         | 
| 483 | 
             
                            for transformer_i, transformer in enumerate(attn.transformer_blocks):
         | 
| 484 | 
             
                                if isinstance(transformer, BasicTransformerBlock):
         | 
| 485 | 
            +
                                    attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
         | 
| 486 | 
            +
                                        transformer,
         | 
| 487 | 
            +
                                        f'mid_{attn_i}_{transformer_i}',
         | 
| 488 | 
            +
                                        use_ma, use_ra, is_turbo
         | 
| 489 | 
            +
                                    )
         | 
| 490 |  | 
| 491 | 
             
                    for up_block_i, up_block in enumerate(unet.up_blocks):
         | 
| 492 | 
             
                        if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention:
         | 
| 493 | 
             
                            for attn_i, attn in enumerate(up_block.attentions):
         | 
| 494 | 
             
                                for transformer_i, transformer in enumerate(attn.transformer_blocks):
         | 
| 495 | 
             
                                    if isinstance(transformer, BasicTransformerBlock):
         | 
| 496 | 
            +
                                        attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
         | 
| 497 | 
            +
                                            transformer,
         | 
| 498 | 
            +
                                            f'up_{up_block_i}_{attn_i}_{transformer_i}',
         | 
| 499 | 
            +
                                            use_ma, use_ra, is_turbo
         | 
| 500 | 
            +
                                        )
         | 
| 501 |  | 
| 502 | 
             
                def __getattr__(self, name: str):
         | 
| 503 | 
             
                    try:
         | 
| 504 | 
             
                        return super().__getattr__(name)
         | 
| 505 | 
             
                    except AttributeError:
         | 
| 506 | 
             
                        return getattr(self.unet, name)
         | 
| 507 | 
            +
             | 
| 508 | 
             
                def forward(
         | 
| 509 | 
            +
                    self, sample, timestep, encoder_hidden_states,
         | 
| 510 | 
            +
                    *args, down_intrablock_additional_residuals=None,
         | 
| 511 | 
             
                    down_block_res_samples=None, mid_block_res_sample=None,
         | 
| 512 | 
             
                    **cached_condition,
         | 
| 513 | 
             
                ):
         | 
| 514 | 
             
                    B, N_gen, _, H, W = sample.shape
         | 
| 515 | 
            +
                    assert H == W
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    if self.use_camera_embedding:
         | 
| 518 | 
            +
                        camera_info_gen = cached_condition['camera_info_gen'] + self.max_num_ref_image
         | 
| 519 | 
            +
                        camera_info_gen = rearrange(camera_info_gen, 'b n -> (b n)')
         | 
| 520 | 
            +
                    else:
         | 
| 521 | 
            +
                        camera_info_gen = None
         | 
| 522 | 
            +
             | 
| 523 | 
             
                    sample = [sample]
         | 
|  | |
| 524 | 
             
                    if 'normal_imgs' in cached_condition:
         | 
| 525 | 
             
                        sample.append(cached_condition["normal_imgs"])
         | 
| 526 | 
             
                    if 'position_imgs' in cached_condition:
         | 
| 527 | 
             
                        sample.append(cached_condition["position_imgs"])
         | 
|  | |
| 528 | 
             
                    sample = torch.cat(sample, dim=2)
         | 
| 529 | 
            +
             | 
| 530 | 
             
                    sample = rearrange(sample, 'b n c h w -> (b n) c h w')
         | 
| 531 |  | 
| 532 | 
             
                    encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(1, N_gen, 1, 1)
         | 
| 533 | 
             
                    encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, 'b n l c -> (b n) l c')
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 534 |  | 
| 535 | 
            +
                    if self.use_ra:
         | 
| 536 | 
            +
                        if 'condition_embed_dict' in cached_condition:
         | 
| 537 | 
            +
                            condition_embed_dict = cached_condition['condition_embed_dict']
         | 
|  | |
|  | |
| 538 | 
             
                        else:
         | 
| 539 | 
            +
                            condition_embed_dict = {}
         | 
| 540 | 
            +
                            ref_latents = cached_condition['ref_latents']
         | 
| 541 | 
            +
                            N_ref = ref_latents.shape[1]
         | 
| 542 | 
            +
                            if self.use_camera_embedding:
         | 
| 543 | 
            +
                                camera_info_ref = cached_condition['camera_info_ref']
         | 
| 544 | 
            +
                                camera_info_ref = rearrange(camera_info_ref, 'b n -> (b n)')
         | 
| 545 | 
            +
                            else:
         | 
| 546 | 
            +
                                camera_info_ref = None
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                            ref_latents = rearrange(ref_latents, 'b n c h w -> (b n) c h w')
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                            encoder_hidden_states_ref = self.unet.learned_text_clip_ref.unsqueeze(1).repeat(B, N_ref, 1, 1)
         | 
| 551 | 
            +
                            encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, 'b n l c -> (b n) l c')
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                            noisy_ref_latents = ref_latents
         | 
| 554 | 
            +
                            timestep_ref = 0
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                            if self.use_dual_stream:
         | 
| 557 | 
            +
                                unet_ref = self.unet_dual
         | 
| 558 | 
            +
                            else:
         | 
| 559 | 
            +
                                unet_ref = self.unet
         | 
| 560 | 
            +
                            unet_ref(
         | 
| 561 | 
            +
                                noisy_ref_latents, timestep_ref,
         | 
| 562 | 
            +
                                encoder_hidden_states=encoder_hidden_states_ref,
         | 
| 563 | 
            +
                                class_labels=camera_info_ref,
         | 
| 564 | 
            +
                                # **kwargs
         | 
| 565 | 
            +
                                return_dict=False,
         | 
| 566 | 
            +
                                cross_attention_kwargs={
         | 
| 567 | 
            +
                                    'mode': 'w', 'num_in_batch': N_ref,
         | 
| 568 | 
            +
                                    'condition_embed_dict': condition_embed_dict},
         | 
| 569 | 
            +
                            )
         | 
| 570 | 
            +
                            cached_condition['condition_embed_dict'] = condition_embed_dict
         | 
| 571 | 
             
                    else:
         | 
| 572 | 
            +
                        condition_embed_dict = None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 573 |  | 
| 574 | 
            +
                    mva_scale = cached_condition.get('mva_scale', 1.0)
         | 
| 575 | 
            +
                    ref_scale = cached_condition.get('ref_scale', 1.0)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 576 |  | 
| 577 | 
            +
                    if self.is_turbo:
         | 
| 578 | 
            +
                        cross_attention_kwargs_ = {
         | 
| 579 | 
            +
                            'mode': 'r', 'num_in_batch': N_gen,
         | 
| 580 | 
            +
                            'condition_embed_dict': condition_embed_dict,
         | 
| 581 | 
            +
                            'position_attn_mask':position_attn_mask, 
         | 
| 582 | 
            +
                            'position_voxel_indices':position_voxel_indices,
         | 
| 583 | 
            +
                            'mva_scale': mva_scale,
         | 
| 584 | 
            +
                            'ref_scale': ref_scale,
         | 
| 585 | 
            +
                        }
         | 
| 586 | 
            +
                    else:
         | 
| 587 | 
            +
                        cross_attention_kwargs_ = {
         | 
| 588 | 
            +
                            'mode': 'r', 'num_in_batch': N_gen,
         | 
| 589 | 
            +
                            'condition_embed_dict': condition_embed_dict,
         | 
| 590 | 
            +
                            'mva_scale': mva_scale,
         | 
| 591 | 
            +
                            'ref_scale': ref_scale,
         | 
| 592 | 
            +
                        }
         | 
| 593 | 
             
                    return self.unet(
         | 
| 594 | 
             
                        sample, timestep,
         | 
| 595 | 
             
                        encoder_hidden_states_gen, *args,
         | 
|  | |
| 605 | 
             
                            if mid_block_res_sample is not None else None
         | 
| 606 | 
             
                        ),
         | 
| 607 | 
             
                        return_dict=False,
         | 
| 608 | 
            +
                        cross_attention_kwargs=cross_attention_kwargs_,
         | 
| 609 | 
            +
                    )
         | 
| 610 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | 
