# Open Source Model Licensed under the Apache License Version 2.0 # and Other Licenses of the Third-Party Components therein: # The below Model in this distribution may have been modified by THL A29 Limited # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. # The below software and/or models in this distribution may have been # modified by THL A29 Limited ("Tencent Modifications"). # All Tencent Modifications are Copyright (C) THL A29 Limited. # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT # except for the third-party components listed below. # Hunyuan 3D does not impose any additional limitations beyond what is outlined # in the repsective licenses of these third-party components. # Users must comply with all terms and conditions of original licenses of these third-party # components and must ensure that the usage of the third party components adheres to # all relevant laws and regulations. # For avoidance of doubts, Hunyuan 3D means the large language models and # their software and algorithms, including trained model weights, parameters (including # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, # fine-tuning enabling code and other elements of the foregoing made publicly available # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. import os from typing import Optional, Union, List import torch import torch.nn as nn from einops import rearrange from torch import Tensor from .attention_processors import CrossAttentionProcessor from ...utils.misc import logger scaled_dot_product_attention = nn.functional.scaled_dot_product_attention if os.environ.get("USE_SAGEATTN", "0") == "1": try: from sageattention import sageattn except ImportError: raise ImportError( 'Please install the package "sageattention" to use this USE_SAGEATTN.' ) scaled_dot_product_attention = sageattn class FourierEmbedder(nn.Module): """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts each feature dimension of `x[..., i]` into: [ sin(x[..., i]), sin(f_1*x[..., i]), sin(f_2*x[..., i]), ... sin(f_N * x[..., i]), cos(x[..., i]), cos(f_1*x[..., i]), cos(f_2*x[..., i]), ... cos(f_N * x[..., i]), x[..., i] # only present if include_input is True. ], here f_i is the frequency. Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. Args: num_freqs (int): the number of frequencies, default is 6; logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; input_dim (int): the input dimension, default is 3; include_input (bool): include the input tensor or not, default is True. Attributes: frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), otherwise, it is input_dim * num_freqs * 2. """ def __init__( self, num_freqs: int = 6, logspace: bool = True, input_dim: int = 3, include_input: bool = True, include_pi: bool = True, ) -> None: """The initialization""" super().__init__() if logspace: frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) else: frequencies = torch.linspace( 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 ) if include_pi: frequencies *= torch.pi self.register_buffer("frequencies", frequencies, persistent=False) self.include_input = include_input self.num_freqs = num_freqs self.out_dim = self.get_dims(input_dim) def get_dims(self, input_dim): temp = 1 if self.include_input or self.num_freqs == 0 else 0 out_dim = input_dim * (self.num_freqs * 2 + temp) return out_dim def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward process. Args: x: tensor of shape [..., dim] Returns: embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] where temp is 1 if include_input is True and 0 otherwise. """ if self.num_freqs > 0: embed = (x[..., None].contiguous() * self.frequencies).view( *x.shape[:-1], -1 ) if self.include_input: return torch.cat((x, embed.sin(), embed.cos()), dim=-1) else: return torch.cat((embed.sin(), embed.cos()), dim=-1) else: return x class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if self.drop_prob == 0.0 or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and self.scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor def extra_repr(self): return f"drop_prob={round(self.drop_prob, 3):0.3f}" class MLP(nn.Module): def __init__( self, *, width: int, expand_ratio: int = 4, output_width: int = None, drop_path_rate: float = 0.0, ): super().__init__() self.width = width self.c_fc = nn.Linear(width, width * expand_ratio) self.c_proj = nn.Linear( width * expand_ratio, output_width if output_width is not None else width ) self.gelu = nn.GELU() self.drop_path = ( DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() ) def forward(self, x): return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) class QKVMultiheadCrossAttention(nn.Module): def __init__( self, *, heads: int, width=None, qk_norm=False, norm_layer=nn.LayerNorm, ): super().__init__() self.heads = heads self.q_norm = ( norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.k_norm = ( norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.attn_processor = CrossAttentionProcessor() def forward(self, q, kv): _, n_ctx, _ = q.shape bs, n_data, width = kv.shape attn_ch = width // self.heads // 2 q = q.view(bs, n_ctx, self.heads, -1) kv = kv.view(bs, n_data, self.heads, -1) k, v = torch.split(kv, attn_ch, dim=-1) q = self.q_norm(q) k = self.k_norm(k) q, k, v = map( lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) ) out = self.attn_processor(self, q, k, v) out = out.transpose(1, 2).reshape(bs, n_ctx, -1) return out class MultiheadCrossAttention(nn.Module): def __init__( self, *, width: int, heads: int, qkv_bias: bool = True, data_width: Optional[int] = None, norm_layer=nn.LayerNorm, qk_norm: bool = False, kv_cache: bool = False, ): super().__init__() self.width = width self.heads = heads self.data_width = width if data_width is None else data_width self.c_q = nn.Linear(width, width, bias=qkv_bias) self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) self.c_proj = nn.Linear(width, width) self.attention = QKVMultiheadCrossAttention( heads=heads, width=width, norm_layer=norm_layer, qk_norm=qk_norm, ) self.kv_cache = kv_cache self.data = None def forward(self, x, data): x = self.c_q(x) if self.kv_cache: if self.data is None: self.data = self.c_kv(data) logger.info( "Save kv cache,this should be called only once for one mesh" ) data = self.data else: data = self.c_kv(data) x = self.attention(x, data) x = self.c_proj(x) return x class ResidualCrossAttentionBlock(nn.Module): def __init__( self, *, width: int, heads: int, mlp_expand_ratio: int = 4, data_width: Optional[int] = None, qkv_bias: bool = True, norm_layer=nn.LayerNorm, qk_norm: bool = False, ): super().__init__() if data_width is None: data_width = width self.attn = MultiheadCrossAttention( width=width, heads=heads, data_width=data_width, qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm, ) self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) def forward(self, x: torch.Tensor, data: torch.Tensor): x = x + self.attn(self.ln_1(x), self.ln_2(data)) x = x + self.mlp(self.ln_3(x)) return x class QKVMultiheadAttention(nn.Module): def __init__( self, *, heads: int, width=None, qk_norm=False, norm_layer=nn.LayerNorm ): super().__init__() self.heads = heads self.q_norm = ( norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.k_norm = ( norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) def forward(self, qkv): bs, n_ctx, width = qkv.shape attn_ch = width // self.heads // 3 qkv = qkv.view(bs, n_ctx, self.heads, -1) q, k, v = torch.split(qkv, attn_ch, dim=-1) q = self.q_norm(q) k = self.k_norm(k) q, k, v = map( lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) ) out = ( scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) ) return out class MultiheadAttention(nn.Module): def __init__( self, *, width: int, heads: int, qkv_bias: bool, norm_layer=nn.LayerNorm, qk_norm: bool = False, drop_path_rate: float = 0.0, ): super().__init__() self.width = width self.heads = heads self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias) self.c_proj = nn.Linear(width, width) self.attention = QKVMultiheadAttention( heads=heads, width=width, norm_layer=norm_layer, qk_norm=qk_norm, ) self.drop_path = ( DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() ) def forward(self, x): x = self.c_qkv(x) x = self.attention(x) x = self.drop_path(self.c_proj(x)) return x class ResidualAttentionBlock(nn.Module): def __init__( self, *, width: int, heads: int, qkv_bias: bool = True, norm_layer=nn.LayerNorm, qk_norm: bool = False, drop_path_rate: float = 0.0, ): super().__init__() self.attn = MultiheadAttention( width=width, heads=heads, qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm, drop_path_rate=drop_path_rate, ) self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) def forward(self, x: torch.Tensor): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__( self, *, width: int, layers: int, heads: int, qkv_bias: bool = True, norm_layer=nn.LayerNorm, qk_norm: bool = False, drop_path_rate: float = 0.0, ): super().__init__() self.width = width self.layers = layers self.resblocks = nn.ModuleList([ ResidualAttentionBlock( width=width, heads=heads, qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm, drop_path_rate=drop_path_rate, ) for _ in range(layers) ]) def forward(self, x: torch.Tensor): for block in self.resblocks: x = block(x) return x class CrossAttentionDecoder(nn.Module): def __init__( self, *, out_channels: int, fourier_embedder: FourierEmbedder, width: int, heads: int, mlp_expand_ratio: int = 4, downsample_ratio: int = 1, enable_ln_post: bool = True, qkv_bias: bool = True, qk_norm: bool = False, label_type: str = "binary", ): super().__init__() self.enable_ln_post = enable_ln_post self.fourier_embedder = fourier_embedder self.downsample_ratio = downsample_ratio self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width) if self.downsample_ratio != 1: self.latents_proj = nn.Linear(width * downsample_ratio, width) if self.enable_ln_post == False: qk_norm = False self.cross_attn_decoder = ResidualCrossAttentionBlock( width=width, mlp_expand_ratio=mlp_expand_ratio, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm, ) if self.enable_ln_post: self.ln_post = nn.LayerNorm(width) self.output_proj = nn.Linear(width, out_channels) self.label_type = label_type self.count = 0 def set_cross_attention_processor(self, processor): self.cross_attn_decoder.attn.attention.attn_processor = processor # def set_default_cross_attention_processor(self): # self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor def forward(self, queries=None, query_embeddings=None, latents=None): if query_embeddings is None: query_embeddings = self.query_proj( self.fourier_embedder(queries).to(latents.dtype) ) self.count += query_embeddings.shape[1] if self.downsample_ratio != 1: latents = self.latents_proj(latents) x = self.cross_attn_decoder(query_embeddings, latents) if self.enable_ln_post: x = self.ln_post(x) occ = self.output_proj(x) return occ def fps( src: torch.Tensor, batch: Optional[Tensor] = None, ratio: Optional[Union[Tensor, float]] = None, random_start: bool = True, batch_size: Optional[int] = None, ptr: Optional[Union[Tensor, List[int]]] = None, ): src = src.float() from torch_cluster import fps as fps_fn output = fps_fn(src, batch, ratio, random_start, batch_size, ptr) return output class PointCrossAttentionEncoder(nn.Module): def __init__( self, *, num_latents: int, downsample_ratio: float, pc_size: int, pc_sharpedge_size: int, fourier_embedder: FourierEmbedder, point_feats: int, width: int, heads: int, layers: int, normal_pe: bool = False, qkv_bias: bool = True, use_ln_post: bool = False, use_checkpoint: bool = False, qk_norm: bool = False, ): super().__init__() self.use_checkpoint = use_checkpoint self.num_latents = num_latents self.downsample_ratio = downsample_ratio self.point_feats = point_feats self.normal_pe = normal_pe if pc_sharpedge_size == 0: print( f"PointCrossAttentionEncoder INFO: pc_sharpedge_size is not given," f" using pc_size as pc_sharpedge_size" ) else: print( "PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using" f" pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}" ) self.pc_size = pc_size self.pc_sharpedge_size = pc_sharpedge_size self.fourier_embedder = fourier_embedder self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width) self.cross_attn = ResidualCrossAttentionBlock( width=width, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm ) self.self_attn = None if layers > 0: self.self_attn = Transformer( width=width, layers=layers, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm, ) if use_ln_post: self.ln_post = nn.LayerNorm(width) else: self.ln_post = None def sample_points_and_latents( self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None ): B, N, D = pc.shape num_pts = self.num_latents * self.downsample_ratio # Compute number of latents num_latents = int(num_pts / self.downsample_ratio) # Compute the number of random and sharpedge latents num_random_query = ( self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents ) num_sharpedge_query = num_latents - num_random_query # Split random and sharpedge surface points random_pc, sharpedge_pc = torch.split( pc, [self.pc_size, self.pc_sharpedge_size], dim=1 ) assert ( random_pc.shape[1] <= self.pc_size ), "Random surface points size must be less than or equal to pc_size" assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, ( "Sharpedge surface points size must be less than or equal to" " pc_sharpedge_size" ) # Randomly select random surface points and random query points input_random_pc_size = int(num_random_query * self.downsample_ratio) random_query_ratio = num_random_query / input_random_pc_size idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[ :input_random_pc_size ] input_random_pc = random_pc[:, idx_random_pc, :] flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D) N_down = int(flatten_input_random_pc.shape[0] / B) batch_down = torch.arange(B).to(pc.device) batch_down = torch.repeat_interleave(batch_down, N_down) idx_query_random = fps( flatten_input_random_pc, batch_down, ratio=random_query_ratio ) query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D) # Randomly select sharpedge surface points and sharpedge query points input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio) if input_sharpedge_pc_size == 0: input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to( pc.device ) query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to( pc.device ) else: sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size idx_sharpedge_pc = torch.randperm( sharpedge_pc.shape[1], device=sharpedge_pc.device )[:input_sharpedge_pc_size] input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :] flatten_input_sharpedge_surface_points = input_sharpedge_pc.view( B * input_sharpedge_pc_size, D ) N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B) batch_down = torch.arange(B).to(pc.device) batch_down = torch.repeat_interleave(batch_down, N_down) idx_query_sharpedge = fps( flatten_input_sharpedge_surface_points, batch_down, ratio=sharpedge_query_ratio, ) query_sharpedge_pc = flatten_input_sharpedge_surface_points[ idx_query_sharpedge ].view(B, -1, D) # Concatenate random and sharpedge surface points and query points query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1) input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1) # PE query = self.fourier_embedder(query_pc) data = self.fourier_embedder(input_pc) # Concat normal if given if self.point_feats != 0: random_surface_feats, sharpedge_surface_feats = torch.split( feats, [self.pc_size, self.pc_sharpedge_size], dim=1 ) input_random_surface_feats = random_surface_feats[:, idx_random_pc, :] flatten_input_random_surface_feats = input_random_surface_feats.view( B * input_random_pc_size, -1 ) query_random_feats = flatten_input_random_surface_feats[ idx_query_random ].view(B, -1, flatten_input_random_surface_feats.shape[-1]) if input_sharpedge_pc_size == 0: input_sharpedge_surface_feats = torch.zeros( B, 0, self.point_feats, dtype=input_random_surface_feats.dtype ).to(pc.device) query_sharpedge_feats = torch.zeros( B, 0, self.point_feats, dtype=query_random_feats.dtype ).to(pc.device) else: input_sharpedge_surface_feats = sharpedge_surface_feats[ :, idx_sharpedge_pc, : ] flatten_input_sharpedge_surface_feats = ( input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size, -1) ) query_sharpedge_feats = flatten_input_sharpedge_surface_feats[ idx_query_sharpedge ].view(B, -1, flatten_input_sharpedge_surface_feats.shape[-1]) query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1) input_feats = torch.cat( [input_random_surface_feats, input_sharpedge_surface_feats], dim=1 ) if self.normal_pe: query_normal_pe = self.fourier_embedder(query_feats[..., :3]) input_normal_pe = self.fourier_embedder(input_feats[..., :3]) query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1) input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1) query = torch.cat([query, query_feats], dim=-1) data = torch.cat([data, input_feats], dim=-1) if input_sharpedge_pc_size == 0: query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device) input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device) # print(f'query_pc: {query_pc.shape}') # print(f'input_pc: {input_pc.shape}') # print(f'query_random_pc: {query_random_pc.shape}') # print(f'input_random_pc: {input_random_pc.shape}') # print(f'query_sharpedge_pc: {query_sharpedge_pc.shape}') # print(f'input_sharpedge_pc: {input_sharpedge_pc.shape}') return ( query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]), [ query_pc, input_pc, query_random_pc, input_random_pc, query_sharpedge_pc, input_sharpedge_pc, ], ) def forward(self, pc, feats): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] Returns: """ query, data, pc_infos = self.sample_points_and_latents(pc, feats) query = self.input_proj(query) query = query data = self.input_proj(data) data = data latents = self.cross_attn(query, data) if self.self_attn is not None: latents = self.self_attn(latents) if self.ln_post is not None: latents = self.ln_post(latents) return latents, pc_infos