Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from croco.models.blocks import Mlp | |
| inf = float("inf") | |
| class PoseDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size=768, | |
| mlp_ratio=4, | |
| pose_encoding_type="absT_quaR", | |
| ): | |
| super().__init__() | |
| self.pose_encoding_type = pose_encoding_type | |
| if self.pose_encoding_type == "absT_quaR": | |
| self.target_dim = 7 | |
| self.mlp = Mlp( | |
| in_features=hidden_size, | |
| hidden_features=int(hidden_size * mlp_ratio), | |
| out_features=self.target_dim, | |
| drop=0, | |
| ) | |
| def forward( | |
| self, | |
| pose_feat, | |
| ): | |
| """ | |
| pose_feat: BxC | |
| preliminary_cameras: cameras in opencv coordinate. | |
| """ | |
| pred_cameras = self.mlp(pose_feat) # Bx7, 3 for absT, 4 for quaR | |
| return pred_cameras | |
| class PoseEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size=768, | |
| mlp_ratio=4, | |
| pose_mode=("exp", -inf, inf), | |
| pose_encoding_type="absT_quaR", | |
| ): | |
| super().__init__() | |
| self.pose_encoding_type = pose_encoding_type | |
| self.pose_mode = pose_mode | |
| if self.pose_encoding_type == "absT_quaR": | |
| self.target_dim = 7 | |
| self.embed_pose = PoseEmbedding( | |
| target_dim=self.target_dim, | |
| out_dim=hidden_size, | |
| n_harmonic_functions=10, | |
| append_input=True, | |
| ) | |
| self.pose_encoder = Mlp( | |
| in_features=self.embed_pose.out_dim, | |
| hidden_features=int(hidden_size * mlp_ratio), | |
| out_features=hidden_size, | |
| drop=0, | |
| ) | |
| def forward(self, camera): | |
| from dust3r.heads.postprocess import postprocess_pose | |
| pose_enc = camera_to_pose_encoding( | |
| camera, | |
| pose_encoding_type=self.pose_encoding_type, | |
| ).to(camera.dtype) | |
| pose_enc = postprocess_pose(pose_enc, self.pose_mode, inverse=True) | |
| pose_feat = self.embed_pose(pose_enc) | |
| pose_feat = self.pose_encoder(pose_feat) | |
| return pose_feat | |
| class HarmonicEmbedding(torch.nn.Module): | |
| def __init__( | |
| self, | |
| n_harmonic_functions: int = 6, | |
| omega_0: float = 1.0, | |
| logspace: bool = True, | |
| append_input: bool = True, | |
| ) -> None: | |
| """ | |
| The harmonic embedding layer supports the classical | |
| Nerf positional encoding described in | |
| `NeRF <https://arxiv.org/abs/2003.08934>`_ | |
| and the integrated position encoding in | |
| `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_. | |
| During the inference you can provide the extra argument `diag_cov`. | |
| If `diag_cov is None`, it converts | |
| rays parametrized with a `ray_bundle` to 3D points by | |
| extending each ray according to the corresponding length. | |
| Then it converts each feature | |
| (i.e. vector along the last dimension) in `x` | |
| into a series of harmonic features `embedding`, | |
| where for each i in range(dim) the following are present | |
| in embedding[...]:: | |
| [ | |
| sin(f_1*x[..., i]), | |
| sin(f_2*x[..., i]), | |
| ... | |
| sin(f_N * x[..., i]), | |
| cos(f_1*x[..., i]), | |
| cos(f_2*x[..., i]), | |
| ... | |
| cos(f_N * x[..., i]), | |
| x[..., i], # only present if append_input is True. | |
| ] | |
| where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar | |
| denoting the i-th frequency of the harmonic embedding. | |
| If `diag_cov is not None`, it approximates | |
| conical frustums following a ray bundle as gaussians, | |
| defined by x, the means of the gaussians and diag_cov, | |
| the diagonal covariances. | |
| Then it converts each gaussian | |
| into a series of harmonic features `embedding`, | |
| where for each i in range(dim) the following are present | |
| in embedding[...]:: | |
| [ | |
| sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), | |
| sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), | |
| ... | |
| sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), | |
| cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), | |
| cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, | |
| ... | |
| cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), | |
| x[..., i], # only present if append_input is True. | |
| ] | |
| where N equals `n_harmonic_functions-1`, and f_i is a scalar | |
| denoting the i-th frequency of the harmonic embedding. | |
| If `logspace==True`, the frequencies `[f_1, ..., f_N]` are | |
| powers of 2: | |
| `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` | |
| If `logspace==False`, frequencies are linearly spaced between | |
| `1.0` and `2**(n_harmonic_functions-1)`: | |
| `f_1, ..., f_N = torch.linspace( | |
| 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions | |
| )` | |
| Note that `x` is also premultiplied by the base frequency `omega_0` | |
| before evaluating the harmonic functions. | |
| Args: | |
| n_harmonic_functions: int, number of harmonic | |
| features | |
| omega_0: float, base frequency | |
| logspace: bool, Whether to space the frequencies in | |
| logspace or linear space | |
| append_input: bool, whether to concat the original | |
| input to the harmonic embedding. If true the | |
| output is of the form (embed.sin(), embed.cos(), x) | |
| """ | |
| super().__init__() | |
| if logspace: | |
| frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32) | |
| else: | |
| frequencies = torch.linspace( | |
| 1.0, | |
| 2.0 ** (n_harmonic_functions - 1), | |
| n_harmonic_functions, | |
| dtype=torch.float32, | |
| ) | |
| self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) | |
| self.register_buffer( | |
| "_zero_half_pi", | |
| torch.tensor([0.0, 0.5 * torch.pi]), | |
| persistent=False, | |
| ) | |
| self.append_input = append_input | |
| def forward( | |
| self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: tensor of shape [..., dim] | |
| diag_cov: An optional tensor of shape `(..., dim)` | |
| representing the diagonal covariance matrices of our Gaussians, joined with x | |
| as means of the Gaussians. | |
| Returns: | |
| embedding: a harmonic embedding of `x` of shape | |
| [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] | |
| """ | |
| embed = x[..., None] * self._frequencies | |
| embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] | |
| embed = embed.sin() | |
| if diag_cov is not None: | |
| x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) | |
| exp_var = torch.exp(-0.5 * x_var) | |
| embed = embed * exp_var[..., None, :, :] | |
| embed = embed.reshape(*x.shape[:-1], -1) | |
| if self.append_input: | |
| return torch.cat([embed, x], dim=-1) | |
| return embed | |
| def get_output_dim_static( | |
| input_dims: int, n_harmonic_functions: int, append_input: bool | |
| ) -> int: | |
| """ | |
| Utility to help predict the shape of the output of `forward`. | |
| Args: | |
| input_dims: length of the last dimension of the input tensor | |
| n_harmonic_functions: number of embedding frequencies | |
| append_input: whether or not to concat the original | |
| input to the harmonic embedding | |
| Returns: | |
| int: the length of the last dimension of the output tensor | |
| """ | |
| return input_dims * (2 * n_harmonic_functions + int(append_input)) | |
| def get_output_dim(self, input_dims: int = 3) -> int: | |
| """ | |
| Same as above. The default for input_dims is 3 for 3D applications | |
| which use harmonic embedding for positional encoding, | |
| so the input might be xyz. | |
| """ | |
| return self.get_output_dim_static( | |
| input_dims, len(self._frequencies), self.append_input | |
| ) | |
| class PoseEmbedding(nn.Module): | |
| def __init__(self, target_dim, out_dim, n_harmonic_functions=10, append_input=True): | |
| super().__init__() | |
| self._emb_pose = HarmonicEmbedding( | |
| n_harmonic_functions=n_harmonic_functions, append_input=append_input | |
| ) | |
| self.out_dim = self._emb_pose.get_output_dim(target_dim) | |
| def forward(self, pose_encoding): | |
| e_pose_encoding = self._emb_pose(pose_encoding) | |
| return e_pose_encoding | |
| def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Returns torch.sqrt(torch.max(0, x)) | |
| but with a zero subgradient where x is 0. | |
| """ | |
| ret = torch.zeros_like(x) | |
| positive_mask = x > 0 | |
| ret[positive_mask] = torch.sqrt(x[positive_mask]) | |
| return ret | |
| def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert rotations given as rotation matrices to quaternions. | |
| Args: | |
| matrix: Rotation matrices as tensor of shape (..., 3, 3). | |
| Returns: | |
| quaternions with real part first, as tensor of shape (..., 4). | |
| """ | |
| if matrix.size(-1) != 3 or matrix.size(-2) != 3: | |
| raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") | |
| batch_dim = matrix.shape[:-2] | |
| m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( | |
| matrix.reshape(batch_dim + (9,)), dim=-1 | |
| ) | |
| q_abs = _sqrt_positive_part( | |
| torch.stack( | |
| [ | |
| 1.0 + m00 + m11 + m22, | |
| 1.0 + m00 - m11 - m22, | |
| 1.0 - m00 + m11 - m22, | |
| 1.0 - m00 - m11 + m22, | |
| ], | |
| dim=-1, | |
| ) | |
| ) | |
| quat_by_rijk = torch.stack( | |
| [ | |
| torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), | |
| torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), | |
| torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), | |
| torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), | |
| ], | |
| dim=-2, | |
| ) | |
| flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) | |
| quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) | |
| out = quat_candidates[ | |
| F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : | |
| ].reshape(batch_dim + (4,)) | |
| return standardize_quaternion(out) | |
| def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert a unit quaternion to a standard form: one in which the real | |
| part is non negative. | |
| Args: | |
| quaternions: Quaternions with real part first, | |
| as tensor of shape (..., 4). | |
| Returns: | |
| Standardized quaternions as tensor of shape (..., 4). | |
| """ | |
| quaternions = F.normalize(quaternions, p=2, dim=-1) | |
| return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) | |
| def camera_to_pose_encoding( | |
| camera, | |
| pose_encoding_type="absT_quaR", | |
| ): | |
| """ | |
| Inverse to pose_encoding_to_camera | |
| camera: opencv, cam2world | |
| """ | |
| if pose_encoding_type == "absT_quaR": | |
| quaternion_R = matrix_to_quaternion(camera[:, :3, :3]) | |
| pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1) | |
| else: | |
| raise ValueError(f"Unknown pose encoding {pose_encoding_type}") | |
| return pose_encoding | |
| def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert rotations given as quaternions to rotation matrices. | |
| Args: | |
| quaternions: quaternions with real part first, | |
| as tensor of shape (..., 4). | |
| Returns: | |
| Rotation matrices as tensor of shape (..., 3, 3). | |
| """ | |
| r, i, j, k = torch.unbind(quaternions, -1) | |
| two_s = 2.0 / (quaternions * quaternions).sum(-1) | |
| o = torch.stack( | |
| ( | |
| 1 - two_s * (j * j + k * k), | |
| two_s * (i * j - k * r), | |
| two_s * (i * k + j * r), | |
| two_s * (i * j + k * r), | |
| 1 - two_s * (i * i + k * k), | |
| two_s * (j * k - i * r), | |
| two_s * (i * k - j * r), | |
| two_s * (j * k + i * r), | |
| 1 - two_s * (i * i + j * j), | |
| ), | |
| -1, | |
| ) | |
| return o.reshape(quaternions.shape[:-1] + (3, 3)) | |
| def pose_encoding_to_camera( | |
| pose_encoding, | |
| pose_encoding_type="absT_quaR", | |
| ): | |
| """ | |
| Args: | |
| pose_encoding: A tensor of shape `BxC`, containing a batch of | |
| `B` `C`-dimensional pose encodings. | |
| pose_encoding_type: The type of pose encoding, | |
| """ | |
| if pose_encoding_type == "absT_quaR": | |
| abs_T = pose_encoding[:, :3] | |
| quaternion_R = pose_encoding[:, 3:7] | |
| R = quaternion_to_matrix(quaternion_R) | |
| else: | |
| raise ValueError(f"Unknown pose encoding {pose_encoding_type}") | |
| c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device) | |
| c2w_mats = c2w_mats[None].repeat(len(R), 1, 1) | |
| c2w_mats[:, :3, :3] = R | |
| c2w_mats[:, :3, 3] = abs_T | |
| return c2w_mats | |
| def quaternion_conjugate(q): | |
| """Compute the conjugate of quaternion q (w, x, y, z).""" | |
| q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) | |
| return q_conj | |
| def quaternion_multiply(q1, q2): | |
| """Multiply two quaternions q1 and q2.""" | |
| w1, x1, y1, z1 = q1.unbind(dim=-1) | |
| w2, x2, y2, z2 = q2.unbind(dim=-1) | |
| w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 | |
| x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 | |
| y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 | |
| z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 | |
| return torch.stack((w, x, y, z), dim=-1) | |
| def rotate_vector(q, v): | |
| """Rotate vector v by quaternion q.""" | |
| q_vec = q[..., 1:] | |
| q_w = q[..., :1] | |
| t = 2.0 * torch.cross(q_vec, v, dim=-1) | |
| v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1) | |
| return v_rot | |
| def relative_pose_absT_quatR(t1, q1, t2, q2): | |
| """Compute the relative translation and quaternion between two poses.""" | |
| q1_inv = quaternion_conjugate(q1) | |
| q_rel = quaternion_multiply(q1_inv, q2) | |
| delta_t = t2 - t1 | |
| t_rel = rotate_vector(q1_inv, delta_t) | |
| return t_rel, q_rel | |