Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ----------------------------------------------------------------------------- | |
| Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
| NVIDIA CORPORATION and its licensors retain all intellectual property | |
| and proprietary rights in and to this software, related documentation | |
| and any modifications thereto. Any use, reproduction, disclosure or | |
| distribution of this software and related documentation without an express | |
| license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| ----------------------------------------------------------------------------- | |
| """ | |
| from typing import Literal | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from vae.configs.schema import ModelConfig | |
| from vae.modules.transformer import AttentionBlock, FlashQueryLayer | |
| from vae.utils import ( | |
| DiagonalGaussianDistribution, | |
| DummyLatent, | |
| calculate_iou, | |
| calculate_metrics, | |
| construct_grid_points, | |
| extract_mesh, | |
| sync_timer, | |
| ) | |
| class Model(nn.Module): | |
| def __init__(self, config: ModelConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.precision = torch.bfloat16 # manually handle low-precision training, always use bf16 | |
| # point encoder | |
| self.proj_input = nn.Linear(3 + config.point_fourier_dim, config.hidden_dim) | |
| self.perceiver = AttentionBlock( | |
| config.hidden_dim, | |
| num_heads=config.num_heads, | |
| dim_context=config.hidden_dim, | |
| qknorm=config.qknorm, | |
| qknorm_type=config.qknorm_type, | |
| ) | |
| if self.config.salient_attn_mode == "dual": | |
| self.perceiver_dorases = AttentionBlock( | |
| config.hidden_dim, | |
| num_heads=config.num_heads, | |
| dim_context=config.hidden_dim, | |
| qknorm=config.qknorm, | |
| qknorm_type=config.qknorm_type, | |
| ) | |
| # self-attention encoder | |
| self.encoder = nn.ModuleList( | |
| [ | |
| AttentionBlock( | |
| config.hidden_dim, config.num_heads, qknorm=config.qknorm, qknorm_type=config.qknorm_type | |
| ) | |
| for _ in range(config.num_enc_layers) | |
| ] | |
| ) | |
| # vae bottleneck | |
| self.norm_down = nn.LayerNorm(config.hidden_dim) | |
| self.proj_down_mean = nn.Linear(config.hidden_dim, config.latent_dim) | |
| if not self.config.use_ae: | |
| self.proj_down_std = nn.Linear(config.hidden_dim, config.latent_dim) | |
| self.proj_up = nn.Linear(config.latent_dim, config.dec_hidden_dim) | |
| # self-attention decoder | |
| self.decoder = nn.ModuleList( | |
| [ | |
| AttentionBlock( | |
| config.dec_hidden_dim, config.dec_num_heads, qknorm=config.qknorm, qknorm_type=config.qknorm_type | |
| ) | |
| for _ in range(config.num_dec_layers) | |
| ] | |
| ) | |
| # cross-attention query | |
| self.proj_query = nn.Linear(3 + config.point_fourier_dim, config.query_hidden_dim) | |
| if self.config.use_flash_query: | |
| self.norm_query_context = nn.LayerNorm(config.hidden_dim, eps=1e-6, elementwise_affine=False) | |
| self.attn_query = FlashQueryLayer( | |
| config.query_hidden_dim, | |
| num_heads=config.query_num_heads, | |
| dim_context=config.hidden_dim, | |
| qknorm=config.qknorm, | |
| qknorm_type=config.qknorm_type, | |
| ) | |
| else: | |
| self.attn_query = AttentionBlock( | |
| config.query_hidden_dim, | |
| num_heads=config.query_num_heads, | |
| dim_context=config.hidden_dim, | |
| qknorm=config.qknorm, | |
| qknorm_type=config.qknorm_type, | |
| ) | |
| self.norm_out = nn.LayerNorm(config.query_hidden_dim) | |
| self.proj_out = nn.Linear(config.query_hidden_dim, 1) | |
| # preload from a checkpoint (NOTE: this happens BEFORE checkpointer loading latest checkpoint!) | |
| if self.config.pretrain_path is not None: | |
| try: | |
| ckpt = torch.load(self.config.pretrain_path) # local path | |
| self.load_state_dict(ckpt["model"], strict=True) | |
| del ckpt | |
| print(f"Loaded VAE from {self.config.pretrain_path}") | |
| except Exception as e: | |
| print( | |
| f"Failed to load VAE from {self.config.pretrain_path}: {e}, make sure you resumed from a valid checkpoint!" | |
| ) | |
| # log | |
| n_params = 0 | |
| for p in self.parameters(): | |
| n_params += p.numel() | |
| print(f"Number of parameters in VAE: {n_params / 1e6:.2f}M") | |
| # override to support tolerant loading (only load matched shape) | |
| def load_state_dict(self, state_dict, strict=True, assign=False): | |
| local_state_dict = self.state_dict() | |
| seen_keys = {k: False for k in local_state_dict.keys()} | |
| for k, v in state_dict.items(): | |
| if k in local_state_dict: | |
| seen_keys[k] = True | |
| if local_state_dict[k].shape == v.shape: | |
| local_state_dict[k].copy_(v) | |
| else: | |
| print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}") | |
| else: | |
| print(f"unexpected key {k} in loaded state dict") | |
| for k in seen_keys: | |
| if not seen_keys[k]: | |
| print(f"missing key {k} in loaded state dict") | |
| def fourier_encoding(self, points: torch.Tensor): | |
| # points: [B, N, 3], float32 for precision | |
| # assert points.dtype == torch.float32, "Query points must be float32" | |
| F = self.config.point_fourier_dim // (2 * points.shape[-1]) | |
| if self.config.fourier_version == "v1": # default | |
| exponent = torch.arange(1, F + 1, device=points.device, dtype=torch.float32) / F # [F], range from 0 to 1 | |
| freq_band = 512**exponent # [F], min frequency is 1, max frequency is 1/freq | |
| freq_band *= torch.pi | |
| elif self.config.fourier_version == "v2": | |
| exponent = torch.arange(F, device=points.device, dtype=torch.float32) / (F - 1) # [F], range from 0 to 1 | |
| freq_band = 1024**exponent # [F] | |
| freq_band *= torch.pi | |
| elif self.config.fourier_version == "v3": # hunyuan3d-2 | |
| freq_band = 2 ** torch.arange(F, device=points.device, dtype=torch.float32) # [F] | |
| spectrum = points.unsqueeze(-1) * freq_band # [B,...,3,F] | |
| sin, cos = spectrum.sin(), spectrum.cos() # [B,...,3,F] | |
| input_enc = torch.stack([sin, cos], dim=-2) # [B,...,3,2,F] | |
| input_enc = input_enc.view(*points.shape[:-1], -1) # [B,...,6F] = [B,...,dim] | |
| return torch.cat([input_enc, points], dim=-1).to(dtype=self.precision) # [B,...,dim+input_dim] | |
| def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: | |
| super().on_train_start(memory_format=memory_format) | |
| self.to(dtype=self.precision, memory_format=memory_format) # use bfloat16 for training | |
| def encode(self, data: dict[str, torch.Tensor]): | |
| # uniform points | |
| pointcloud = data["pointcloud"] # [B, N, 3] | |
| # fourier embed and project | |
| pointcloud = self.fourier_encoding(pointcloud) # [B, N, 3+C] | |
| pointcloud = self.proj_input(pointcloud) # [B, N, hidden_dim] | |
| # salient points | |
| if self.config.use_salient_point: | |
| pointcloud_dorases = data["pointcloud_dorases"] # [B, M, 3] | |
| # fourier embed and project (shared weights) | |
| pointcloud_dorases = self.fourier_encoding(pointcloud_dorases) # [B, M, 3+C] | |
| pointcloud_dorases = self.proj_input(pointcloud_dorases) # [B, M, hidden_dim] | |
| # gather fps point | |
| fps_indices = data["fps_indices"] # [B, N'] | |
| pointcloud_query = torch.gather(pointcloud, 1, fps_indices.unsqueeze(-1).expand(-1, -1, pointcloud.shape[-1])) | |
| if self.config.use_salient_point: | |
| fps_indices_dorases = data["fps_indices_dorases"] # [B, M'] | |
| if fps_indices_dorases.shape[1] > 0: | |
| pointcloud_query_dorases = torch.gather( | |
| pointcloud_dorases, | |
| 1, | |
| fps_indices_dorases.unsqueeze(-1).expand(-1, -1, pointcloud_dorases.shape[-1]), | |
| ) | |
| # combine both fps points as the query | |
| pointcloud_query = torch.cat( | |
| [pointcloud_query, pointcloud_query_dorases], dim=1 | |
| ) # [B, N'+M', hidden_dim] | |
| # dual cross-attention | |
| if self.config.salient_attn_mode == "dual_shared": | |
| hidden_states = self.perceiver(pointcloud_query, pointcloud) + self.perceiver( | |
| pointcloud_query, pointcloud_dorases | |
| ) # [B, N'+M', hidden_dim] | |
| elif self.config.salient_attn_mode == "dual": | |
| hidden_states = self.perceiver(pointcloud_query, pointcloud) + self.perceiver_dorases( | |
| pointcloud_query, pointcloud_dorases | |
| ) | |
| else: # single, hunyuan3d-2 style | |
| hidden_states = self.perceiver(pointcloud_query, torch.cat([pointcloud, pointcloud_dorases], dim=1)) | |
| else: | |
| hidden_states = self.perceiver(pointcloud_query, pointcloud) # [B, N', hidden_dim] | |
| # encoder | |
| for block in self.encoder: | |
| hidden_states = block(hidden_states) | |
| # bottleneck | |
| hidden_states = self.norm_down(hidden_states) | |
| latent_mean = self.proj_down_mean(hidden_states).float() | |
| if not self.config.use_ae: | |
| latent_std = self.proj_down_std(hidden_states).float() | |
| posterior = DiagonalGaussianDistribution(latent_mean, latent_std) | |
| else: | |
| posterior = DummyLatent(latent_mean) | |
| return posterior | |
| def decode(self, latent: torch.Tensor): | |
| latent = latent.to(dtype=self.precision) | |
| hidden_states = self.proj_up(latent) | |
| for block in self.decoder: | |
| hidden_states = block(hidden_states) | |
| return hidden_states | |
| def query(self, query_points: torch.Tensor, hidden_states: torch.Tensor): | |
| # query_points: [B, N, 3], float32 to keep the precision | |
| query_points = self.fourier_encoding(query_points) # [B, N, 3+C] | |
| query_points = self.proj_query(query_points) # [B, N, hidden_dim] | |
| # cross attention | |
| query_output = self.attn_query(query_points, hidden_states) # [B, N, hidden_dim] | |
| # output linear | |
| query_output = self.norm_out(query_output) | |
| pred = self.proj_out(query_output) # [B, N, 1] | |
| return pred | |
| def training_step( | |
| self, | |
| data: dict[str, torch.Tensor], | |
| iteration: int, | |
| ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
| output = {} | |
| # cut off fps point during training for progressive flow | |
| if self.training: | |
| # randomly choose from a set of cutoff candidates | |
| cutoff_index = np.random.choice(len(self.config.cutoff_fps_prob), p=self.config.cutoff_fps_prob) | |
| cutoff_fps_point = self.config.cutoff_fps_point[cutoff_index] | |
| cutoff_fps_salient_point = self.config.cutoff_fps_salient_point[cutoff_index] | |
| # prefix of FPS points are still FPS points | |
| data["fps_indices"] = data["fps_indices"][:, :cutoff_fps_point] | |
| if self.config.use_salient_point: | |
| data["fps_indices_dorases"] = data["fps_indices_dorases"][:, :cutoff_fps_salient_point] | |
| loss = 0 | |
| # encode | |
| posterior = self.encode(data) | |
| latent_geom = posterior.sample() if self.training else posterior.mode() | |
| # decode | |
| hidden_states = self.decode(latent_geom) | |
| # cross-attention query | |
| query_points = data["query_points"] # [B, N, 3], float32 | |
| # the context norm can be moved out to avoid repeated computation | |
| if self.config.use_flash_query: | |
| hidden_states = self.norm_query_context(hidden_states) | |
| pred = self.query(query_points, hidden_states).squeeze(-1).float() # [B, N] | |
| gt = data["query_gt"].float() # [B, N], in [-1, 1] | |
| # main loss | |
| loss_mse = F.mse_loss(pred, gt, reduction="mean") | |
| loss += loss_mse | |
| loss_l1 = F.l1_loss(pred, gt, reduction="mean") | |
| loss += loss_l1 | |
| # kl loss | |
| loss_kl = posterior.kl().mean() | |
| loss += self.config.kl_weight * loss_kl | |
| # metrics | |
| with torch.no_grad(): | |
| output["scalar"] = {} # for wandb logging | |
| output["scalar"]["loss_mse"] = loss_mse.detach() | |
| output["scalar"]["loss_l1"] = loss_l1.detach() | |
| output["scalar"]["loss_kl"] = loss_kl.detach() | |
| output["scalar"]["iou_fg"] = calculate_iou(pred, gt, target_value=1) | |
| output["scalar"]["iou_bg"] = calculate_iou(pred, gt, target_value=0) | |
| output["scalar"]["precision"], output["scalar"]["recall"], output["scalar"]["f1"] = calculate_metrics( | |
| pred, gt, target_value=1 | |
| ) | |
| return output, loss | |
| def validation_step( | |
| self, | |
| data: dict[str, torch.Tensor], | |
| iteration: int, | |
| ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
| return self.training_step(data, iteration) | |
| def forward( | |
| self, | |
| data: dict[str, torch.Tensor], | |
| mode: Literal["dense", "hierarchical"] = "hierarchical", | |
| max_samples_per_iter: int = 512**2, | |
| resolution: int = 512, | |
| min_resolution: int = 64, # for hierarchical | |
| ) -> dict[str, torch.Tensor]: | |
| output = {} | |
| # encode | |
| if "latent" in data: | |
| latent = data["latent"] | |
| else: | |
| posterior = self.encode(data) | |
| output["posterior"] = posterior | |
| latent = posterior.mode() | |
| output["latent"] = latent | |
| B = latent.shape[0] | |
| # decode | |
| hidden_states = self.decode(latent) | |
| output["hidden_states"] = hidden_states # [B, N, hidden_dim] for the last cross-attention decoder | |
| # the context norm can be moved out to avoid repeated computation | |
| if self.config.use_flash_query: | |
| hidden_states = self.norm_query_context(hidden_states) | |
| # query | |
| def chunked_query(grid_points): | |
| if grid_points.shape[0] <= max_samples_per_iter: | |
| return self.query(grid_points.unsqueeze(0), hidden_states).squeeze(-1) # [B, N] | |
| all_pred = [] | |
| for i in range(0, grid_points.shape[0], max_samples_per_iter): | |
| grid_chunk = grid_points[i : i + max_samples_per_iter] | |
| pred_chunk = self.query(grid_chunk.unsqueeze(0), hidden_states) | |
| all_pred.append(pred_chunk) | |
| return torch.cat(all_pred, dim=1).squeeze(-1) # [B, N] | |
| if mode == "dense": | |
| grid_points = construct_grid_points(resolution).to(latent.device) | |
| grid_points = grid_points.contiguous().view(-1, 3) | |
| grid_vals = chunked_query(grid_points).float().view(B, resolution + 1, resolution + 1, resolution + 1) | |
| elif mode == "hierarchical": | |
| assert resolution >= min_resolution, "Resolution must be greater than or equal to min_resolution" | |
| assert B == 1, "Only one batch is supported for hierarchical mode" | |
| resolutions = [] | |
| res = resolution | |
| while res >= min_resolution: | |
| resolutions.append(res) | |
| res = res // 2 | |
| resolutions.reverse() # e.g., [64, 128, 256, 512] | |
| # dense-query the coarsest resolution | |
| res = resolutions[0] | |
| grid_points = construct_grid_points(res).to(latent.device) | |
| grid_points = grid_points.contiguous().view(-1, 3) | |
| grid_vals = chunked_query(grid_points).float().view(res + 1, res + 1, res + 1) | |
| # sparse-query finer resolutions | |
| dilate_kernel_3 = torch.ones(1, 1, 3, 3, 3, dtype=torch.float32, device=latent.device) | |
| dilate_kernel_5 = torch.ones(1, 1, 5, 5, 5, dtype=torch.float32, device=latent.device) | |
| for i in range(1, len(resolutions)): | |
| res = resolutions[i] | |
| # get the boundary grid mask in the coarser grid (where the grid_vals have different signs with at least one of its neighbors) | |
| grid_signs = grid_vals >= 0 | |
| mask = torch.zeros_like(grid_signs) | |
| mask[1:, :, :] += grid_signs[1:, :, :] != grid_signs[:-1, :, :] | |
| mask[:-1, :, :] += grid_signs[:-1, :, :] != grid_signs[1:, :, :] | |
| mask[:, 1:, :] += grid_signs[:, 1:, :] != grid_signs[:, :-1, :] | |
| mask[:, :-1, :] += grid_signs[:, :-1, :] != grid_signs[:, 1:, :] | |
| mask[:, :, 1:] += grid_signs[:, :, 1:] != grid_signs[:, :, :-1] | |
| mask[:, :, :-1] += grid_signs[:, :, :-1] != grid_signs[:, :, 1:] | |
| # empirical: also add those with abs(grid_vals) < 0.95 | |
| mask += grid_vals.abs() < 0.95 | |
| mask = (mask > 0).float() | |
| # empirical: dilate the coarse mask | |
| if res < 512: | |
| mask = mask.unsqueeze(0).unsqueeze(0) | |
| mask = F.conv3d(mask, weight=dilate_kernel_3, padding=1) | |
| mask = mask.squeeze(0).squeeze(0) | |
| # get the coarse coordinates | |
| cidx_x, cidx_y, cidx_z = torch.nonzero(mask, as_tuple=True) | |
| # fill to the fine indices | |
| mask_fine = torch.zeros(res + 1, res + 1, res + 1, dtype=torch.float32, device=latent.device) | |
| mask_fine[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 | |
| # empirical: dilate the fine mask | |
| if res < 512: | |
| mask_fine = mask_fine.unsqueeze(0).unsqueeze(0) | |
| mask_fine = F.conv3d(mask_fine, weight=dilate_kernel_3, padding=1) | |
| mask_fine = mask_fine.squeeze(0).squeeze(0) | |
| else: | |
| mask_fine = mask_fine.unsqueeze(0).unsqueeze(0) | |
| mask_fine = F.conv3d(mask_fine, weight=dilate_kernel_5, padding=2) | |
| mask_fine = mask_fine.squeeze(0).squeeze(0) | |
| # get the fine coordinates | |
| fidx_x, fidx_y, fidx_z = torch.nonzero(mask_fine, as_tuple=True) | |
| # convert to float query points | |
| query_points = torch.stack([fidx_x, fidx_y, fidx_z], dim=-1) # [N, 3] | |
| query_points = query_points * 2 / res - 1 # [N, 3], in [-1, 1] | |
| # query | |
| pred = chunked_query(query_points).float() | |
| # fill to the fine indices | |
| grid_vals = torch.full((res + 1, res + 1, res + 1), -100.0, dtype=torch.float32, device=latent.device) | |
| grid_vals[fidx_x, fidx_y, fidx_z] = pred | |
| # print(f"[INFO] hierarchical: resolution: {res}, valid coarse points: {len(cidx_x)}, valid fine points: {len(fidx_x)}") | |
| grid_vals = grid_vals.unsqueeze(0) # [1, res+1, res+1, res+1] | |
| grid_vals[grid_vals <= -100.0] = float("nan") # use nans to ignore invalid regions | |
| # extract mesh | |
| meshes = [] | |
| for b in range(B): | |
| vertices, faces = extract_mesh(grid_vals[b], resolution) | |
| meshes.append((vertices, faces)) | |
| output["meshes"] = meshes | |
| return output | |