Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # modified from DUSt3R | |
| import numpy as np | |
| import torch | |
| from dust3r.utils.geometry import xy_grid | |
| def estimate_focal_knowing_depth( | |
| pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf | |
| ): | |
| """Reprojection method, for when the absolute depth is known: | |
| 1) estimate the camera focal using a robust estimator | |
| 2) reproject points onto true rays, minimizing a certain error | |
| """ | |
| B, H, W, THREE = pts3d.shape | |
| assert THREE == 3 | |
| pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view( | |
| -1, 1, 2 | |
| ) # B,HW,2 | |
| pts3d = pts3d.flatten(1, 2) # (B, HW, 3) | |
| if focal_mode == "median": | |
| with torch.no_grad(): | |
| u, v = pixels.unbind(dim=-1) | |
| x, y, z = pts3d.unbind(dim=-1) | |
| fx_votes = (u * z) / x | |
| fy_votes = (v * z) / y | |
| f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) | |
| focal = torch.nanmedian(f_votes, dim=-1).values | |
| elif focal_mode == "weiszfeld": | |
| xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num( | |
| posinf=0, neginf=0 | |
| ) # homogeneous (x,y,1) | |
| dot_xy_px = (xy_over_z * pixels).sum(dim=-1) | |
| dot_xy_xy = xy_over_z.square().sum(dim=-1) | |
| focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) | |
| for iter in range(10): | |
| dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) | |
| w = dis.clip(min=1e-8).reciprocal() | |
| focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) | |
| else: | |
| raise ValueError(f"bad {focal_mode=}") | |
| focal_base = max(H, W) / ( | |
| 2 * np.tan(np.deg2rad(60) / 2) | |
| ) # size / 1.1547005383792515 | |
| focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base) | |
| return focal | |