Spaces:
Paused
Paused
| from dataclasses import dataclass | |
| import math | |
| import torch | |
| import numpy as np | |
| import random | |
| import time | |
| import trimesh | |
| import torch.nn as nn | |
| from einops import repeat, rearrange | |
| from tqdm import trange | |
| from itertools import product | |
| from diffusers.models.modeling_utils import ModelMixin | |
| import step1x3d_geometry | |
| from step1x3d_geometry.utils.checkpoint import checkpoint | |
| from step1x3d_geometry.utils.base import BaseModule | |
| from step1x3d_geometry.utils.typing import * | |
| from step1x3d_geometry.utils.misc import get_world_size, get_device | |
| from .transformers.perceiver_1d import Perceiver | |
| from .transformers.attention import ResidualCrossAttentionBlock | |
| from .volume_decoders import HierarchicalVolumeDecoder, VanillaVolumeDecoder | |
| from .surface_extractors import MCSurfaceExtractor, DMCSurfaceExtractor | |
| from ..pipelines.pipeline_utils import smart_load_model | |
| from safetensors.torch import load_file | |
| VALID_EMBED_TYPES = ["identity", "fourier", "learned_fourier", "siren"] | |
| class FourierEmbedder(nn.Module): | |
| def __init__( | |
| self, | |
| num_freqs: int = 6, | |
| logspace: bool = True, | |
| input_dim: int = 3, | |
| include_input: bool = True, | |
| include_pi: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| if logspace: | |
| frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) | |
| else: | |
| frequencies = torch.linspace( | |
| 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 | |
| ) | |
| if include_pi: | |
| frequencies *= torch.pi | |
| self.register_buffer("frequencies", frequencies, persistent=False) | |
| self.include_input = include_input | |
| self.num_freqs = num_freqs | |
| self.out_dim = self.get_dims(input_dim) | |
| def get_dims(self, input_dim): | |
| temp = 1 if self.include_input or self.num_freqs == 0 else 0 | |
| out_dim = input_dim * (self.num_freqs * 2 + temp) | |
| return out_dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.num_freqs > 0: | |
| embed = (x[..., None].contiguous() * self.frequencies).view( | |
| *x.shape[:-1], -1 | |
| ) | |
| if self.include_input: | |
| return torch.cat((x, embed.sin(), embed.cos()), dim=-1) | |
| else: | |
| return torch.cat((embed.sin(), embed.cos()), dim=-1) | |
| else: | |
| return x | |
| class LearnedFourierEmbedder(nn.Module): | |
| def __init__(self, input_dim, dim): | |
| super().__init__() | |
| assert (dim % 2) == 0 | |
| half_dim = dim // 2 | |
| per_channel_dim = half_dim // input_dim | |
| self.weights = nn.Parameter(torch.randn(per_channel_dim)) | |
| self.out_dim = self.get_dims(input_dim) | |
| def forward(self, x): | |
| # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] | |
| freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) | |
| fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) | |
| return fouriered | |
| def get_dims(self, input_dim): | |
| return input_dim * (self.weights.shape[0] * 2 + 1) | |
| class Sine(nn.Module): | |
| def __init__(self, w0=1.0): | |
| super().__init__() | |
| self.w0 = w0 | |
| def forward(self, x): | |
| return torch.sin(self.w0 * x) | |
| class Siren(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| out_dim, | |
| w0=1.0, | |
| c=6.0, | |
| is_first=False, | |
| use_bias=True, | |
| activation=None, | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.out_dim = out_dim | |
| self.is_first = is_first | |
| weight = torch.zeros(out_dim, in_dim) | |
| bias = torch.zeros(out_dim) if use_bias else None | |
| self.init_(weight, bias, c=c, w0=w0) | |
| self.weight = nn.Parameter(weight) | |
| self.bias = nn.Parameter(bias) if use_bias else None | |
| self.activation = Sine(w0) if activation is None else activation | |
| self.dropout = nn.Dropout(dropout) | |
| def init_(self, weight, bias, c, w0): | |
| dim = self.in_dim | |
| w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) | |
| weight.uniform_(-w_std, w_std) | |
| if bias is not None: | |
| bias.uniform_(-w_std, w_std) | |
| def forward(self, x): | |
| out = F.linear(x, self.weight, self.bias) | |
| out = self.activation(out) | |
| out = self.dropout(out) | |
| return out | |
| def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True): | |
| if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): | |
| return nn.Identity(), input_dim | |
| elif embed_type == "fourier": | |
| embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) | |
| elif embed_type == "learned_fourier": | |
| embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs) | |
| elif embed_type == "siren": | |
| embedder_obj = Siren( | |
| in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim | |
| ) | |
| else: | |
| raise ValueError( | |
| f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}" | |
| ) | |
| return embedder_obj | |
| ###################### AutoEncoder | |
| class DiagonalGaussianDistribution(ModelMixin, object): | |
| def __init__( | |
| self, | |
| parameters: Union[torch.Tensor, List[torch.Tensor]], | |
| deterministic=False, | |
| feat_dim=1, | |
| ): | |
| self.feat_dim = feat_dim | |
| self.parameters = parameters | |
| if isinstance(parameters, list): | |
| self.mean = parameters[0] | |
| self.logvar = parameters[1] | |
| else: | |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) | |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean) | |
| def sample(self): | |
| x = self.mean + self.std * torch.randn_like(self.mean) | |
| return x | |
| def kl(self, other=None, dims=(1, 2)): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| else: | |
| if other is None: | |
| return 0.5 * torch.mean( | |
| torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims | |
| ) | |
| else: | |
| return 0.5 * torch.mean( | |
| torch.pow(self.mean - other.mean, 2) / other.var | |
| + self.var / other.var | |
| - 1.0 | |
| - self.logvar | |
| + other.logvar, | |
| dim=dims, | |
| ) | |
| def nll(self, sample, dims=(1, 2)): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| logtwopi = np.log(2.0 * np.pi) | |
| return 0.5 * torch.sum( | |
| logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, | |
| dim=dims, | |
| ) | |
| def mode(self): | |
| return self.mean | |
| class PerceiverCrossAttentionEncoder(ModelMixin, nn.Module): | |
| def __init__( | |
| self, | |
| use_downsample: bool, | |
| num_latents: int, | |
| embedder: FourierEmbedder, | |
| point_feats: int, | |
| embed_point_feats: bool, | |
| width: int, | |
| heads: int, | |
| layers: int, | |
| init_scale: float = 0.25, | |
| qkv_bias: bool = True, | |
| qk_norm: bool = True, | |
| use_ln_post: bool = False, | |
| use_flash: bool = False, | |
| use_checkpoint: bool = False, | |
| use_multi_reso: bool = False, | |
| resolutions: list = [], | |
| sampling_prob: list = [], | |
| with_sharp_data: bool = False, | |
| ): | |
| super().__init__() | |
| self.use_checkpoint = use_checkpoint | |
| self.num_latents = num_latents | |
| self.use_downsample = use_downsample | |
| self.embed_point_feats = embed_point_feats | |
| self.use_multi_reso = use_multi_reso | |
| self.resolutions = resolutions | |
| self.sampling_prob = sampling_prob | |
| if not self.use_downsample: | |
| self.query = nn.Parameter(torch.randn((num_latents, width)) * 0.02) | |
| self.embedder = embedder | |
| if self.embed_point_feats: | |
| self.input_proj = nn.Linear(self.embedder.out_dim * 2, width) | |
| else: | |
| self.input_proj = nn.Linear(self.embedder.out_dim + point_feats, width) | |
| self.cross_attn = ResidualCrossAttentionBlock( | |
| width=width, | |
| heads=heads, | |
| init_scale=init_scale, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| use_flash=use_flash, | |
| ) | |
| self.with_sharp_data = with_sharp_data | |
| if with_sharp_data: | |
| self.downsmaple_num_latents = num_latents // 2 | |
| self.input_proj_sharp = nn.Linear( | |
| self.embedder.out_dim + point_feats, width | |
| ) | |
| self.cross_attn_sharp = ResidualCrossAttentionBlock( | |
| width=width, | |
| heads=heads, | |
| init_scale=init_scale, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| use_flash=use_flash, | |
| ) | |
| else: | |
| self.downsmaple_num_latents = num_latents | |
| self.self_attn = Perceiver( | |
| n_ctx=num_latents, | |
| width=width, | |
| layers=layers, | |
| heads=heads, | |
| init_scale=init_scale, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| use_flash=use_flash, | |
| use_checkpoint=use_checkpoint, | |
| ) | |
| if use_ln_post: | |
| self.ln_post = nn.LayerNorm(width) | |
| else: | |
| self.ln_post = None | |
| def _forward(self, pc, feats, sharp_pc=None, sharp_feat=None): | |
| """ | |
| Args: | |
| pc (torch.FloatTensor): [B, N, 3] | |
| feats (torch.FloatTensor or None): [B, N, C] | |
| Returns: | |
| """ | |
| bs, N, D = pc.shape | |
| data = self.embedder(pc) | |
| if feats is not None: | |
| if self.embed_point_feats: | |
| feats = self.embedder(feats) | |
| data = torch.cat([data, feats], dim=-1) | |
| data = self.input_proj(data) | |
| if self.with_sharp_data: | |
| sharp_data = self.embedder(sharp_pc) | |
| if sharp_feat is not None: | |
| if self.embed_point_feats: | |
| sharp_feat = self.embedder(sharp_feat) | |
| sharp_data = torch.cat([sharp_data, sharp_feat], dim=-1) | |
| sharp_data = self.input_proj_sharp(sharp_data) | |
| if self.use_multi_reso: | |
| resolution = random.choice(self.resolutions, size=1, p=self.sampling_prob)[ | |
| 0 | |
| ] | |
| if resolution != N: | |
| flattened = pc.view(bs * N, D) # bs*N, 64. 103,4096,3 -> 421888,3 | |
| batch = torch.arange(bs).to(pc.device) # 103 | |
| batch = torch.repeat_interleave(batch, N) # bs*N. 421888 | |
| pos = flattened.to(torch.float16) | |
| ratio = 1.0 * resolution / N # 0.0625 | |
| idx = fps(pos, batch, ratio=ratio) # 26368 | |
| pc = pc.view(bs * N, -1)[idx].view(bs, -1, D) | |
| bs, N, D = feats.shape | |
| flattened1 = feats.view(bs * N, D) | |
| feats = flattened1.view(bs * N, -1)[idx].view(bs, -1, D) | |
| bs, N, D = pc.shape | |
| if self.use_downsample: | |
| ###### fps | |
| from torch_cluster import fps | |
| flattened = pc.view(bs * N, D) # bs*N, 64 | |
| batch = torch.arange(bs).to(pc.device) | |
| batch = torch.repeat_interleave(batch, N) # bs*N | |
| pos = flattened.to(torch.float16) | |
| ratio = 1.0 * self.downsmaple_num_latents / N | |
| idx = fps(pos, batch, ratio=ratio).detach() | |
| query = data.view(bs * N, -1)[idx].view(bs, -1, data.shape[-1]) | |
| if self.with_sharp_data: | |
| bs, N, D = sharp_pc.shape | |
| flattened = sharp_pc.view(bs * N, D) # bs*N, 64 | |
| pos = flattened.to(torch.float16) | |
| ratio = 1.0 * self.downsmaple_num_latents / N | |
| idx = fps(pos, batch, ratio=ratio).detach() | |
| sharp_query = sharp_data.view(bs * N, -1)[idx].view( | |
| bs, -1, sharp_data.shape[-1] | |
| ) | |
| query = torch.cat([query, sharp_query], dim=1) | |
| else: | |
| query = self.query | |
| query = repeat(query, "m c -> b m c", b=bs) | |
| latents = self.cross_attn(query, data) | |
| if self.with_sharp_data: | |
| latents = latents + self.cross_attn_sharp(query, sharp_data) | |
| latents = self.self_attn(latents) | |
| if self.ln_post is not None: | |
| latents = self.ln_post(latents) | |
| return latents | |
| def forward( | |
| self, | |
| pc: torch.FloatTensor, | |
| feats: Optional[torch.FloatTensor] = None, | |
| sharp_pc: Optional[torch.FloatTensor] = None, | |
| sharp_feats: Optional[torch.FloatTensor] = None, | |
| ): | |
| """ | |
| Args: | |
| pc (torch.FloatTensor): [B, N, 3] | |
| feats (torch.FloatTensor or None): [B, N, C] | |
| Returns: | |
| dict | |
| """ | |
| return checkpoint( | |
| self._forward, | |
| (pc, feats, sharp_pc, sharp_feats), | |
| self.parameters(), | |
| self.use_checkpoint, | |
| ) | |
| class PerceiverCrossAttentionDecoder(ModelMixin, nn.Module): | |
| def __init__( | |
| self, | |
| num_latents: int, | |
| out_dim: int, | |
| embedder: FourierEmbedder, | |
| width: int, | |
| heads: int, | |
| init_scale: float = 0.25, | |
| qkv_bias: bool = True, | |
| qk_norm: bool = True, | |
| use_flash: bool = False, | |
| use_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| self.use_checkpoint = use_checkpoint | |
| self.embedder = embedder | |
| self.query_proj = nn.Linear(self.embedder.out_dim, width) | |
| self.cross_attn_decoder = ResidualCrossAttentionBlock( | |
| n_data=num_latents, | |
| width=width, | |
| heads=heads, | |
| init_scale=init_scale, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| use_flash=use_flash, | |
| ) | |
| self.ln_post = nn.LayerNorm(width) | |
| self.output_proj = nn.Linear(width, out_dim) | |
| def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): | |
| queries = self.query_proj(self.embedder(queries)) | |
| x = self.cross_attn_decoder(queries, latents) | |
| x = self.ln_post(x) | |
| x = self.output_proj(x) | |
| return x | |
| def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): | |
| return checkpoint( | |
| self._forward, (queries, latents), self.parameters(), self.use_checkpoint | |
| ) | |
| class MichelangeloAutoencoder(BaseModule): | |
| r""" | |
| A VAE model for encoding shapes into latents and decoding latent representations into shapes. | |
| """ | |
| class Config(BaseModule.Config): | |
| pretrained_model_name_or_path: str = "" | |
| subfolder: str = "" | |
| n_samples: int = 4096 | |
| use_downsample: bool = False | |
| downsample_ratio: float = 0.0625 | |
| num_latents: int = 256 | |
| point_feats: int = 0 | |
| embed_point_feats: bool = False | |
| out_dim: int = 1 | |
| embed_dim: int = 64 | |
| embed_type: str = "fourier" | |
| num_freqs: int = 8 | |
| include_pi: bool = True | |
| width: int = 768 | |
| heads: int = 12 | |
| num_encoder_layers: int = 8 | |
| num_decoder_layers: int = 16 | |
| init_scale: float = 0.25 | |
| qkv_bias: bool = True | |
| qk_norm: bool = False | |
| use_ln_post: bool = False | |
| use_flash: bool = False | |
| use_checkpoint: bool = True | |
| use_multi_reso: Optional[bool] = False | |
| resolutions: Optional[List[int]] = None | |
| sampling_prob: Optional[List[float]] = None | |
| with_sharp_data: Optional[bool] = True | |
| volume_decoder_type: str = "hierarchical" | |
| surface_extractor_type: str = "mc" | |
| z_scale_factor: float = 1.0 | |
| cfg: Config | |
| def configure(self) -> None: | |
| super().configure() | |
| self.embedder = get_embedder( | |
| embed_type=self.cfg.embed_type, | |
| num_freqs=self.cfg.num_freqs, | |
| include_pi=self.cfg.include_pi, | |
| ) | |
| # encoder | |
| self.cfg.init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width) | |
| self.encoder = PerceiverCrossAttentionEncoder( | |
| use_downsample=self.cfg.use_downsample, | |
| embedder=self.embedder, | |
| num_latents=self.cfg.num_latents, | |
| point_feats=self.cfg.point_feats, | |
| embed_point_feats=self.cfg.embed_point_feats, | |
| width=self.cfg.width, | |
| heads=self.cfg.heads, | |
| layers=self.cfg.num_encoder_layers, | |
| init_scale=self.cfg.init_scale, | |
| qkv_bias=self.cfg.qkv_bias, | |
| qk_norm=self.cfg.qk_norm, | |
| use_ln_post=self.cfg.use_ln_post, | |
| use_flash=self.cfg.use_flash, | |
| use_checkpoint=self.cfg.use_checkpoint, | |
| use_multi_reso=self.cfg.use_multi_reso, | |
| resolutions=self.cfg.resolutions, | |
| sampling_prob=self.cfg.sampling_prob, | |
| with_sharp_data=self.cfg.with_sharp_data, | |
| ) | |
| if self.cfg.embed_dim > 0: | |
| # VAE embed | |
| self.pre_kl = nn.Linear(self.cfg.width, self.cfg.embed_dim * 2) | |
| self.post_kl = nn.Linear(self.cfg.embed_dim, self.cfg.width) | |
| self.latent_shape = (self.cfg.num_latents, self.cfg.embed_dim) | |
| else: | |
| self.latent_shape = (self.cfg.num_latents, self.cfg.width) | |
| self.transformer = Perceiver( | |
| n_ctx=self.cfg.num_latents, | |
| width=self.cfg.width, | |
| layers=self.cfg.num_decoder_layers, | |
| heads=self.cfg.heads, | |
| init_scale=self.cfg.init_scale, | |
| qkv_bias=self.cfg.qkv_bias, | |
| qk_norm=self.cfg.qk_norm, | |
| use_flash=self.cfg.use_flash, | |
| use_checkpoint=self.cfg.use_checkpoint, | |
| ) | |
| # decoder | |
| self.decoder = PerceiverCrossAttentionDecoder( | |
| embedder=self.embedder, | |
| out_dim=self.cfg.out_dim, | |
| num_latents=self.cfg.num_latents, | |
| width=self.cfg.width, | |
| heads=self.cfg.heads, | |
| init_scale=self.cfg.init_scale, | |
| qkv_bias=self.cfg.qkv_bias, | |
| qk_norm=self.cfg.qk_norm, | |
| use_flash=self.cfg.use_flash, | |
| use_checkpoint=self.cfg.use_checkpoint, | |
| ) | |
| # volume decoder | |
| if self.cfg.volume_decoder_type == "hierarchical": | |
| self.volume_decoder = HierarchicalVolumeDecoder() | |
| else: | |
| self.volume_decoder = VanillaVolumeDecoder() | |
| if self.cfg.pretrained_model_name_or_path != "": | |
| local_model_path = f"{smart_load_model(self.cfg.pretrained_model_name_or_path, self.cfg.subfolder)}/vae/diffusion_pytorch_model.safetensors" | |
| pretrain_safetensors = load_file(local_model_path) | |
| print(f"Loading pretrained VAE model from {local_model_path}") | |
| if "state_dict" in pretrain_safetensors: | |
| _pretrained_safetensors = {} | |
| for k, v in pretrain_safetensors["state_dict"].items(): | |
| if k.startswith("shape_model."): | |
| if "proj1" in k: | |
| _pretrained_safetensors[ | |
| k.replace("shape_model.", "").replace( | |
| "proj1", "proj_sharp" | |
| ) | |
| ] = v | |
| elif "attn1" in k: | |
| _pretrained_safetensors[ | |
| k.replace("shape_model.", "").replace( | |
| "attn1", "attn_sharp" | |
| ) | |
| ] = v | |
| else: | |
| _pretrained_safetensors[k.replace("shape_model.", "")] = v | |
| pretrain_safetensors = _pretrained_safetensors | |
| self.load_state_dict(pretrain_safetensors, strict=True) | |
| else: | |
| _pretrained_safetensors = {} | |
| for k, v in pretrain_safetensors.items(): | |
| if k.startswith("shape_model"): | |
| final_module = self | |
| for key in k.replace("shape_model.", "").split("."): | |
| final_module = getattr(final_module, key) | |
| data = final_module.data | |
| data_zero = torch.zeros_like(data).to(v) | |
| if data.shape != v.shape: | |
| if data.ndim == 1: | |
| data_zero[: v.shape[0]] = v | |
| elif data.ndim == 2: | |
| data_zero[: v.shape[0], : v.shape[1]] = v | |
| v = data_zero | |
| _pretrained_safetensors[k.replace("shape_model.", "")] = v | |
| else: | |
| _pretrained_safetensors[k] = v | |
| pretrain_safetensors = _pretrained_safetensors | |
| self.load_state_dict(pretrain_safetensors, strict=True) | |
| print("Successed load pretrained VAE model") | |
| def encode( | |
| self, | |
| surface: torch.FloatTensor, | |
| sample_posterior: bool = True, | |
| sharp_surface: torch.FloatTensor = None, | |
| ): | |
| """ | |
| Args: | |
| surface (torch.FloatTensor): [B, N, 3+C] | |
| sample_posterior (bool): | |
| Returns: | |
| shape_latents (torch.FloatTensor): [B, num_latents, width] | |
| kl_embed (torch.FloatTensor): [B, num_latents, embed_dim] | |
| posterior (DiagonalGaussianDistribution or None): | |
| """ | |
| assert ( | |
| surface.shape[-1] == 3 + self.cfg.point_feats | |
| ), f"\ | |
| Expected {3 + self.cfg.point_feats} channels, got {surface.shape[-1]}" | |
| pc, feats = surface[..., :3], surface[..., 3:] # B, n_samples, 3 | |
| if sharp_surface is not None: | |
| sharp_pc, sharp_feats = ( | |
| sharp_surface[..., :3], | |
| sharp_surface[..., 3:], | |
| ) # B, n_samples, 3 | |
| else: | |
| sharp_pc, sharp_feats = None, None | |
| shape_embeds = self.encoder( | |
| pc, feats, sharp_pc, sharp_feats | |
| ) # B, num_latents, width | |
| kl_embed, posterior = self.encode_kl_embed( | |
| shape_embeds, sample_posterior | |
| ) # B, num_latents, embed_dim | |
| kl_embed = kl_embed * self.cfg.z_scale_factor # encode with scale | |
| return shape_embeds, kl_embed, posterior | |
| def decode(self, latents: torch.FloatTensor): | |
| """ | |
| Args: | |
| latents (torch.FloatTensor): [B, embed_dim] | |
| Returns: | |
| latents (torch.FloatTensor): [B, embed_dim] | |
| """ | |
| latents = self.post_kl( | |
| latents / self.cfg.z_scale_factor | |
| ) # [B, num_latents, embed_dim] -> [B, num_latents, width] | |
| return self.transformer(latents) | |
| def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor): | |
| """ | |
| Args: | |
| queries (torch.FloatTensor): [B, N, 3] | |
| latents (torch.FloatTensor): [B, embed_dim] | |
| Returns: | |
| features (torch.FloatTensor): [B, N, C], output features | |
| """ | |
| features = self.decoder(queries, latents) | |
| return features | |
| def encode_kl_embed( | |
| self, latents: torch.FloatTensor, sample_posterior: bool = True | |
| ): | |
| posterior = None | |
| if self.cfg.embed_dim > 0: | |
| moments = self.pre_kl(latents) | |
| posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) | |
| if sample_posterior: | |
| kl_embed = posterior.sample() | |
| else: | |
| kl_embed = posterior.mode() | |
| else: | |
| kl_embed = latents | |
| return kl_embed, posterior | |
| def forward( | |
| self, | |
| surface: torch.FloatTensor, | |
| sharp_surface: torch.FloatTensor = None, | |
| rand_points: torch.FloatTensor = None, | |
| sample_posterior: bool = True, | |
| **kwargs, | |
| ): | |
| shape_latents, kl_embed, posterior = self.encode( | |
| surface, sample_posterior=sample_posterior, sharp_surface=sharp_surface | |
| ) | |
| latents = self.decode(kl_embed) # [B, num_latents, width] | |
| meshes = self.extract_geometry(latents, **kwargs) | |
| return shape_latents, latents, posterior, meshes | |
| def extract_geometry(self, latents: torch.FloatTensor, **kwargs): | |
| grid_logits_list = [] | |
| for i in range(latents.shape[0]): | |
| grid_logits = self.volume_decoder( | |
| latents[i].unsqueeze(0), self.query, **kwargs | |
| ) | |
| grid_logits_list.append(grid_logits) | |
| grid_logits = torch.cat(grid_logits_list, dim=0) | |
| # extract mesh | |
| surface_extractor_type = ( | |
| kwargs["surface_extractor_type"] | |
| if "surface_extractor_type" in kwargs.keys() | |
| and kwargs["surface_extractor_type"] is not None | |
| else self.cfg.surface_extractor_type | |
| ) | |
| if surface_extractor_type == "mc": | |
| surface_extractor = MCSurfaceExtractor() | |
| meshes = surface_extractor(grid_logits, **kwargs) | |
| elif surface_extractor_type == "dmc": | |
| surface_extractor = DMCSurfaceExtractor() | |
| meshes = surface_extractor(grid_logits, **kwargs) | |
| else: | |
| raise NotImplementedError | |
| return meshes | |