Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| import copy | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| from easydict import EasyDict as edict | |
| import utils3d.torch | |
| from ..basic import BasicTrainer | |
| from ...representations import Gaussian | |
| from ...renderers import GaussianRenderer | |
| from ...modules.sparse import SparseTensor | |
| from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips | |
| class SLatVaeGaussianTrainer(BasicTrainer): | |
| """ | |
| Trainer for structured latent VAE. | |
| Args: | |
| models (dict[str, nn.Module]): Models to train. | |
| dataset (torch.utils.data.Dataset): Dataset. | |
| output_dir (str): Output directory. | |
| load_dir (str): Load directory. | |
| step (int): Step to load. | |
| batch_size (int): Batch size. | |
| batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. | |
| batch_split (int): Split batch with gradient accumulation. | |
| max_steps (int): Max steps. | |
| optimizer (dict): Optimizer config. | |
| lr_scheduler (dict): Learning rate scheduler config. | |
| elastic (dict): Elastic memory management config. | |
| grad_clip (float or dict): Gradient clip config. | |
| ema_rate (float or list): Exponential moving average rates. | |
| fp16_mode (str): FP16 mode. | |
| - None: No FP16. | |
| - 'inflat_all': Hold a inflated fp32 master param for all params. | |
| - 'amp': Automatic mixed precision. | |
| fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. | |
| finetune_ckpt (dict): Finetune checkpoint. | |
| log_param_stats (bool): Log parameter stats. | |
| i_print (int): Print interval. | |
| i_log (int): Log interval. | |
| i_sample (int): Sample interval. | |
| i_save (int): Save interval. | |
| i_ddpcheck (int): DDP check interval. | |
| loss_type (str): Loss type. Can be 'l1', 'l2' | |
| lambda_ssim (float): SSIM loss weight. | |
| lambda_lpips (float): LPIPS loss weight. | |
| lambda_kl (float): KL loss weight. | |
| regularizations (dict): Regularization config. | |
| """ | |
| def __init__( | |
| self, | |
| *args, | |
| loss_type: str = 'l1', | |
| lambda_ssim: float = 0.2, | |
| lambda_lpips: float = 0.2, | |
| lambda_kl: float = 1e-6, | |
| regularizations: Dict = {}, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.loss_type = loss_type | |
| self.lambda_ssim = lambda_ssim | |
| self.lambda_lpips = lambda_lpips | |
| self.lambda_kl = lambda_kl | |
| self.regularizations = regularizations | |
| self._init_renderer() | |
| def _init_renderer(self): | |
| rendering_options = {"near" : 0.8, | |
| "far" : 1.6, | |
| "bg_color" : 'random'} | |
| self.renderer = GaussianRenderer(rendering_options) | |
| self.renderer.pipe.kernel_size = self.models['decoder'].rep_config['2d_filter_kernel_size'] | |
| def _render_batch(self, reps: List[Gaussian], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Render a batch of representations. | |
| Args: | |
| reps: The dictionary of lists of representations. | |
| extrinsics: The [N x 4 x 4] tensor of extrinsics. | |
| intrinsics: The [N x 3 x 3] tensor of intrinsics. | |
| """ | |
| ret = None | |
| for i, representation in enumerate(reps): | |
| render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i]) | |
| if ret is None: | |
| ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']} | |
| for k, v in render_pack.items(): | |
| ret[k].append(v) | |
| ret['bg_color'].append(self.renderer.bg_color) | |
| for k, v in ret.items(): | |
| ret[k] = torch.stack(v, dim=0) | |
| return ret | |
| def _get_status(self, z: SparseTensor, reps: List[Gaussian]) -> Dict: | |
| xyz = torch.cat([g.get_xyz for g in reps], dim=0) | |
| xyz_base = (z.coords[:, 1:].float() + 0.5) / self.models['decoder'].resolution - 0.5 | |
| offset = xyz - xyz_base.unsqueeze(1).expand(-1, self.models['decoder'].rep_config['num_gaussians'], -1).reshape(-1, 3) | |
| status = { | |
| 'xyz': xyz, | |
| 'offset': offset, | |
| 'scale': torch.cat([g.get_scaling for g in reps], dim=0), | |
| 'opacity': torch.cat([g.get_opacity for g in reps], dim=0), | |
| } | |
| for k in list(status.keys()): | |
| status[k] = { | |
| 'mean': status[k].mean().item(), | |
| 'max': status[k].max().item(), | |
| 'min': status[k].min().item(), | |
| } | |
| return status | |
| def _get_regularization_loss(self, reps: List[Gaussian]) -> Tuple[torch.Tensor, Dict]: | |
| loss = 0.0 | |
| terms = {} | |
| if 'lambda_vol' in self.regularizations: | |
| scales = torch.cat([g.get_scaling for g in reps], dim=0) # [N x 3] | |
| volume = torch.prod(scales, dim=1) # [N] | |
| terms[f'reg_vol'] = volume.mean() | |
| loss = loss + self.regularizations['lambda_vol'] * terms[f'reg_vol'] | |
| if 'lambda_opacity' in self.regularizations: | |
| opacity = torch.cat([g.get_opacity for g in reps], dim=0) | |
| terms[f'reg_opacity'] = (opacity - 1).pow(2).mean() | |
| loss = loss + self.regularizations['lambda_opacity'] * terms[f'reg_opacity'] | |
| return loss, terms | |
| def training_losses( | |
| self, | |
| feats: SparseTensor, | |
| image: torch.Tensor, | |
| alpha: torch.Tensor, | |
| extrinsics: torch.Tensor, | |
| intrinsics: torch.Tensor, | |
| return_aux: bool = False, | |
| **kwargs | |
| ) -> Tuple[Dict, Dict]: | |
| """ | |
| Compute training losses. | |
| Args: | |
| feats: The [N x * x C] sparse tensor of features. | |
| image: The [N x 3 x H x W] tensor of images. | |
| alpha: The [N x H x W] tensor of alpha channels. | |
| extrinsics: The [N x 4 x 4] tensor of extrinsics. | |
| intrinsics: The [N x 3 x 3] tensor of intrinsics. | |
| return_aux: Whether to return auxiliary information. | |
| Returns: | |
| a dict with the key "loss" containing a scalar tensor. | |
| may also contain other keys for different terms. | |
| """ | |
| z, mean, logvar = self.training_models['encoder'](feats, sample_posterior=True, return_raw=True) | |
| reps = self.training_models['decoder'](z) | |
| self.renderer.rendering_options.resolution = image.shape[-1] | |
| render_results = self._render_batch(reps, extrinsics, intrinsics) | |
| terms = edict(loss = 0.0, rec = 0.0) | |
| rec_image = render_results['color'] | |
| gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None] | |
| if self.loss_type == 'l1': | |
| terms["l1"] = l1_loss(rec_image, gt_image) | |
| terms["rec"] = terms["rec"] + terms["l1"] | |
| elif self.loss_type == 'l2': | |
| terms["l2"] = l2_loss(rec_image, gt_image) | |
| terms["rec"] = terms["rec"] + terms["l2"] | |
| else: | |
| raise ValueError(f"Invalid loss type: {self.loss_type}") | |
| if self.lambda_ssim > 0: | |
| terms["ssim"] = 1 - ssim(rec_image, gt_image) | |
| terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"] | |
| if self.lambda_lpips > 0: | |
| terms["lpips"] = lpips(rec_image, gt_image) | |
| terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"] | |
| terms["loss"] = terms["loss"] + terms["rec"] | |
| terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) | |
| terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] | |
| reg_loss, reg_terms = self._get_regularization_loss(reps) | |
| terms.update(reg_terms) | |
| terms["loss"] = terms["loss"] + reg_loss | |
| status = self._get_status(z, reps) | |
| if return_aux: | |
| return terms, status, {'rec_image': rec_image, 'gt_image': gt_image} | |
| return terms, status | |
| def run_snapshot( | |
| self, | |
| num_samples: int, | |
| batch_size: int, | |
| verbose: bool = False, | |
| ) -> Dict: | |
| dataloader = DataLoader( | |
| copy.deepcopy(self.dataset), | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=0, | |
| collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, | |
| ) | |
| # inference | |
| ret_dict = {} | |
| gt_images = [] | |
| exts = [] | |
| ints = [] | |
| reps = [] | |
| for i in range(0, num_samples, batch_size): | |
| batch = min(batch_size, num_samples - i) | |
| data = next(iter(dataloader)) | |
| args = {k: v[:batch].cuda() for k, v in data.items()} | |
| gt_images.append(args['image'] * args['alpha'][:, None]) | |
| exts.append(args['extrinsics']) | |
| ints.append(args['intrinsics']) | |
| z = self.models['encoder'](args['feats'], sample_posterior=True, return_raw=False) | |
| reps.extend(self.models['decoder'](z)) | |
| gt_images = torch.cat(gt_images, dim=0) | |
| ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}}) | |
| # render single view | |
| exts = torch.cat(exts, dim=0) | |
| ints = torch.cat(ints, dim=0) | |
| self.renderer.rendering_options.bg_color = (0, 0, 0) | |
| self.renderer.rendering_options.resolution = gt_images.shape[-1] | |
| render_results = self._render_batch(reps, exts, ints) | |
| ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}}) | |
| # render multiview | |
| self.renderer.rendering_options.resolution = 512 | |
| ## Build camera | |
| yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] | |
| yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) | |
| yaws = [y + yaws_offset for y in yaws] | |
| pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] | |
| ## render each view | |
| miltiview_images = [] | |
| for yaw, pitch in zip(yaws, pitch): | |
| orig = torch.tensor([ | |
| np.sin(yaw) * np.cos(pitch), | |
| np.cos(yaw) * np.cos(pitch), | |
| np.sin(pitch), | |
| ]).float().cuda() * 2 | |
| fov = torch.deg2rad(torch.tensor(30)).cuda() | |
| extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) | |
| intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) | |
| extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1) | |
| intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1) | |
| render_results = self._render_batch(reps, extrinsics, intrinsics) | |
| miltiview_images.append(render_results['color']) | |
| ## Concatenate views | |
| miltiview_images = torch.cat([ | |
| torch.cat(miltiview_images[:2], dim=-2), | |
| torch.cat(miltiview_images[2:], dim=-2), | |
| ], dim=-1) | |
| ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}}) | |
| self.renderer.rendering_options.bg_color = 'random' | |
| return ret_dict | |