Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # modified from DUSt3R | |
| import torch | |
| import torch.nn.functional as F | |
| def postprocess(out, depth_mode, conf_mode, pos_z=False): | |
| """ | |
| extract 3D points/confidence from prediction head output | |
| """ | |
| fmap = out.permute(0, 2, 3, 1) # B,H,W,3 | |
| res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode, pos_z=pos_z)) | |
| if conf_mode is not None: | |
| res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) | |
| return res | |
| def postprocess_rgb(out, eps=1e-6): | |
| fmap = out.permute(0, 2, 3, 1) # B,H,W,3 | |
| res = torch.sigmoid(fmap) * (1 - 2 * eps) + eps | |
| res = (res - 0.5) * 2 | |
| return dict(rgb=res) | |
| def postprocess_pose(out, mode, inverse=False): | |
| """ | |
| extract pose from prediction head output | |
| """ | |
| mode, vmin, vmax = mode | |
| no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) | |
| assert no_bounds | |
| trans = out[..., 0:3] | |
| quats = out[..., 3:7] | |
| if mode == "linear": | |
| if no_bounds: | |
| return trans # [-inf, +inf] | |
| return trans.clip(min=vmin, max=vmax) | |
| d = trans.norm(dim=-1, keepdim=True) | |
| if mode == "square": | |
| if inverse: | |
| scale = d / d.square().clip(min=1e-8) | |
| else: | |
| scale = d.square() / d.clip(min=1e-8) | |
| if mode == "exp": | |
| if inverse: | |
| scale = d / torch.expm1(d).clip(min=1e-8) | |
| else: | |
| scale = torch.expm1(d) / d.clip(min=1e-8) | |
| trans = trans * scale | |
| quats = standardize_quaternion(quats) | |
| return torch.cat([trans, quats], dim=-1) | |
| def postprocess_pose_conf(out): | |
| fmap = out.permute(0, 2, 3, 1) # B,H,W,1 | |
| return dict(pose_conf=torch.sigmoid(fmap)) | |
| def postprocess_desc(out, depth_mode, conf_mode, desc_dim, double_channel=False): | |
| """ | |
| extract 3D points/confidence from prediction head output | |
| """ | |
| fmap = out.permute(0, 2, 3, 1) # B,H,W,3 | |
| res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) | |
| if conf_mode is not None: | |
| res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) | |
| if double_channel: | |
| res["pts3d_self"] = reg_dense_depth( | |
| fmap[ | |
| :, :, :, 3 + int(conf_mode is not None) : 6 + int(conf_mode is not None) | |
| ], | |
| mode=depth_mode, | |
| ) | |
| if conf_mode is not None: | |
| res["conf_self"] = reg_dense_conf( | |
| fmap[:, :, :, 6 + int(conf_mode is not None)], mode=conf_mode | |
| ) | |
| start = ( | |
| 3 | |
| + int(conf_mode is not None) | |
| + int(double_channel) * (3 + int(conf_mode is not None)) | |
| ) | |
| res["desc"] = reg_desc(fmap[:, :, :, start : start + desc_dim], mode="norm") | |
| res["desc_conf"] = reg_dense_conf(fmap[:, :, :, start + desc_dim], mode=conf_mode) | |
| assert start + desc_dim + 1 == fmap.shape[-1] | |
| return res | |
| def reg_desc(desc, mode="norm"): | |
| if "norm" in mode: | |
| desc = desc / desc.norm(dim=-1, keepdim=True) | |
| else: | |
| raise ValueError(f"Unknown desc mode {mode}") | |
| return desc | |
| def reg_dense_depth(xyz, mode, pos_z=False): | |
| """ | |
| extract 3D points from prediction head output | |
| """ | |
| mode, vmin, vmax = mode | |
| no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) | |
| assert no_bounds | |
| if mode == "linear": | |
| if no_bounds: | |
| return xyz # [-inf, +inf] | |
| return xyz.clip(min=vmin, max=vmax) | |
| if pos_z: | |
| sign = torch.sign(xyz[..., -1:]) | |
| xyz *= sign | |
| d = xyz.norm(dim=-1, keepdim=True) | |
| xyz = xyz / d.clip(min=1e-8) | |
| if mode == "square": | |
| return xyz * d.square() | |
| if mode == "exp": | |
| return xyz * torch.expm1(d) | |
| raise ValueError(f"bad {mode=}") | |
| def reg_dense_conf(x, mode): | |
| """ | |
| extract confidence from prediction head output | |
| """ | |
| mode, vmin, vmax = mode | |
| if mode == "exp": | |
| return vmin + x.exp().clip(max=vmax - vmin) | |
| if mode == "sigmoid": | |
| return (vmax - vmin) * torch.sigmoid(x) + vmin | |
| raise ValueError(f"bad {mode=}") | |
| def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert a unit quaternion to a standard form: one in which the real | |
| part is non negative. | |
| Args: | |
| quaternions: Quaternions with real part first, | |
| as tensor of shape (..., 4). | |
| Returns: | |
| Standardized quaternions as tensor of shape (..., 4). | |
| """ | |
| quaternions = F.normalize(quaternions, p=2, dim=-1) | |
| return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) | |