Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from gsplat import rasterization | |
| from dust3r.utils.geometry import inv, geotrf | |
| def render( | |
| intrinsics: torch.Tensor, | |
| pts3d: torch.Tensor, | |
| rgbs: torch.Tensor | None = None, | |
| scale: float = 0.002, | |
| opacity: float = 0.95, | |
| ): | |
| device = pts3d.device | |
| batch_size = len(intrinsics) | |
| img_size = pts3d.shape[1:3] | |
| pts3d = pts3d.reshape(batch_size, -1, 3) | |
| num_pts = pts3d.shape[1] | |
| quats = torch.randn((num_pts, 4), device=device) | |
| quats = quats / quats.norm(dim=-1, keepdim=True) | |
| scales = scale * torch.ones((num_pts, 3), device=device) | |
| opacities = opacity * torch.ones((num_pts), device=device) | |
| if rgbs is not None: | |
| assert rgbs.shape[1] == 3 | |
| rgbs = rgbs.reshape(batch_size, 3, -1).transpose(1, 2) | |
| else: | |
| rgbs = torch.ones_like(pts3d[:, :, :3]) | |
| rendered_rgbs = [] | |
| rendered_depths = [] | |
| accs = [] | |
| for i in range(batch_size): | |
| rgbd, acc, _ = rasterization( | |
| pts3d[i], | |
| quats, | |
| scales, | |
| opacities, | |
| rgbs[i], | |
| torch.eye(4, device=device)[None], | |
| intrinsics[[i]], | |
| width=img_size[1], | |
| height=img_size[0], | |
| packed=False, | |
| render_mode="RGB+D", | |
| ) | |
| rendered_depths.append(rgbd[..., 3]) | |
| rendered_depths = torch.cat(rendered_depths, dim=0) | |
| return rendered_rgbs, rendered_depths, accs | |
| def get_render_results(gts, preds, self_view=False): | |
| device = preds[0]["pts3d_in_other_view"].device | |
| with torch.no_grad(): | |
| depths = [] | |
| gt_depths = [] | |
| for i, (gt, pred) in enumerate(zip(gts, preds)): | |
| if self_view: | |
| camera = inv(gt["camera_pose"]).to(device) | |
| intrinsics = gt["camera_intrinsics"].to(device) | |
| pred = pred["pts3d_in_other_view"] | |
| else: | |
| camera = inv(gts[0]["camera_pose"]).to(device) | |
| intrinsics = gts[0]["camera_intrinsics"].to(device) | |
| pred = pred["pts3d_in_other_view"] | |
| gt_img = gt["img"].to(device) | |
| gt_pts3d = gt["pts3d"].to(device) | |
| _, depth, _ = render(intrinsics, pred, gt_img) | |
| _, gt_depth, _ = render(intrinsics, geotrf(camera, gt_pts3d), gt_img) | |
| depths.append(depth) | |
| gt_depths.append(gt_depth) | |
| return depths, gt_depths | |