Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from random import randint | |
| from tqdm.rich import trange | |
| from tqdm import tqdm as tqdm | |
| from source.networks import Warper3DGS | |
| import wandb | |
| import sys | |
| sys.path.append('./submodules/gaussian-splatting/') | |
| import lpips | |
| from source.losses import ssim, l1_loss, psnr | |
| from rich.console import Console | |
| from rich.theme import Theme | |
| custom_theme = Theme({ | |
| "info": "dim cyan", | |
| "warning": "magenta", | |
| "danger": "bold red" | |
| }) | |
| #from source.corr_init import init_gaussians_with_corr | |
| from source.corr_init_new import init_gaussians_with_corr_profiled as init_gaussians_with_corr | |
| from source.utils_aux import log_samples | |
| from source.timer import Timer | |
| class EDGSTrainer: | |
| def __init__(self, | |
| GS: Warper3DGS, | |
| training_config, | |
| dataset_white_background=False, | |
| device=torch.device('cuda'), | |
| log_wandb=True, | |
| ): | |
| self.GS = GS | |
| self.scene = GS.scene | |
| self.viewpoint_stack = GS.viewpoint_stack | |
| self.gaussians = GS.gaussians | |
| self.training_config = training_config | |
| self.GS_optimizer = GS.gaussians.optimizer | |
| self.dataset_white_background = dataset_white_background | |
| self.training_step = 1 | |
| self.gs_step = 0 | |
| self.CONSOLE = Console(width=120, theme=custom_theme) | |
| self.saving_iterations = training_config.save_iterations | |
| self.evaluate_iterations = None | |
| self.batch_size = training_config.batch_size | |
| self.ema_loss_for_log = 0.0 | |
| # Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}} | |
| self.logs_losses = {} | |
| self.lpips = lpips.LPIPS(net='vgg').to(device) | |
| self.device = device | |
| self.timer = Timer() | |
| self.log_wandb = log_wandb | |
| def load_checkpoints(self, load_cfg): | |
| # Load 3DGS checkpoint | |
| if load_cfg.gs: | |
| self.gs.gaussians.restore( | |
| torch.load(f"{load_cfg.gs}/chkpnt{load_cfg.gs_step}.pth")[0], | |
| self.training_config) | |
| self.GS_optimizer = self.GS.gaussians.optimizer | |
| self.CONSOLE.print(f"3DGS loaded from checkpoint for iteration {load_cfg.gs_step}", | |
| style="info") | |
| self.training_step += load_cfg.gs_step | |
| self.gs_step += load_cfg.gs_step | |
| def train(self, train_cfg): | |
| # 3DGS training | |
| self.CONSOLE.print("Train 3DGS for {} iterations".format(train_cfg.gs_epochs), style="info") | |
| with trange(self.training_step, self.training_step + train_cfg.gs_epochs, desc="[green]Train gaussians") as progress_bar: | |
| for self.training_step in progress_bar: | |
| radii = self.train_step_gs(max_lr=train_cfg.max_lr, no_densify=train_cfg.no_densify) | |
| with torch.no_grad(): | |
| if train_cfg.no_densify: | |
| self.prune(radii) | |
| else: | |
| self.densify_and_prune(radii) | |
| if train_cfg.reduce_opacity: | |
| # Slightly reduce opacity every few steps: | |
| if self.gs_step < self.training_config.densify_until_iter and self.gs_step % 10 == 0: | |
| opacities_new = torch.log(torch.exp(self.GS.gaussians._opacity.data) * 0.99) | |
| self.GS.gaussians._opacity.data = opacities_new | |
| self.timer.pause() | |
| # Progress bar | |
| if self.training_step % 10 == 0: | |
| progress_bar.set_postfix({"[red]Loss": f"{self.ema_loss_for_log:.{7}f}"}, refresh=True) | |
| # Log and save | |
| if self.training_step in self.saving_iterations: | |
| self.save_model() | |
| if self.evaluate_iterations is not None: | |
| if self.training_step in self.evaluate_iterations: | |
| self.evaluate() | |
| else: | |
| if (self.training_step <= 3000 and self.training_step % 500 == 0) or \ | |
| (self.training_step > 3000 and self.training_step % 1000 == 228) : | |
| self.evaluate() | |
| self.timer.start() | |
| def evaluate(self): | |
| torch.cuda.empty_cache() | |
| log_gen_images, log_real_images = [], [] | |
| validation_configs = ({'name': 'test', 'cameras': self.scene.getTestCameras(), 'cam_idx': self.training_config.TEST_CAM_IDX_TO_LOG}, | |
| {'name': 'train', | |
| 'cameras': [self.scene.getTrainCameras()[idx % len(self.scene.getTrainCameras())] for idx in | |
| range(0, 150, 5)], 'cam_idx': 10}) | |
| if self.log_wandb: | |
| wandb.log({f"Number of Gaussians": len(self.GS.gaussians._xyz)}, step=self.training_step) | |
| for config in validation_configs: | |
| if config['cameras'] and len(config['cameras']) > 0: | |
| l1_test = 0.0 | |
| psnr_test = 0.0 | |
| ssim_test = 0.0 | |
| lpips_splat_test = 0.0 | |
| for idx, viewpoint in enumerate(config['cameras']): | |
| image = torch.clamp(self.GS(viewpoint)["render"], 0.0, 1.0) | |
| gt_image = torch.clamp(viewpoint.original_image.to(self.device), 0.0, 1.0) | |
| l1_test += l1_loss(image, gt_image).double() | |
| psnr_test += psnr(image.unsqueeze(0), gt_image.unsqueeze(0)).double() | |
| ssim_test += ssim(image, gt_image).double() | |
| lpips_splat_test += self.lpips(image, gt_image).detach().double() | |
| if idx in [config['cam_idx']]: | |
| log_gen_images.append(image) | |
| log_real_images.append(gt_image) | |
| psnr_test /= len(config['cameras']) | |
| l1_test /= len(config['cameras']) | |
| ssim_test /= len(config['cameras']) | |
| lpips_splat_test /= len(config['cameras']) | |
| if self.log_wandb: | |
| wandb.log({f"{config['name']}/L1": l1_test.item(), f"{config['name']}/PSNR": psnr_test.item(), \ | |
| f"{config['name']}/SSIM": ssim_test.item(), f"{config['name']}/LPIPS_splat": lpips_splat_test.item()}, step = self.training_step) | |
| self.CONSOLE.print("\n[ITER {}], #{} gaussians, Evaluating {}: L1={:.6f}, PSNR={:.6f}, SSIM={:.6f}, LPIPS_splat={:.6f} ".format( | |
| self.training_step, len(self.GS.gaussians._xyz), config['name'], l1_test.item(), psnr_test.item(), ssim_test.item(), lpips_splat_test.item()), style="info") | |
| if self.log_wandb: | |
| with torch.no_grad(): | |
| log_samples(torch.stack((log_real_images[0],log_gen_images[0])) , [], self.training_step, caption="Real and Generated Samples") | |
| wandb.log({"time": self.timer.get_elapsed_time()}, step=self.training_step) | |
| torch.cuda.empty_cache() | |
| def train_step_gs(self, max_lr = False, no_densify = False): | |
| self.gs_step += 1 | |
| if max_lr: | |
| self.GS.gaussians.update_learning_rate(max(self.gs_step, 8_000)) | |
| else: | |
| self.GS.gaussians.update_learning_rate(self.gs_step) | |
| # Every 1000 its we increase the levels of SH up to a maximum degree | |
| if self.gs_step % 1000 == 0: | |
| self.GS.gaussians.oneupSHdegree() | |
| # Pick a random Camera | |
| if not self.viewpoint_stack: | |
| self.viewpoint_stack = self.scene.getTrainCameras().copy() | |
| viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1)) | |
| render_pkg = self.GS(viewpoint_cam=viewpoint_cam) | |
| image = render_pkg["render"] | |
| # Loss | |
| gt_image = viewpoint_cam.original_image.to(self.device) | |
| L1_loss = l1_loss(image, gt_image) | |
| ssim_loss = (1.0 - ssim(image, gt_image)) | |
| loss = (1.0 - self.training_config.lambda_dssim) * L1_loss + \ | |
| self.training_config.lambda_dssim * ssim_loss | |
| self.timer.pause() | |
| self.logs_losses[self.training_step] = {"loss": loss.item(), | |
| "L1_loss": L1_loss.item(), | |
| "ssim_loss": ssim_loss.item()} | |
| if self.log_wandb: | |
| for k, v in self.logs_losses[self.training_step].items(): | |
| wandb.log({f"train/{k}": v}, step=self.training_step) | |
| self.ema_loss_for_log = 0.4 * self.logs_losses[self.training_step]["loss"] + 0.6 * self.ema_loss_for_log | |
| self.timer.start() | |
| self.GS_optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| with torch.no_grad(): | |
| if self.gs_step < self.training_config.densify_until_iter and not no_densify: | |
| self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]] = torch.max( | |
| self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]], | |
| render_pkg["radii"][render_pkg["visibility_filter"]]) | |
| self.GS.gaussians.add_densification_stats(render_pkg["viewspace_points"], | |
| render_pkg["visibility_filter"]) | |
| # Optimizer step | |
| self.GS_optimizer.step() | |
| self.GS_optimizer.zero_grad(set_to_none=True) | |
| return render_pkg["radii"] | |
| def densify_and_prune(self, radii = None): | |
| # Densification or pruning | |
| if self.gs_step < self.training_config.densify_until_iter: | |
| if (self.gs_step > self.training_config.densify_from_iter) and \ | |
| (self.gs_step % self.training_config.densification_interval == 0): | |
| size_threshold = 20 if self.gs_step > self.training_config.opacity_reset_interval else None | |
| self.GS.gaussians.densify_and_prune(self.training_config.densify_grad_threshold, | |
| 0.005, | |
| self.GS.scene.cameras_extent, | |
| size_threshold, radii) | |
| if self.gs_step % self.training_config.opacity_reset_interval == 0 or ( | |
| self.dataset_white_background and self.gs_step == self.training_config.densify_from_iter): | |
| self.GS.gaussians.reset_opacity() | |
| def save_model(self): | |
| print("\n[ITER {}] Saving Gaussians".format(self.gs_step)) | |
| self.scene.save(self.gs_step) | |
| print("\n[ITER {}] Saving Checkpoint".format(self.gs_step)) | |
| torch.save((self.GS.gaussians.capture(), self.gs_step), | |
| self.scene.model_path + "/chkpnt" + str(self.gs_step) + ".pth") | |
| def init_with_corr(self, cfg, verbose=False, roma_model=None): | |
| """ | |
| Initializes image with matchings. Also removes SfM init points. | |
| Args: | |
| cfg: configuration part named init_wC. Check train.yaml | |
| verbose: whether you want to print intermediate results. Useful for debug. | |
| roma_model: optionally you can pass here preinit RoMA model to avoid reinit | |
| it every time. | |
| """ | |
| if not cfg.use: | |
| return None | |
| N_splats_at_init = len(self.GS.gaussians._xyz) | |
| print("N_splats_at_init:", N_splats_at_init) | |
| camera_set, selected_indices, visualization_dict = init_gaussians_with_corr( | |
| self.GS.gaussians, | |
| self.scene, | |
| cfg, | |
| self.device, | |
| verbose=verbose, | |
| roma_model=roma_model) | |
| # Remove SfM points and leave only matchings inits | |
| if not cfg.add_SfM_init: | |
| with torch.no_grad(): | |
| N_splats_after_init = len(self.GS.gaussians._xyz) | |
| print("N_splats_after_init:", N_splats_after_init) | |
| self.gaussians.tmp_radii = torch.zeros(self.gaussians._xyz.shape[0]).to(self.device) | |
| mask = torch.concat([torch.ones(N_splats_at_init, dtype=torch.bool), | |
| torch.zeros(N_splats_after_init-N_splats_at_init, dtype=torch.bool)], | |
| axis=0) | |
| self.GS.gaussians.prune_points(mask) | |
| with torch.no_grad(): | |
| gaussians = self.gaussians | |
| gaussians._scaling = gaussians.scaling_inverse_activation(gaussians.scaling_activation(gaussians._scaling)*0.5) | |
| return visualization_dict | |
| def prune(self, radii, min_opacity=0.005): | |
| self.GS.gaussians.tmp_radii = radii | |
| if self.gs_step < self.training_config.densify_until_iter: | |
| prune_mask = (self.GS.gaussians.get_opacity < min_opacity).squeeze() | |
| self.GS.gaussians.prune_points(prune_mask) | |
| torch.cuda.empty_cache() | |
| self.GS.gaussians.tmp_radii = None | |