Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Multi-view geometric losses for training 3D reconstruction models. | |
| References: DUSt3R & MASt3R | |
| """ | |
| import math | |
| from copy import copy, deepcopy | |
| import einops as ein | |
| import torch | |
| import torch.nn as nn | |
| from mapanything.utils.geometry import ( | |
| angle_diff_vec3, | |
| apply_log_to_norm, | |
| closed_form_pose_inverse, | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap, | |
| geotrf, | |
| normalize_multiple_pointclouds, | |
| quaternion_inverse, | |
| quaternion_multiply, | |
| quaternion_to_rotation_matrix, | |
| transform_pose_using_quats_and_trans_2_to_1, | |
| ) | |
| def get_loss_terms_and_details( | |
| losses_dict, valid_masks, self_name, n_views, flatten_across_image_only | |
| ): | |
| """ | |
| Helper function to generate loss terms and details for different loss types. | |
| Args: | |
| losses_dict (dict): Dictionary mapping loss types to their values. | |
| Format: { | |
| 'loss_type': { | |
| 'values': list_of_loss_tensors or single_tensor, | |
| 'use_mask': bool, | |
| 'is_multi_view': bool | |
| } | |
| } | |
| valid_masks (list): List of valid masks for each view. | |
| self_name (str): Name of the loss class. | |
| n_views (int): Number of views. | |
| flatten_across_image_only (bool): Whether flattening was done across image only. | |
| Returns: | |
| tuple: (loss_terms, details) where loss_terms is a list of tuples (loss, mask, type) | |
| and details is a dictionary of loss details. | |
| """ | |
| loss_terms = [] | |
| details = {} | |
| for loss_type, loss_info in losses_dict.items(): | |
| values = loss_info["values"] | |
| use_mask = loss_info["use_mask"] | |
| is_multi_view = loss_info["is_multi_view"] | |
| if is_multi_view: | |
| # Handle multi-view losses (list of tensors) | |
| view_loss_details = [] | |
| for i in range(n_views): | |
| mask = valid_masks[i] if use_mask else None | |
| loss_terms.append((values[i], mask, loss_type)) | |
| # Add details for individual view | |
| if not flatten_across_image_only or not use_mask: | |
| values_after_masking = values[i] | |
| else: | |
| values_after_masking = values[i][mask] | |
| if values_after_masking.numel() > 0: | |
| view_loss_detail = float(values_after_masking.mean()) | |
| if view_loss_detail > 0: | |
| details[f"{self_name}_{loss_type}_view{i + 1}"] = ( | |
| view_loss_detail | |
| ) | |
| view_loss_details.append(view_loss_detail) | |
| # Add average across views | |
| if len(view_loss_details) > 0: | |
| details[f"{self_name}_{loss_type}_avg"] = sum(view_loss_details) / len( | |
| view_loss_details | |
| ) | |
| else: | |
| # Handle single tensor losses | |
| if values is not None: | |
| loss_terms.append((values, None, loss_type)) | |
| if values.numel() > 0: | |
| loss_detail = float(values.mean()) | |
| if loss_detail > 0: | |
| details[f"{self_name}_{loss_type}"] = loss_detail | |
| return loss_terms, details | |
| def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor: | |
| if beta == 0: | |
| return err | |
| else: | |
| return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta) | |
| def compute_normal_loss(points, gt_points, mask): | |
| """ | |
| Compute the normal loss between the predicted and ground truth points. | |
| References: | |
| https://github.com/microsoft/MoGe/blob/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/train/losses.py#L205 | |
| Args: | |
| points (torch.Tensor): Predicted points. Shape: (..., H, W, 3). | |
| gt_points (torch.Tensor): Ground truth points. Shape: (..., H, W, 3). | |
| mask (torch.Tensor): Mask indicating valid points. Shape: (..., H, W). | |
| Returns: | |
| torch.Tensor: Normal loss. | |
| """ | |
| height, width = points.shape[-3:-1] | |
| leftup, rightup, leftdown, rightdown = ( | |
| points[..., :-1, :-1, :], | |
| points[..., :-1, 1:, :], | |
| points[..., 1:, :-1, :], | |
| points[..., 1:, 1:, :], | |
| ) | |
| upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1) | |
| leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1) | |
| downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1) | |
| rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1) | |
| gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = ( | |
| gt_points[..., :-1, :-1, :], | |
| gt_points[..., :-1, 1:, :], | |
| gt_points[..., 1:, :-1, :], | |
| gt_points[..., 1:, 1:, :], | |
| ) | |
| gt_upxleft = torch.cross( | |
| gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1 | |
| ) | |
| gt_leftxdown = torch.cross( | |
| gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1 | |
| ) | |
| gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1) | |
| gt_rightxup = torch.cross( | |
| gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1 | |
| ) | |
| mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = ( | |
| mask[..., :-1, :-1], | |
| mask[..., :-1, 1:], | |
| mask[..., 1:, :-1], | |
| mask[..., 1:, 1:], | |
| ) | |
| mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown | |
| mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup | |
| mask_downxright = mask_leftdown & mask_rightup & mask_leftup | |
| mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown | |
| MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3) | |
| loss = ( | |
| mask_upxleft | |
| * _smooth( | |
| angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), | |
| beta=BETA_RAD, | |
| ) | |
| + mask_leftxdown | |
| * _smooth( | |
| angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), | |
| beta=BETA_RAD, | |
| ) | |
| + mask_downxright | |
| * _smooth( | |
| angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), | |
| beta=BETA_RAD, | |
| ) | |
| + mask_rightxup | |
| * _smooth( | |
| angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), | |
| beta=BETA_RAD, | |
| ) | |
| ) | |
| total_valid_mask = mask_upxleft | mask_leftxdown | mask_downxright | mask_rightxup | |
| valid_count = total_valid_mask.sum() | |
| if valid_count > 0: | |
| loss = loss.sum() / (valid_count * (4 * max(points.shape[-3:-1]))) | |
| else: | |
| loss = 0 * loss.sum() | |
| return loss | |
| def compute_gradient_loss(prediction, gt_target, mask): | |
| """ | |
| Compute the gradient loss between the prediction and GT target at valid points. | |
| References: | |
| https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss | |
| https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py | |
| Args: | |
| prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C). | |
| gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C). | |
| mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W). | |
| """ | |
| # Expand mask to match number of channels in prediction | |
| mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1]) | |
| summed_mask = torch.sum(mask, (1, 2, 3)) | |
| # Compute the gradient of the prediction and GT target | |
| diff = prediction - gt_target | |
| diff = torch.mul(mask, diff) | |
| # Gradient in x direction | |
| grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) | |
| mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) | |
| grad_x = torch.mul(mask_x, grad_x) | |
| # Gradient in y direction | |
| grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) | |
| mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) | |
| grad_y = torch.mul(mask_y, grad_y) | |
| # Clamp the outlier gradients | |
| grad_x = grad_x.clamp(max=100) | |
| grad_y = grad_y.clamp(max=100) | |
| # Compute the total loss | |
| image_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3)) | |
| num_valid_pixels = torch.sum(summed_mask) | |
| if num_valid_pixels > 0: | |
| image_loss = torch.sum(image_loss) / num_valid_pixels | |
| else: | |
| image_loss = 0 * torch.sum(image_loss) | |
| return image_loss | |
| def compute_gradient_matching_loss(prediction, gt_target, mask, scales=4): | |
| """ | |
| Compute the multi-scale gradient matching loss between the prediction and GT target at valid points. | |
| This loss biases discontinuities to be sharp and to coincide with discontinuities in the ground truth. | |
| More info in MiDAS: https://arxiv.org/pdf/1907.01341.pdf; Equation 11 | |
| References: | |
| https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss | |
| https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py | |
| Args: | |
| prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C). | |
| gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C). | |
| mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W). | |
| scales (int): Number of scales to compute the loss at. Default: 4. | |
| """ | |
| # Define total loss | |
| total_loss = 0.0 | |
| # Compute the gradient loss at different scales | |
| for scale in range(scales): | |
| step = pow(2, scale) | |
| grad_loss = compute_gradient_loss( | |
| prediction[:, ::step, ::step], | |
| gt_target[:, ::step, ::step], | |
| mask[:, ::step, ::step], | |
| ) | |
| total_loss += grad_loss | |
| return total_loss | |
| def Sum(*losses_and_masks): | |
| """ | |
| Aggregates multiple losses into a single loss value or returns the original losses. | |
| Args: | |
| *losses_and_masks: Variable number of tuples, each containing (loss, mask, rep_type) | |
| - loss: Tensor containing loss values | |
| - mask: Mask indicating valid pixels/regions | |
| - rep_type: String indicating the type of representation (e.g., 'pts3d', 'depth') | |
| Returns: | |
| If the first loss has dimensions > 0: | |
| Returns the original list of (loss, mask, rep_type) tuples | |
| Otherwise: | |
| Returns a scalar tensor that is the sum of all loss values | |
| """ | |
| loss, mask, rep_type = losses_and_masks[0] | |
| if loss.ndim > 0: | |
| # we are actually returning the loss for every pixels | |
| return losses_and_masks | |
| else: | |
| # we are returning the global loss | |
| for loss2, mask2, rep_type2 in losses_and_masks[1:]: | |
| loss = loss + loss2 | |
| return loss | |
| class BaseCriterion(nn.Module): | |
| "Base Criterion to support different reduction methods" | |
| def __init__(self, reduction="mean"): | |
| super().__init__() | |
| self.reduction = reduction | |
| class LLoss(BaseCriterion): | |
| "L-norm loss" | |
| def forward(self, a, b, **kwargs): | |
| assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 4, ( | |
| f"Bad shape = {a.shape}" | |
| ) | |
| dist = self.distance(a, b, **kwargs) | |
| assert dist.ndim == a.ndim - 1 # one dimension less | |
| if self.reduction == "none": | |
| return dist | |
| if self.reduction == "sum": | |
| return dist.sum() | |
| if self.reduction == "mean": | |
| return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) | |
| raise ValueError(f"bad {self.reduction=} mode") | |
| def distance(self, a, b, **kwargs): | |
| raise NotImplementedError() | |
| class L1Loss(LLoss): | |
| "L1 distance" | |
| def distance(self, a, b, **kwargs): | |
| return torch.abs(a - b).sum(dim=-1) | |
| class L2Loss(LLoss): | |
| "Euclidean (L2 Norm) distance" | |
| def distance(self, a, b, **kwargs): | |
| return torch.norm(a - b, dim=-1) | |
| class GenericLLoss(LLoss): | |
| "Criterion that supports different L-norms" | |
| def distance(self, a, b, loss_type, **kwargs): | |
| if loss_type == "l1": | |
| # L1 distance | |
| return torch.abs(a - b).sum(dim=-1) | |
| elif loss_type == "l2": | |
| # Euclidean (L2 norm) distance | |
| return torch.norm(a - b, dim=-1) | |
| else: | |
| raise ValueError( | |
| f"Unsupported loss type: {loss_type}. Supported types are 'l1' and 'l2'." | |
| ) | |
| class FactoredLLoss(LLoss): | |
| "Criterion that supports different L-norms for the factored loss functions" | |
| def __init__( | |
| self, | |
| reduction="mean", | |
| points_loss_type="l2", | |
| depth_loss_type="l1", | |
| ray_directions_loss_type="l1", | |
| pose_quats_loss_type="l1", | |
| pose_trans_loss_type="l1", | |
| scale_loss_type="l1", | |
| ): | |
| super().__init__(reduction) | |
| self.points_loss_type = points_loss_type | |
| self.depth_loss_type = depth_loss_type | |
| self.ray_directions_loss_type = ray_directions_loss_type | |
| self.pose_quats_loss_type = pose_quats_loss_type | |
| self.pose_trans_loss_type = pose_trans_loss_type | |
| self.scale_loss_type = scale_loss_type | |
| def _distance(self, a, b, loss_type): | |
| if loss_type == "l1": | |
| # L1 distance | |
| return torch.abs(a - b).sum(dim=-1) | |
| elif loss_type == "l2": | |
| # Euclidean (L2 norm) distance | |
| return torch.norm(a - b, dim=-1) | |
| else: | |
| raise ValueError(f"Unsupported loss type: {loss_type}.") | |
| def distance(self, a, b, factor, **kwargs): | |
| if factor == "points": | |
| return self._distance(a, b, self.points_loss_type) | |
| elif factor == "depth": | |
| return self._distance(a, b, self.depth_loss_type) | |
| elif factor == "ray_directions": | |
| return self._distance(a, b, self.ray_directions_loss_type) | |
| elif factor == "pose_quats": | |
| return self._distance(a, b, self.pose_quats_loss_type) | |
| elif factor == "pose_trans": | |
| return self._distance(a, b, self.pose_trans_loss_type) | |
| elif factor == "scale": | |
| return self._distance(a, b, self.scale_loss_type) | |
| else: | |
| raise ValueError(f"Unsupported factor type: {factor}.") | |
| class RobustRegressionLoss(LLoss): | |
| """ | |
| Generalized Robust Loss introduced in https://arxiv.org/abs/1701.03077. | |
| """ | |
| def __init__(self, alpha=0.5, scaling_c=0.25, reduction="mean"): | |
| """ | |
| Initialize the Robust Regression Loss. | |
| Args: | |
| alpha (float): Shape parameter controlling the robustness of the loss. | |
| Lower values make the loss more robust to outliers. Default: 0.5. | |
| scaling_c (float): Scale parameter controlling the transition between | |
| quadratic and robust behavior. Default: 0.1. | |
| reduction (str): Specifies the reduction to apply to the output: | |
| 'none' | 'mean' | 'sum'. Default: 'mean'. | |
| """ | |
| super().__init__(reduction) | |
| self.alpha = alpha | |
| self.scaling_c = scaling_c | |
| def distance(self, a, b, **kwargs): | |
| error_scaled = torch.sum(((a - b) / self.scaling_c) ** 2, dim=-1) | |
| robust_loss = (abs(self.alpha - 2) / self.alpha) * ( | |
| torch.pow((error_scaled / abs(self.alpha - 2)) + 1, self.alpha / 2) - 1 | |
| ) | |
| return robust_loss | |
| class BCELoss(BaseCriterion): | |
| """Binary Cross Entropy loss""" | |
| def forward(self, predicted_logits, reference_mask): | |
| """ | |
| Args: | |
| predicted_logits: (B, H, W) tensor of predicted logits for the mask | |
| reference_mask: (B, H, W) tensor of reference mask | |
| Returns: | |
| loss: scalar tensor of the BCE loss | |
| """ | |
| bce_loss = torch.nn.functional.binary_cross_entropy_with_logits( | |
| predicted_logits, reference_mask.float() | |
| ) | |
| return bce_loss | |
| class Criterion(nn.Module): | |
| """ | |
| Base class for all criterion modules that wrap a BaseCriterion. | |
| This class serves as a wrapper around BaseCriterion objects, providing | |
| additional functionality like naming and reduction mode control. | |
| Args: | |
| criterion (BaseCriterion): The base criterion to wrap. | |
| """ | |
| def __init__(self, criterion=None): | |
| super().__init__() | |
| assert isinstance(criterion, BaseCriterion), ( | |
| f"{criterion} is not a proper criterion!" | |
| ) | |
| self.criterion = copy(criterion) | |
| def get_name(self): | |
| """ | |
| Returns a string representation of this criterion. | |
| Returns: | |
| str: A string containing the class name and the wrapped criterion. | |
| """ | |
| return f"{type(self).__name__}({self.criterion})" | |
| def with_reduction(self, mode="none"): | |
| """ | |
| Creates a deep copy of this criterion with the specified reduction mode. | |
| This method recursively sets the reduction mode for this criterion and | |
| any chained MultiLoss criteria. | |
| Args: | |
| mode (str): The reduction mode to set. Default: "none". | |
| Returns: | |
| Criterion: A new criterion with the specified reduction mode. | |
| """ | |
| res = loss = deepcopy(self) | |
| while loss is not None: | |
| assert isinstance(loss, Criterion) | |
| loss.criterion.reduction = mode # make it return the loss for each sample | |
| loss = loss._loss2 # we assume loss is a Multiloss | |
| return res | |
| class MultiLoss(nn.Module): | |
| """ | |
| Base class for combinable loss functions with automatic tracking of individual loss values. | |
| This class enables easy combination of multiple loss functions through arithmetic operations: | |
| loss = MyLoss1() + 0.1*MyLoss2() | |
| The combined loss functions maintain their individual weights and the forward pass | |
| automatically computes and aggregates all losses while tracking individual loss values. | |
| Usage: | |
| Inherit from this class and override get_name() and compute_loss() methods. | |
| Attributes: | |
| _alpha (float): Weight multiplier for this loss component. | |
| _loss2 (MultiLoss): Reference to the next loss in the chain, if any. | |
| """ | |
| def __init__(self): | |
| """Initialize the MultiLoss with default weight of 1 and no chained loss.""" | |
| super().__init__() | |
| self._alpha = 1 | |
| self._loss2 = None | |
| def compute_loss(self, *args, **kwargs): | |
| """ | |
| Compute the loss value for this specific loss component. | |
| Args: | |
| *args: Variable length argument list. | |
| **kwargs: Arbitrary keyword arguments. | |
| Returns: | |
| torch.Tensor or tuple: Either the loss tensor or a tuple of (loss, details_dict). | |
| Raises: | |
| NotImplementedError: This method must be implemented by subclasses. | |
| """ | |
| raise NotImplementedError() | |
| def get_name(self): | |
| """ | |
| Get the name of this loss component. | |
| Returns: | |
| str: The name of the loss. | |
| Raises: | |
| NotImplementedError: This method must be implemented by subclasses. | |
| """ | |
| raise NotImplementedError() | |
| def __mul__(self, alpha): | |
| """ | |
| Multiply the loss by a scalar weight. | |
| Args: | |
| alpha (int or float): The weight to multiply the loss by. | |
| Returns: | |
| MultiLoss: A new loss object with the updated weight. | |
| Raises: | |
| AssertionError: If alpha is not a number. | |
| """ | |
| assert isinstance(alpha, (int, float)) | |
| res = copy(self) | |
| res._alpha = alpha | |
| return res | |
| __rmul__ = __mul__ # Support both loss*alpha and alpha*loss | |
| def __add__(self, loss2): | |
| """ | |
| Add another loss to this loss, creating a chain of losses. | |
| Args: | |
| loss2 (MultiLoss): Another loss to add to this one. | |
| Returns: | |
| MultiLoss: A new loss object representing the combined losses. | |
| Raises: | |
| AssertionError: If loss2 is not a MultiLoss. | |
| """ | |
| assert isinstance(loss2, MultiLoss) | |
| res = cur = copy(self) | |
| # Find the end of the chain | |
| while cur._loss2 is not None: | |
| cur = cur._loss2 | |
| cur._loss2 = loss2 | |
| return res | |
| def __repr__(self): | |
| """ | |
| Create a string representation of the loss, including weights and chained losses. | |
| Returns: | |
| str: String representation of the loss. | |
| """ | |
| name = self.get_name() | |
| if self._alpha != 1: | |
| name = f"{self._alpha:g}*{name}" | |
| if self._loss2: | |
| name = f"{name} + {self._loss2}" | |
| return name | |
| def forward(self, *args, **kwargs): | |
| """ | |
| Compute the weighted loss and aggregate with any chained losses. | |
| Args: | |
| *args: Variable length argument list. | |
| **kwargs: Arbitrary keyword arguments. | |
| Returns: | |
| tuple: A tuple containing: | |
| - torch.Tensor: The total weighted loss. | |
| - dict: Details about individual loss components. | |
| """ | |
| loss = self.compute_loss(*args, **kwargs) | |
| if isinstance(loss, tuple): | |
| loss, details = loss | |
| elif loss.ndim == 0: | |
| details = {self.get_name(): float(loss)} | |
| else: | |
| details = {} | |
| loss = loss * self._alpha | |
| if self._loss2: | |
| loss2, details2 = self._loss2(*args, **kwargs) | |
| loss = loss + loss2 | |
| details |= details2 | |
| return loss, details | |
| class NonAmbiguousMaskLoss(Criterion, MultiLoss): | |
| """ | |
| Loss on non-ambiguous mask prediction logits. | |
| """ | |
| def __init__(self, criterion): | |
| super().__init__(criterion) | |
| def compute_loss(self, batch, preds, **kw): | |
| """ | |
| Args: | |
| batch: list of dicts with the gt data | |
| preds: list of dicts with the predictions | |
| Returns: | |
| loss: Sum class of the lossses for N-views and the loss details | |
| """ | |
| # Init loss list to keep track of individual losses for each view | |
| loss_list = [] | |
| mask_loss_details = {} | |
| mask_loss_total = 0 | |
| self_name = type(self).__name__ | |
| # Loop over the views | |
| for view_idx, (gt, pred) in enumerate(zip(batch, preds)): | |
| # Get the GT non-ambiguous masks | |
| gt_non_ambiguous_mask = gt["non_ambiguous_mask"] | |
| # Get the predicted non-ambiguous mask logits | |
| pred_non_ambiguous_mask_logits = pred["non_ambiguous_mask_logits"] | |
| # Compute the loss for the current view | |
| loss = self.criterion(pred_non_ambiguous_mask_logits, gt_non_ambiguous_mask) | |
| # Add the loss to the list | |
| loss_list.append((loss, None, "non_ambiguous_mask")) | |
| # Add the loss details to the dictionary | |
| mask_loss_details[f"{self_name}_mask_view{view_idx + 1}"] = float(loss) | |
| mask_loss_total += float(loss) | |
| # Compute the average loss across all views | |
| mask_loss_details[f"{self_name}_mask_avg"] = mask_loss_total / len(batch) | |
| return Sum(*loss_list), (mask_loss_details | {}) | |
| class ConfLoss(MultiLoss): | |
| """ | |
| Applies confidence-weighted regression loss using model-predicted confidence values. | |
| The confidence-weighted loss has the form: | |
| conf_loss = raw_loss * conf - alpha * log(conf) | |
| Where: | |
| - raw_loss is the original per-pixel loss | |
| - conf is the predicted confidence (higher values = higher confidence) | |
| - alpha is a hyperparameter controlling the regularization strength | |
| This loss can be selectively applied to specific loss components in factored and multi-view settings. | |
| """ | |
| def __init__(self, pixel_loss, alpha=1, loss_set_indices=None): | |
| """ | |
| Args: | |
| pixel_loss (MultiLoss): The pixel-level regression loss to be used. | |
| alpha (float): Hyperparameter controlling the confidence regularization strength. | |
| loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to. | |
| Each index selects a specific loss set across all views (with the same rep_type). | |
| If None, defaults to [0] which applies to the first loss set only. | |
| """ | |
| super().__init__() | |
| assert alpha > 0 | |
| self.alpha = alpha | |
| self.pixel_loss = pixel_loss.with_reduction("none") | |
| self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices | |
| def get_name(self): | |
| return f"ConfLoss({self.pixel_loss})" | |
| def get_conf_log(self, x): | |
| return x, torch.log(x) | |
| def compute_loss(self, batch, preds, **kw): | |
| # Init loss list and details | |
| total_loss = 0 | |
| conf_loss_details = {} | |
| running_avg_dict = {} | |
| self_name = type(self.pixel_loss).__name__ | |
| n_views = len(batch) | |
| # Compute per-pixel loss for each view | |
| losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw) | |
| # Select specific loss sets based on indices | |
| selected_losses = [] | |
| processed_indices = set() | |
| for idx in self.loss_set_indices: | |
| start_idx = idx * n_views | |
| end_idx = min((idx + 1) * n_views, len(losses)) | |
| selected_losses.extend(losses[start_idx:end_idx]) | |
| processed_indices.update(range(start_idx, end_idx)) | |
| # Process selected losses with confidence weighting | |
| for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses): | |
| view_idx = loss_idx % n_views # Map to corresponding view index | |
| if loss.numel() == 0: | |
| # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) | |
| continue | |
| # Get the confidence and log confidence | |
| if ( | |
| hasattr(self.pixel_loss, "flatten_across_image_only") | |
| and self.pixel_loss.flatten_across_image_only | |
| ): | |
| # Reshape confidence to match the flattened dimensions | |
| conf_reshaped = preds[view_idx]["conf"].view( | |
| preds[view_idx]["conf"].shape[0], -1 | |
| ) | |
| conf, log_conf = self.get_conf_log(conf_reshaped[msk]) | |
| loss = loss[msk] | |
| else: | |
| conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk]) | |
| # Weight the loss by the confidence | |
| conf_loss = loss * conf - self.alpha * log_conf | |
| # Only add to total loss and store details if there are valid elements | |
| if conf_loss.numel() > 0: | |
| conf_loss = conf_loss.mean() | |
| total_loss = total_loss + conf_loss | |
| # Store details | |
| conf_loss_details[ | |
| f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}" | |
| ] = float(conf_loss) | |
| # Initialize or update running average directly | |
| avg_key = f"{self_name}_{rep_type}_conf_loss_avg" | |
| if avg_key not in conf_loss_details: | |
| conf_loss_details[avg_key] = float(conf_loss) | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_conf_loss_valid_views" | |
| ] = 1 | |
| else: | |
| valid_views = ( | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_conf_loss_valid_views" | |
| ] | |
| + 1 | |
| ) | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_conf_loss_valid_views" | |
| ] = valid_views | |
| conf_loss_details[avg_key] += ( | |
| float(conf_loss) - conf_loss_details[avg_key] | |
| ) / valid_views | |
| # Add unmodified losses for sets not in selected_losses | |
| for idx, (loss, msk, rep_type) in enumerate(losses): | |
| if idx not in processed_indices: | |
| if msk is not None: | |
| loss_after_masking = loss[msk] | |
| else: | |
| loss_after_masking = loss | |
| if loss_after_masking.numel() > 0: | |
| loss_mean = loss_after_masking.mean() | |
| else: | |
| # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) | |
| loss_mean = 0 | |
| total_loss = total_loss + loss_mean | |
| return total_loss, dict(**conf_loss_details, **pixel_loss_details) | |
| class ExcludeTopNPercentPixelLoss(MultiLoss): | |
| """ | |
| Pixel-level regression loss where for each instance in a batch the top N% of per-pixel loss values are ignored | |
| for the mean loss computation. | |
| Allows selecting which pixel-level regression loss sets to apply the exclusion to. | |
| """ | |
| def __init__( | |
| self, | |
| pixel_loss, | |
| top_n_percent=5, | |
| apply_to_real_data_only=True, | |
| loss_set_indices=None, | |
| ): | |
| """ | |
| Args: | |
| pixel_loss (MultiLoss): The pixel-level regression loss to be used. | |
| top_n_percent (float): The percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5. | |
| apply_to_real_data_only (bool): Whether to apply the loss only to real world data. Default: True. | |
| loss_set_indices (list or None): Indices of the loss sets to apply the exclusion to. | |
| Each index selects a specific loss set across all views (with the same rep_type). | |
| If None, defaults to [0] which applies to the first loss set only. | |
| """ | |
| super().__init__() | |
| self.pixel_loss = pixel_loss.with_reduction("none") | |
| self.top_n_percent = top_n_percent | |
| self.bottom_n_percent = 100 - top_n_percent | |
| self.apply_to_real_data_only = apply_to_real_data_only | |
| self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices | |
| def get_name(self): | |
| return f"ExcludeTopNPercentPixelLoss({self.pixel_loss})" | |
| def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent): | |
| """ | |
| Function to compute the mask for keeping the bottom n percent of per-pixel loss values. | |
| Args: | |
| tensor (torch.Tensor): The tensor containing the per-pixel loss values. | |
| Shape: (B, N) where B is the batch size and N is the number of total pixels. | |
| mask (torch.Tensor): The mask indicating valid pixels. Shape: (B, N). | |
| Returns: | |
| torch.Tensor: Flattened tensor containing the bottom n percent of per-pixel loss values. | |
| """ | |
| B, N = tensor.shape | |
| # Calculate the number of valid elements (where mask is True) | |
| num_valid = mask.sum(dim=1) | |
| # Calculate the number of elements to keep (n% of valid elements) | |
| num_keep = (num_valid * bottom_n_percent / 100).long() | |
| # Create a mask for the bottom n% elements | |
| keep_mask = torch.arange(N, device=tensor.device).unsqueeze( | |
| 0 | |
| ) < num_keep.unsqueeze(1) | |
| # Create a tensor with inf where mask is False | |
| masked_tensor = torch.where( | |
| mask, tensor, torch.tensor(float("inf"), device=tensor.device) | |
| ) | |
| # Sort the masked tensor along the N dimension | |
| sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False) | |
| # Get the bottom n% elements | |
| bottom_n_percent_elements = sorted_tensor[keep_mask] | |
| return bottom_n_percent_elements | |
| def compute_loss(self, batch, preds, **kw): | |
| # Compute per-pixel loss | |
| losses, details = self.pixel_loss(batch, preds, **kw) | |
| n_views = len(batch) | |
| # Select specific loss sets based on indices | |
| selected_losses = [] | |
| processed_indices = set() | |
| for idx in self.loss_set_indices: | |
| start_idx = idx * n_views | |
| end_idx = min((idx + 1) * n_views, len(losses)) | |
| selected_losses.extend(losses[start_idx:end_idx]) | |
| processed_indices.update(range(start_idx, end_idx)) | |
| # Initialize total loss | |
| total_loss = 0.0 | |
| loss_details = {} | |
| running_avg_dict = {} | |
| self_name = type(self.pixel_loss).__name__ | |
| # Process selected losses with top N percent exclusion | |
| for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses): | |
| view_idx = loss_idx % n_views # Map to corresponding view index | |
| if loss.numel() == 0: | |
| # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) | |
| continue | |
| # Create empty list for current view's aggregated tensors | |
| aggregated_losses = [] | |
| if self.apply_to_real_data_only: | |
| # Get the synthetic and real world data mask | |
| synthetic_mask = batch[view_idx]["is_synthetic"] | |
| real_data_mask = ~batch[view_idx]["is_synthetic"] | |
| else: | |
| # Apply the filtering to all data | |
| synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"]) | |
| real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"]) | |
| # Process synthetic data | |
| if synthetic_mask.any(): | |
| synthetic_loss = loss[synthetic_mask] | |
| synthetic_msk = msk[synthetic_mask] | |
| aggregated_losses.append(synthetic_loss[synthetic_msk]) | |
| # Process real data | |
| if real_data_mask.any(): | |
| real_loss = loss[real_data_mask] | |
| real_msk = msk[real_data_mask] | |
| real_bottom_n_percent_loss = self.keep_bottom_n_percent( | |
| real_loss, real_msk, self.bottom_n_percent | |
| ) | |
| aggregated_losses.append(real_bottom_n_percent_loss) | |
| # Compute view loss | |
| view_loss = torch.cat(aggregated_losses, dim=0) | |
| # Only add to total loss and store details if there are valid elements | |
| if view_loss.numel() > 0: | |
| view_loss = view_loss.mean() | |
| total_loss = total_loss + view_loss | |
| # Store details | |
| loss_details[ | |
| f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}" | |
| ] = float(view_loss) | |
| # Initialize or update running average directly | |
| avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg" | |
| if avg_key not in loss_details: | |
| loss_details[avg_key] = float(view_loss) | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" | |
| ] = 1 | |
| else: | |
| valid_views = ( | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" | |
| ] | |
| + 1 | |
| ) | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" | |
| ] = valid_views | |
| loss_details[avg_key] += ( | |
| float(view_loss) - loss_details[avg_key] | |
| ) / valid_views | |
| # Add unmodified losses for sets not in selected_losses | |
| for idx, (loss, msk, rep_type) in enumerate(losses): | |
| if idx not in processed_indices: | |
| if msk is not None: | |
| loss_after_masking = loss[msk] | |
| else: | |
| loss_after_masking = loss | |
| if loss_after_masking.numel() > 0: | |
| loss_mean = loss_after_masking.mean() | |
| else: | |
| # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) | |
| loss_mean = 0 | |
| total_loss = total_loss + loss_mean | |
| return total_loss, dict(**loss_details, **details) | |
| class ConfAndExcludeTopNPercentPixelLoss(MultiLoss): | |
| """ | |
| Combined loss that applies ConfLoss to one set of pixel-level regression losses | |
| and ExcludeTopNPercentPixelLoss to another set of pixel-level regression losses. | |
| """ | |
| def __init__( | |
| self, | |
| pixel_loss, | |
| conf_alpha=1, | |
| top_n_percent=5, | |
| apply_to_real_data_only=True, | |
| conf_loss_set_indices=None, | |
| exclude_loss_set_indices=None, | |
| ): | |
| """ | |
| Args: | |
| pixel_loss (MultiLoss): The pixel-level regression loss to be used. | |
| conf_alpha (float): Alpha parameter for ConfLoss. Default: 1. | |
| top_n_percent (float): Percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5. | |
| apply_to_real_data_only (bool): Whether to apply the exclude loss only to real world data. Default: True. | |
| conf_loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to. | |
| Each index selects a specific loss set across all views (with the same rep_type). | |
| If None, defaults to [0] which applies to the first loss set only. | |
| exclude_loss_set_indices (list or None): Indices of the loss sets to apply top N percent exclusion to. | |
| Each index selects a specific loss set across all views (with the same rep_type). | |
| If None, defaults to [1] which applies to the second loss set only. | |
| """ | |
| super().__init__() | |
| self.pixel_loss = pixel_loss.with_reduction("none") | |
| assert conf_alpha > 0 | |
| self.conf_alpha = conf_alpha | |
| self.top_n_percent = top_n_percent | |
| self.bottom_n_percent = 100 - top_n_percent | |
| self.apply_to_real_data_only = apply_to_real_data_only | |
| self.conf_loss_set_indices = ( | |
| [0] if conf_loss_set_indices is None else conf_loss_set_indices | |
| ) | |
| self.exclude_loss_set_indices = ( | |
| [1] if exclude_loss_set_indices is None else exclude_loss_set_indices | |
| ) | |
| def get_name(self): | |
| return f"ConfAndExcludeTopNPercentPixelLoss({self.pixel_loss})" | |
| def get_conf_log(self, x): | |
| return x, torch.log(x) | |
| def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent): | |
| """ | |
| Function to compute the mask for keeping the bottom n percent of per-pixel loss values. | |
| """ | |
| B, N = tensor.shape | |
| # Calculate the number of valid elements (where mask is True) | |
| num_valid = mask.sum(dim=1) | |
| # Calculate the number of elements to keep (n% of valid elements) | |
| num_keep = (num_valid * bottom_n_percent / 100).long() | |
| # Create a mask for the bottom n% elements | |
| keep_mask = torch.arange(N, device=tensor.device).unsqueeze( | |
| 0 | |
| ) < num_keep.unsqueeze(1) | |
| # Create a tensor with inf where mask is False | |
| masked_tensor = torch.where( | |
| mask, tensor, torch.tensor(float("inf"), device=tensor.device) | |
| ) | |
| # Sort the masked tensor along the N dimension | |
| sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False) | |
| # Get the bottom n% elements | |
| bottom_n_percent_elements = sorted_tensor[keep_mask] | |
| return bottom_n_percent_elements | |
| def compute_loss(self, batch, preds, **kw): | |
| # Compute per-pixel loss | |
| losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw) | |
| n_views = len(batch) | |
| # Select specific loss sets for confidence weighting | |
| conf_selected_losses = [] | |
| conf_processed_indices = set() | |
| for idx in self.conf_loss_set_indices: | |
| start_idx = idx * n_views | |
| end_idx = min((idx + 1) * n_views, len(losses)) | |
| conf_selected_losses.extend(losses[start_idx:end_idx]) | |
| conf_processed_indices.update(range(start_idx, end_idx)) | |
| # Select specific loss sets for top N percent exclusion | |
| exclude_selected_losses = [] | |
| exclude_processed_indices = set() | |
| for idx in self.exclude_loss_set_indices: | |
| start_idx = idx * n_views | |
| end_idx = min((idx + 1) * n_views, len(losses)) | |
| exclude_selected_losses.extend(losses[start_idx:end_idx]) | |
| exclude_processed_indices.update(range(start_idx, end_idx)) | |
| # Initialize total loss and details | |
| total_loss = 0 | |
| loss_details = {} | |
| running_avg_dict = {} | |
| self_name = type(self.pixel_loss).__name__ | |
| # Process selected losses with confidence weighting | |
| for loss_idx, (loss, msk, rep_type) in enumerate(conf_selected_losses): | |
| view_idx = loss_idx % n_views # Map to corresponding view index | |
| if loss.numel() == 0: | |
| # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for conf loss", force=True) | |
| continue | |
| # Get the confidence and log confidence | |
| if ( | |
| hasattr(self.pixel_loss, "flatten_across_image_only") | |
| and self.pixel_loss.flatten_across_image_only | |
| ): | |
| # Reshape confidence to match the flattened dimensions | |
| conf_reshaped = preds[view_idx]["conf"].view( | |
| preds[view_idx]["conf"].shape[0], -1 | |
| ) | |
| conf, log_conf = self.get_conf_log(conf_reshaped[msk]) | |
| loss = loss[msk] | |
| else: | |
| conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk]) | |
| # Weight the loss by the confidence | |
| conf_loss = loss * conf - self.conf_alpha * log_conf | |
| # Only add to total loss and store details if there are valid elements | |
| if conf_loss.numel() > 0: | |
| conf_loss = conf_loss.mean() | |
| total_loss = total_loss + conf_loss | |
| # Store details | |
| loss_details[f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"] = ( | |
| float(conf_loss) | |
| ) | |
| # Initialize or update running average directly | |
| avg_key = f"{self_name}_{rep_type}_conf_loss_avg" | |
| if avg_key not in loss_details: | |
| loss_details[avg_key] = float(conf_loss) | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_conf_loss_valid_views" | |
| ] = 1 | |
| else: | |
| valid_views = ( | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_conf_loss_valid_views" | |
| ] | |
| + 1 | |
| ) | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_conf_loss_valid_views" | |
| ] = valid_views | |
| loss_details[avg_key] += ( | |
| float(conf_loss) - loss_details[avg_key] | |
| ) / valid_views | |
| # Process selected losses with top N percent exclusion | |
| for loss_idx, (loss, msk, rep_type) in enumerate(exclude_selected_losses): | |
| view_idx = loss_idx % n_views # Map to corresponding view index | |
| if loss.numel() == 0: | |
| # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for exclude loss", force=True) | |
| continue | |
| # Create empty list for current view's aggregated tensors | |
| aggregated_losses = [] | |
| if self.apply_to_real_data_only: | |
| # Get the synthetic and real world data mask | |
| synthetic_mask = batch[view_idx]["is_synthetic"] | |
| real_data_mask = ~batch[view_idx]["is_synthetic"] | |
| else: | |
| # Apply the filtering to all data | |
| synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"]) | |
| real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"]) | |
| # Process synthetic data | |
| if synthetic_mask.any(): | |
| synthetic_loss = loss[synthetic_mask] | |
| synthetic_msk = msk[synthetic_mask] | |
| aggregated_losses.append(synthetic_loss[synthetic_msk]) | |
| # Process real data | |
| if real_data_mask.any(): | |
| real_loss = loss[real_data_mask] | |
| real_msk = msk[real_data_mask] | |
| real_bottom_n_percent_loss = self.keep_bottom_n_percent( | |
| real_loss, real_msk, self.bottom_n_percent | |
| ) | |
| aggregated_losses.append(real_bottom_n_percent_loss) | |
| # Compute view loss | |
| view_loss = torch.cat(aggregated_losses, dim=0) | |
| # Only add to total loss and store details if there are valid elements | |
| if view_loss.numel() > 0: | |
| view_loss = view_loss.mean() | |
| total_loss = total_loss + view_loss | |
| # Store details | |
| loss_details[ | |
| f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}" | |
| ] = float(view_loss) | |
| # Initialize or update running average directly | |
| avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg" | |
| if avg_key not in loss_details: | |
| loss_details[avg_key] = float(view_loss) | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" | |
| ] = 1 | |
| else: | |
| valid_views = ( | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" | |
| ] | |
| + 1 | |
| ) | |
| running_avg_dict[ | |
| f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views" | |
| ] = valid_views | |
| loss_details[avg_key] += ( | |
| float(view_loss) - loss_details[avg_key] | |
| ) / valid_views | |
| # Add unmodified losses for sets not processed with either confidence or exclusion | |
| all_processed_indices = conf_processed_indices.union(exclude_processed_indices) | |
| for idx, (loss, msk, rep_type) in enumerate(losses): | |
| if idx not in all_processed_indices: | |
| if msk is not None: | |
| loss_after_masking = loss[msk] | |
| else: | |
| loss_after_masking = loss | |
| if loss_after_masking.numel() > 0: | |
| loss_mean = loss_after_masking.mean() | |
| else: | |
| # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True) | |
| loss_mean = 0 | |
| total_loss = total_loss + loss_mean | |
| return total_loss, dict(**loss_details, **pixel_loss_details) | |
| class Regr3D(Criterion, MultiLoss): | |
| """ | |
| Regression Loss for World Frame Pointmaps. | |
| Asymmetric loss where view 1 is supposed to be the anchor. | |
| For each view i: | |
| Pi = RTi @ Di | |
| lossi = (RTi1 @ pred_Di) - (RT1^-1 @ RTi @ Di) | |
| where RT1 is the anchor view camera pose | |
| """ | |
| def __init__( | |
| self, | |
| criterion, | |
| norm_mode="?avg_dis", | |
| gt_scale=False, | |
| ambiguous_loss_value=0, | |
| max_metric_scale=False, | |
| loss_in_log=True, | |
| flatten_across_image_only=False, | |
| ): | |
| """ | |
| Initialize the loss criterion for World Frame Pointmaps. | |
| Args: | |
| criterion (BaseCriterion): The base criterion to use for computing the loss. | |
| norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis". | |
| If prefixed with "?", normalization is only applied to non-metric scale data. | |
| gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. | |
| If False, both GT and predictions are normalized independently. Default: False. | |
| ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. | |
| If 0, ambiguous pixels are ignored. Default: 0. | |
| max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this | |
| value, it will be treated as non-metric. Default: False (no limit). | |
| loss_in_log (bool): If True, apply logarithmic transformation to input before | |
| computing the loss for pointmaps. Default: True. | |
| flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing | |
| the loss. If False, flatten across batch and spatial dimensions. Default: False. | |
| """ | |
| super().__init__(criterion) | |
| if norm_mode.startswith("?"): | |
| # Do no norm pts from metric scale datasets | |
| self.norm_all = False | |
| self.norm_mode = norm_mode[1:] | |
| else: | |
| self.norm_all = True | |
| self.norm_mode = norm_mode | |
| self.gt_scale = gt_scale | |
| self.ambiguous_loss_value = ambiguous_loss_value | |
| self.max_metric_scale = max_metric_scale | |
| self.loss_in_log = loss_in_log | |
| self.flatten_across_image_only = flatten_across_image_only | |
| def get_all_info(self, batch, preds, dist_clip=None): | |
| n_views = len(batch) | |
| in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) | |
| # Initialize lists to store points and masks | |
| no_norm_gt_pts = [] | |
| valid_masks = [] | |
| # Process ground truth points and valid masks | |
| for view_idx in range(n_views): | |
| no_norm_gt_pts.append( | |
| geotrf(in_camera0, batch[view_idx]["pts3d"]) | |
| ) # B,H,W,3 | |
| valid_masks.append(batch[view_idx]["valid_mask"].clone()) | |
| if dist_clip is not None: | |
| # Points that are too far-away == invalid | |
| for view_idx in range(n_views): | |
| dis = no_norm_gt_pts[view_idx].norm(dim=-1) # (B, H, W) | |
| valid_masks[view_idx] = valid_masks[view_idx] & (dis <= dist_clip) | |
| # Get predicted points | |
| no_norm_pr_pts = [] | |
| for view_idx in range(n_views): | |
| no_norm_pr_pts.append(preds[view_idx]["pts3d"]) | |
| if not self.norm_all: | |
| if self.max_metric_scale: | |
| B = valid_masks[0].shape[0] | |
| # Calculate distances to camera for all views | |
| dists_to_cam1 = [] | |
| for view_idx in range(n_views): | |
| dist = torch.where( | |
| valid_masks[view_idx], | |
| torch.norm(no_norm_gt_pts[view_idx], dim=-1), | |
| 0, | |
| ).view(B, -1) | |
| dists_to_cam1.append(dist) | |
| # Update metric scale flags | |
| metric_scale_mask = batch[0]["is_metric_scale"] | |
| for dist in dists_to_cam1: | |
| metric_scale_mask = metric_scale_mask & ( | |
| dist.max(dim=-1).values < self.max_metric_scale | |
| ) | |
| for view_idx in range(n_views): | |
| batch[view_idx]["is_metric_scale"] = metric_scale_mask | |
| non_metric_scale_mask = ~batch[0]["is_metric_scale"] | |
| else: | |
| non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"]) | |
| # Initialize normalized points | |
| gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] | |
| pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] | |
| # Normalize 3d points | |
| if self.norm_mode and non_metric_scale_mask.any(): | |
| normalized_pr_pts = normalize_multiple_pointclouds( | |
| [pts[non_metric_scale_mask] for pts in no_norm_pr_pts], | |
| [mask[non_metric_scale_mask] for mask in valid_masks], | |
| self.norm_mode, | |
| ) | |
| for i in range(n_views): | |
| pr_pts[i][non_metric_scale_mask] = normalized_pr_pts[i] | |
| elif non_metric_scale_mask.any(): | |
| for i in range(n_views): | |
| pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][ | |
| non_metric_scale_mask | |
| ] | |
| if self.norm_mode and not self.gt_scale: | |
| gt_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True | |
| ) | |
| normalized_gt_pts = gt_normalization_output[:-1] | |
| norm_factor = gt_normalization_output[-1] | |
| for i in range(n_views): | |
| gt_pts[i] = normalized_gt_pts[i] | |
| pr_pts[i][~non_metric_scale_mask] = ( | |
| no_norm_pr_pts[i][~non_metric_scale_mask] | |
| / norm_factor[~non_metric_scale_mask] | |
| ) | |
| elif ~non_metric_scale_mask.any(): | |
| for i in range(n_views): | |
| gt_pts[i] = no_norm_gt_pts[i] | |
| pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][ | |
| ~non_metric_scale_mask | |
| ] | |
| else: | |
| for i in range(n_views): | |
| gt_pts[i] = no_norm_gt_pts[i] | |
| # Get ambiguous masks | |
| ambiguous_masks = [] | |
| for view_idx in range(n_views): | |
| ambiguous_masks.append( | |
| (~batch[view_idx]["non_ambiguous_mask"]) & (~valid_masks[view_idx]) | |
| ) | |
| return gt_pts, pr_pts, valid_masks, ambiguous_masks, {} | |
| def compute_loss(self, batch, preds, **kw): | |
| gt_pts, pred_pts, masks, ambiguous_masks, monitoring = self.get_all_info( | |
| batch, preds, **kw | |
| ) | |
| n_views = len(batch) | |
| if self.ambiguous_loss_value > 0: | |
| assert self.criterion.reduction == "none", ( | |
| "ambiguous_loss_value should be 0 if no conf loss" | |
| ) | |
| # Add the ambiguous pixels as "valid" pixels | |
| masks = [mask | amb_mask for mask, amb_mask in zip(masks, ambiguous_masks)] | |
| losses = [] | |
| details = {} | |
| running_avg_dict = {} | |
| self_name = type(self).__name__ | |
| if not self.flatten_across_image_only: | |
| for view_idx in range(n_views): | |
| pred = pred_pts[view_idx][masks[view_idx]] | |
| gt = gt_pts[view_idx][masks[view_idx]] | |
| if self.loss_in_log: | |
| pred = apply_log_to_norm(pred) | |
| gt = apply_log_to_norm(gt) | |
| loss = self.criterion(pred, gt) | |
| if self.ambiguous_loss_value > 0: | |
| loss = torch.where( | |
| ambiguous_masks[view_idx][masks[view_idx]], | |
| self.ambiguous_loss_value, | |
| loss, | |
| ) | |
| losses.append((loss, masks[view_idx], "pts3d")) | |
| if loss.numel() > 0: | |
| loss_mean = float(loss.mean()) | |
| details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean | |
| # Initialize or update running average directly | |
| avg_key = f"{self_name}_pts3d_avg" | |
| if avg_key not in details: | |
| details[avg_key] = loss_mean | |
| running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1 | |
| else: | |
| valid_views = ( | |
| running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1 | |
| ) | |
| running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views | |
| details[avg_key] += (loss_mean - details[avg_key]) / valid_views | |
| else: | |
| batch_size, _, _, dim = gt_pts[0].shape | |
| for view_idx in range(n_views): | |
| gt = gt_pts[view_idx].view(batch_size, -1, dim) | |
| pred = pred_pts[view_idx].view(batch_size, -1, dim) | |
| view_mask = masks[view_idx].view(batch_size, -1) | |
| amb_mask = ambiguous_masks[view_idx].view(batch_size, -1) | |
| if self.loss_in_log: | |
| pred = apply_log_to_norm(pred) | |
| gt = apply_log_to_norm(gt) | |
| loss = self.criterion(pred, gt) | |
| if self.ambiguous_loss_value > 0: | |
| loss = torch.where(amb_mask, self.ambiguous_loss_value, loss) | |
| losses.append((loss, view_mask, "pts3d")) | |
| loss_after_masking = loss[view_mask] | |
| if loss_after_masking.numel() > 0: | |
| loss_mean = float(loss_after_masking.mean()) | |
| details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean | |
| # Initialize or update running average directly | |
| avg_key = f"{self_name}_pts3d_avg" | |
| if avg_key not in details: | |
| details[avg_key] = loss_mean | |
| running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1 | |
| else: | |
| valid_views = ( | |
| running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1 | |
| ) | |
| running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views | |
| details[avg_key] += (loss_mean - details[avg_key]) / valid_views | |
| return Sum(*losses), (details | monitoring) | |
| class PointsPlusScaleRegr3D(Criterion, MultiLoss): | |
| """ | |
| Regression Loss for World Frame Pointmaps & Scale. | |
| """ | |
| def __init__( | |
| self, | |
| criterion, | |
| norm_predictions=True, | |
| norm_mode="avg_dis", | |
| ambiguous_loss_value=0, | |
| loss_in_log=True, | |
| flatten_across_image_only=False, | |
| world_frame_points_loss_weight=1, | |
| scale_loss_weight=1, | |
| ): | |
| """ | |
| Initialize the loss criterion for World Frame Pointmaps & Scale. | |
| The predicited scene representation is always normalized w.r.t. the frame of view0. | |
| Loss is applied between the predicted metric scale and the ground truth metric scale. | |
| Args: | |
| criterion (BaseCriterion): The base criterion to use for computing the loss. | |
| norm_predictions (bool): If True, normalize the predictions before computing the loss. | |
| norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". | |
| ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. | |
| If 0, ambiguous pixels are ignored. Default: 0. | |
| loss_in_log (bool): If True, apply logarithmic transformation to input before | |
| computing the loss for depth, pointmaps and scale. Default: True. | |
| flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing | |
| the loss. If False, flatten across batch and spatial dimensions. Default: False. | |
| world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. | |
| scale_loss_weight (float): Weight to use for the scale loss. Default: 1. | |
| """ | |
| super().__init__(criterion) | |
| self.norm_predictions = norm_predictions | |
| self.norm_mode = norm_mode | |
| self.ambiguous_loss_value = ambiguous_loss_value | |
| self.loss_in_log = loss_in_log | |
| self.flatten_across_image_only = flatten_across_image_only | |
| self.world_frame_points_loss_weight = world_frame_points_loss_weight | |
| self.scale_loss_weight = scale_loss_weight | |
| def get_all_info(self, batch, preds, dist_clip=None): | |
| """ | |
| Function to get all the information needed to compute the loss. | |
| Returns all quantities normalized w.r.t. camera of view0. | |
| """ | |
| n_views = len(batch) | |
| # Everything is normalized w.r.t. camera of view0 | |
| # Intialize lists to store data for all views | |
| # Ground truth quantities | |
| in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) | |
| no_norm_gt_pts = [] | |
| valid_masks = [] | |
| # Predicted quantities | |
| no_norm_pr_pts = [] | |
| metric_pr_pts_to_compute_scale = [] | |
| # Get ground truth & prediction info for all views | |
| for i in range(n_views): | |
| # Get the ground truth | |
| no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) | |
| valid_masks.append(batch[i]["valid_mask"].clone()) | |
| # Get predictions for normalized loss | |
| if "metric_scaling_factor" in preds[i].keys(): | |
| # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans | |
| # This detaches the predicted metric scaling factor from the geometry based loss | |
| curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ | |
| "metric_scaling_factor" | |
| ].unsqueeze(-1).unsqueeze(-1) | |
| else: | |
| curr_view_no_norm_pr_pts = preds[i]["pts3d"] | |
| no_norm_pr_pts.append(curr_view_no_norm_pr_pts) | |
| # Get the predicted metric scale points | |
| if "metric_scaling_factor" in preds[i].keys(): | |
| # Detach the raw predicted points so that the scale loss is only applied to the scaling factor | |
| curr_view_metric_pr_pts_to_compute_scale = ( | |
| curr_view_no_norm_pr_pts.detach() | |
| * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) | |
| ) | |
| else: | |
| curr_view_metric_pr_pts_to_compute_scale = ( | |
| curr_view_no_norm_pr_pts.clone() | |
| ) | |
| metric_pr_pts_to_compute_scale.append( | |
| curr_view_metric_pr_pts_to_compute_scale | |
| ) | |
| if dist_clip is not None: | |
| # Points that are too far-away == invalid | |
| for i in range(n_views): | |
| dis = no_norm_gt_pts[i].norm(dim=-1) | |
| valid_masks[i] = valid_masks[i] & (dis <= dist_clip) | |
| # Initialize normalized tensors | |
| gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] | |
| pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] | |
| # Normalize the predicted points if specified | |
| if self.norm_predictions: | |
| pr_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_pr_pts, | |
| valid_masks, | |
| self.norm_mode, | |
| ret_factor=True, | |
| ) | |
| pr_pts_norm = pr_normalization_output[:-1] | |
| # Normalize the ground truth points | |
| gt_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True | |
| ) | |
| gt_pts_norm = gt_normalization_output[:-1] | |
| gt_norm_factor = gt_normalization_output[-1] | |
| for i in range(n_views): | |
| if self.norm_predictions: | |
| # Assign the normalized predictions | |
| pr_pts[i] = pr_pts_norm[i] | |
| else: | |
| pr_pts[i] = no_norm_pr_pts[i] | |
| # Assign the normalized ground truth quantities | |
| gt_pts[i] = gt_pts_norm[i] | |
| # Get the mask indicating ground truth metric scale quantities | |
| metric_scale_mask = batch[0]["is_metric_scale"] | |
| valid_gt_norm_factor_mask = ( | |
| gt_norm_factor[:, 0, 0, 0] > 1e-8 | |
| ) # Mask out cases where depth for all views is invalid | |
| valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask | |
| if valid_metric_scale_mask.any(): | |
| # Compute the scale norm factor using the predicted metric scale points | |
| metric_pr_normalization_output = normalize_multiple_pointclouds( | |
| metric_pr_pts_to_compute_scale, | |
| valid_masks, | |
| self.norm_mode, | |
| ret_factor=True, | |
| ) | |
| pr_metric_norm_factor = metric_pr_normalization_output[-1] | |
| # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities | |
| gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] | |
| pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] | |
| else: | |
| gt_metric_norm_factor = None | |
| pr_metric_norm_factor = None | |
| # Get ambiguous masks | |
| ambiguous_masks = [] | |
| for i in range(n_views): | |
| ambiguous_masks.append( | |
| (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) | |
| ) | |
| # Pack into info dicts | |
| gt_info = [] | |
| pred_info = [] | |
| for i in range(n_views): | |
| gt_info.append( | |
| { | |
| "pts3d": gt_pts[i], | |
| } | |
| ) | |
| pred_info.append( | |
| { | |
| "pts3d": pr_pts[i], | |
| } | |
| ) | |
| return ( | |
| gt_info, | |
| pred_info, | |
| valid_masks, | |
| ambiguous_masks, | |
| gt_metric_norm_factor, | |
| pr_metric_norm_factor, | |
| ) | |
| def compute_loss(self, batch, preds, **kw): | |
| ( | |
| gt_info, | |
| pred_info, | |
| valid_masks, | |
| ambiguous_masks, | |
| gt_metric_norm_factor, | |
| pr_metric_norm_factor, | |
| ) = self.get_all_info(batch, preds, **kw) | |
| n_views = len(batch) | |
| if self.ambiguous_loss_value > 0: | |
| assert self.criterion.reduction == "none", ( | |
| "ambiguous_loss_value should be 0 if no conf loss" | |
| ) | |
| # Add the ambiguous pixel as "valid" pixels... | |
| valid_masks = [ | |
| mask | ambig_mask | |
| for mask, ambig_mask in zip(valid_masks, ambiguous_masks) | |
| ] | |
| pts3d_losses = [] | |
| for i in range(n_views): | |
| # Get the predicted dense quantities | |
| if not self.flatten_across_image_only: | |
| # Flatten the points across the entire batch with the masks | |
| pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] | |
| gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] | |
| else: | |
| # Flatten the H x W dimensions to H*W | |
| batch_size, _, _, pts_dim = gt_info[i]["pts3d"].shape | |
| gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| valid_masks[i] = valid_masks[i].view(batch_size, -1) | |
| # Apply loss in log space if specified | |
| if self.loss_in_log: | |
| gt_pts3d = apply_log_to_norm(gt_pts3d) | |
| pred_pts3d = apply_log_to_norm(pred_pts3d) | |
| # Compute point loss | |
| pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") | |
| pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight | |
| pts3d_losses.append(pts3d_loss) | |
| # Handle ambiguous pixels | |
| if self.ambiguous_loss_value > 0: | |
| if not self.flatten_across_image_only: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| else: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| # Compute the scale loss | |
| if gt_metric_norm_factor is not None: | |
| if self.loss_in_log: | |
| gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) | |
| pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) | |
| scale_loss = ( | |
| self.criterion( | |
| pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" | |
| ) | |
| * self.scale_loss_weight | |
| ) | |
| else: | |
| scale_loss = None | |
| # Use helper function to generate loss terms and details | |
| losses_dict = { | |
| "pts3d": { | |
| "values": pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "scale": { | |
| "values": scale_loss, | |
| "use_mask": False, | |
| "is_multi_view": False, | |
| }, | |
| } | |
| loss_terms, details = get_loss_terms_and_details( | |
| losses_dict, | |
| valid_masks, | |
| type(self).__name__, | |
| n_views, | |
| self.flatten_across_image_only, | |
| ) | |
| losses = Sum(*loss_terms) | |
| return losses, (details | {}) | |
| class NormalGMLoss(MultiLoss): | |
| """ | |
| Normal & Gradient Matching Loss for Monocular Depth Training. | |
| """ | |
| def __init__( | |
| self, | |
| norm_predictions=True, | |
| norm_mode="avg_dis", | |
| apply_normal_and_gm_loss_to_synthetic_data_only=True, | |
| ): | |
| """ | |
| Initialize the loss criterion for Normal & Gradient Matching Loss (currently only valid for 1 view). | |
| Computes: | |
| (1) Normal Loss over the PointMap (naturally will be in local frame) in euclidean coordinates, | |
| (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) | |
| Args: | |
| norm_predictions (bool): If True, normalize the predictions before computing the loss. | |
| norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". | |
| apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. | |
| If False, apply the normal and gm loss to all data. Default: True. | |
| """ | |
| super().__init__() | |
| self.norm_predictions = norm_predictions | |
| self.norm_mode = norm_mode | |
| self.apply_normal_and_gm_loss_to_synthetic_data_only = ( | |
| apply_normal_and_gm_loss_to_synthetic_data_only | |
| ) | |
| def get_all_info(self, batch, preds, dist_clip=None): | |
| """ | |
| Function to get all the information needed to compute the loss. | |
| Returns all quantities normalized. | |
| """ | |
| n_views = len(batch) | |
| assert n_views == 1, ( | |
| "Normal & Gradient Matching Loss Class only supports 1 view" | |
| ) | |
| # Everything is normalized w.r.t. camera of view1 | |
| in_camera1 = closed_form_pose_inverse(batch[0]["camera_pose"]) | |
| # Initialize lists to store data for all views | |
| no_norm_gt_pts = [] | |
| valid_masks = [] | |
| no_norm_pr_pts = [] | |
| # Get ground truth & prediction info for all views | |
| for i in range(n_views): | |
| # Get ground truth | |
| no_norm_gt_pts.append(geotrf(in_camera1, batch[i]["pts3d"])) | |
| valid_masks.append(batch[i]["valid_mask"].clone()) | |
| # Get predictions for normalized loss | |
| if "metric_scaling_factor" in preds[i].keys(): | |
| # Divide by the predicted metric scaling factor to get the raw predicted points | |
| # This detaches the predicted metric scaling factor from the geometry based loss | |
| curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ | |
| "metric_scaling_factor" | |
| ].unsqueeze(-1).unsqueeze(-1) | |
| else: | |
| curr_view_no_norm_pr_pts = preds[i]["pts3d"] | |
| no_norm_pr_pts.append(curr_view_no_norm_pr_pts) | |
| if dist_clip is not None: | |
| # Points that are too far-away == invalid | |
| for i in range(n_views): | |
| dis = no_norm_gt_pts[i].norm(dim=-1) | |
| valid_masks[i] = valid_masks[i] & (dis <= dist_clip) | |
| # Initialize normalized tensors | |
| gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] | |
| pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] | |
| # Normalize the predicted points if specified | |
| if self.norm_predictions: | |
| pr_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_pr_pts, | |
| valid_masks, | |
| self.norm_mode, | |
| ret_factor=True, | |
| ) | |
| pr_pts_norm = pr_normalization_output[:-1] | |
| # Normalize the ground truth points | |
| gt_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True | |
| ) | |
| gt_pts_norm = gt_normalization_output[:-1] | |
| for i in range(n_views): | |
| if self.norm_predictions: | |
| # Assign the normalized predictions | |
| pr_pts[i] = pr_pts_norm[i] | |
| else: | |
| # Assign the raw predicted points | |
| pr_pts[i] = no_norm_pr_pts[i] | |
| # Assign the normalized ground truth | |
| gt_pts[i] = gt_pts_norm[i] | |
| return gt_pts, pr_pts, valid_masks | |
| def compute_loss(self, batch, preds, **kw): | |
| gt_pts, pred_pts, valid_masks = self.get_all_info(batch, preds, **kw) | |
| n_views = len(batch) | |
| assert n_views == 1, ( | |
| "Normal & Gradient Matching Loss Class only supports 1 view" | |
| ) | |
| normal_losses = [] | |
| gradient_matching_losses = [] | |
| details = {} | |
| running_avg_dict = {} | |
| self_name = type(self).__name__ | |
| for i in range(n_views): | |
| # Get the local frame points, log space depth_z & valid masks | |
| pred_local_pts3d = pred_pts[i] | |
| pred_depth_z = pred_local_pts3d[..., 2:] | |
| pred_depth_z = apply_log_to_norm(pred_depth_z) | |
| gt_local_pts3d = gt_pts[i] | |
| gt_depth_z = gt_local_pts3d[..., 2:] | |
| gt_depth_z = apply_log_to_norm(gt_depth_z) | |
| valid_mask_for_normal_gm_loss = valid_masks[i].clone() | |
| # Update the validity mask for normal & gm loss based on the synthetic data mask if required | |
| if self.apply_normal_and_gm_loss_to_synthetic_data_only: | |
| synthetic_mask = batch[i]["is_synthetic"] # (B, ) | |
| synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) | |
| synthetic_mask = synthetic_mask.expand( | |
| -1, pred_depth_z.shape[1], pred_depth_z.shape[2] | |
| ) # (B, H, W) | |
| valid_mask_for_normal_gm_loss = ( | |
| valid_mask_for_normal_gm_loss & synthetic_mask | |
| ) | |
| # Compute the normal loss | |
| normal_loss = compute_normal_loss( | |
| pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() | |
| ) | |
| normal_losses.append(normal_loss) | |
| # Compute the gradient matching loss | |
| gradient_matching_loss = compute_gradient_matching_loss( | |
| pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() | |
| ) | |
| gradient_matching_losses.append(gradient_matching_loss) | |
| # Add loss details if only valid values are present | |
| # Initialize or update running average directly | |
| # Normal loss details | |
| if float(normal_loss) > 0: | |
| details[f"{self_name}_normal_view{i + 1}"] = float(normal_loss) | |
| normal_avg_key = f"{self_name}_normal_avg" | |
| if normal_avg_key not in details: | |
| details[normal_avg_key] = float(normal_losses[i]) | |
| running_avg_dict[f"{self_name}_normal_valid_views"] = 1 | |
| else: | |
| normal_valid_views = ( | |
| running_avg_dict[f"{self_name}_normal_valid_views"] + 1 | |
| ) | |
| running_avg_dict[f"{self_name}_normal_valid_views"] = ( | |
| normal_valid_views | |
| ) | |
| details[normal_avg_key] += ( | |
| float(normal_losses[i]) - details[normal_avg_key] | |
| ) / normal_valid_views | |
| # Gradient Matching loss details | |
| if float(gradient_matching_loss) > 0: | |
| details[f"{self_name}_gradient_matching_view{i + 1}"] = float( | |
| gradient_matching_loss | |
| ) | |
| # For gradient matching loss | |
| gm_avg_key = f"{self_name}_gradient_matching_avg" | |
| if gm_avg_key not in details: | |
| details[gm_avg_key] = float(gradient_matching_losses[i]) | |
| running_avg_dict[f"{self_name}_gm_valid_views"] = 1 | |
| else: | |
| gm_valid_views = running_avg_dict[f"{self_name}_gm_valid_views"] + 1 | |
| running_avg_dict[f"{self_name}_gm_valid_views"] = gm_valid_views | |
| details[gm_avg_key] += ( | |
| float(gradient_matching_losses[i]) - details[gm_avg_key] | |
| ) / gm_valid_views | |
| # Put the losses together | |
| loss_terms = [] | |
| for i in range(n_views): | |
| loss_terms.append((normal_losses[i], None, "normal")) | |
| loss_terms.append((gradient_matching_losses[i], None, "gradient_matching")) | |
| losses = Sum(*loss_terms) | |
| return losses, details | |
| class FactoredGeometryRegr3D(Criterion, MultiLoss): | |
| """ | |
| Regression Loss for Factored Geometry. | |
| """ | |
| def __init__( | |
| self, | |
| criterion, | |
| norm_mode="?avg_dis", | |
| gt_scale=False, | |
| ambiguous_loss_value=0, | |
| max_metric_scale=False, | |
| loss_in_log=True, | |
| flatten_across_image_only=False, | |
| depth_type_for_loss="depth_along_ray", | |
| cam_frame_points_loss_weight=1, | |
| depth_loss_weight=1, | |
| ray_directions_loss_weight=1, | |
| pose_quats_loss_weight=1, | |
| pose_trans_loss_weight=1, | |
| compute_pairwise_relative_pose_loss=False, | |
| convert_predictions_to_view0_frame=False, | |
| compute_world_frame_points_loss=True, | |
| world_frame_points_loss_weight=1, | |
| ): | |
| """ | |
| Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), | |
| and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps. | |
| If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order: | |
| (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans. | |
| Else, the pixel-level losses are returned in the following order: | |
| (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans. | |
| Args: | |
| criterion (BaseCriterion): The base criterion to use for computing the loss. | |
| norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis". | |
| If prefixed with "?", normalization is only applied to non-metric scale data. | |
| gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. | |
| If False, both GT and predictions are normalized independently. Default: False. | |
| ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. | |
| If 0, ambiguous pixels are ignored. Default: 0. | |
| max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this | |
| value, it will be treated as non-metric. Default: False (no limit). | |
| loss_in_log (bool): If True, apply logarithmic transformation to input before | |
| computing the loss for depth and pointmaps. Default: True. | |
| flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing | |
| the loss. If False, flatten across batch and spatial dimensions. Default: False. | |
| depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". | |
| Options: "depth_along_ray", "depth_z" | |
| cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. | |
| depth_loss_weight (float): Weight to use for the depth loss. Default: 1. | |
| ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. | |
| pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. | |
| pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. | |
| compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the | |
| exhaustive pairwise relative poses. Default: False. | |
| convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. | |
| Use this if the predictions are not already in the view0 frame. Default: False. | |
| compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. | |
| world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. | |
| """ | |
| super().__init__(criterion) | |
| if norm_mode.startswith("?"): | |
| # Do no norm pts from metric scale datasets | |
| self.norm_all = False | |
| self.norm_mode = norm_mode[1:] | |
| else: | |
| self.norm_all = True | |
| self.norm_mode = norm_mode | |
| self.gt_scale = gt_scale | |
| self.ambiguous_loss_value = ambiguous_loss_value | |
| self.max_metric_scale = max_metric_scale | |
| self.loss_in_log = loss_in_log | |
| self.flatten_across_image_only = flatten_across_image_only | |
| self.depth_type_for_loss = depth_type_for_loss | |
| assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], ( | |
| "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" | |
| ) | |
| self.cam_frame_points_loss_weight = cam_frame_points_loss_weight | |
| self.depth_loss_weight = depth_loss_weight | |
| self.ray_directions_loss_weight = ray_directions_loss_weight | |
| self.pose_quats_loss_weight = pose_quats_loss_weight | |
| self.pose_trans_loss_weight = pose_trans_loss_weight | |
| self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss | |
| self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame | |
| self.compute_world_frame_points_loss = compute_world_frame_points_loss | |
| self.world_frame_points_loss_weight = world_frame_points_loss_weight | |
| def get_all_info(self, batch, preds, dist_clip=None): | |
| """ | |
| Function to get all the information needed to compute the loss. | |
| Returns all quantities normalized w.r.t. camera of view0. | |
| """ | |
| n_views = len(batch) | |
| # Everything is normalized w.r.t. camera of view0 | |
| # Intialize lists to store data for all views | |
| # Ground truth quantities | |
| in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) | |
| no_norm_gt_pts = [] | |
| no_norm_gt_pts_cam = [] | |
| no_norm_gt_depth = [] | |
| no_norm_gt_pose_trans = [] | |
| valid_masks = [] | |
| gt_ray_directions = [] | |
| gt_pose_quats = [] | |
| # Predicted quantities | |
| if self.convert_predictions_to_view0_frame: | |
| # Get the camera transform to convert quantities to view0 frame | |
| pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze( | |
| 0 | |
| ) | |
| batch_size = preds[0]["cam_quats"].shape[0] | |
| pred_camera0 = pred_camera0.repeat(batch_size, 1, 1) | |
| pred_camera0_rot = quaternion_to_rotation_matrix( | |
| preds[0]["cam_quats"].clone() | |
| ) | |
| pred_camera0[..., :3, :3] = pred_camera0_rot | |
| pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone() | |
| pred_in_camera0 = closed_form_pose_inverse(pred_camera0) | |
| no_norm_pr_pts = [] | |
| no_norm_pr_pts_cam = [] | |
| no_norm_pr_depth = [] | |
| no_norm_pr_pose_trans = [] | |
| pr_ray_directions = [] | |
| pr_pose_quats = [] | |
| # Get ground truth & prediction info for all views | |
| for i in range(n_views): | |
| # Get ground truth | |
| no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) | |
| valid_masks.append(batch[i]["valid_mask"].clone()) | |
| no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) | |
| gt_ray_directions.append(batch[i]["ray_directions_cam"]) | |
| if self.depth_type_for_loss == "depth_along_ray": | |
| no_norm_gt_depth.append(batch[i]["depth_along_ray"]) | |
| elif self.depth_type_for_loss == "depth_z": | |
| no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) | |
| if i == 0: | |
| # For view0, initialize identity pose | |
| gt_pose_quats.append( | |
| torch.tensor( | |
| [0, 0, 0, 1], | |
| dtype=gt_ray_directions[0].dtype, | |
| device=gt_ray_directions[0].device, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(gt_ray_directions[0].shape[0], 1) | |
| ) | |
| no_norm_gt_pose_trans.append( | |
| torch.tensor( | |
| [0, 0, 0], | |
| dtype=gt_ray_directions[0].dtype, | |
| device=gt_ray_directions[0].device, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(gt_ray_directions[0].shape[0], 1) | |
| ) | |
| else: | |
| # For other views, transform pose to view0's frame | |
| gt_pose_quats_world = batch[i]["camera_pose_quats"] | |
| no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] | |
| gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( | |
| transform_pose_using_quats_and_trans_2_to_1( | |
| batch[0]["camera_pose_quats"], | |
| batch[0]["camera_pose_trans"], | |
| gt_pose_quats_world, | |
| no_norm_gt_pose_trans_world, | |
| ) | |
| ) | |
| gt_pose_quats.append(gt_pose_quats_in_view0) | |
| no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) | |
| # Get the local predictions | |
| no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"]) | |
| pr_ray_directions.append(preds[i]["ray_directions"]) | |
| if self.depth_type_for_loss == "depth_along_ray": | |
| no_norm_pr_depth.append(preds[i]["depth_along_ray"]) | |
| elif self.depth_type_for_loss == "depth_z": | |
| no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:]) | |
| # Get the predicted global predictions in view0's frame | |
| if self.convert_predictions_to_view0_frame: | |
| # Convert predictions to view0 frame | |
| pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"]) | |
| pr_pose_quats_in_view0, pr_pose_trans_in_view0 = ( | |
| transform_pose_using_quats_and_trans_2_to_1( | |
| preds[0]["cam_quats"], | |
| preds[0]["cam_trans"], | |
| preds[i]["cam_quats"], | |
| preds[i]["cam_trans"], | |
| ) | |
| ) | |
| no_norm_pr_pts.append(pr_pts3d_in_view0) | |
| no_norm_pr_pose_trans.append(pr_pose_trans_in_view0) | |
| pr_pose_quats.append(pr_pose_quats_in_view0) | |
| else: | |
| # Predictions are already in view0 frame | |
| no_norm_pr_pts.append(preds[i]["pts3d"]) | |
| no_norm_pr_pose_trans.append(preds[i]["cam_trans"]) | |
| pr_pose_quats.append(preds[i]["cam_quats"]) | |
| if dist_clip is not None: | |
| # Points that are too far-away == invalid | |
| for i in range(n_views): | |
| dis = no_norm_gt_pts[i].norm(dim=-1) | |
| valid_masks[i] = valid_masks[i] & (dis <= dist_clip) | |
| # Handle metric scale | |
| if not self.norm_all: | |
| if self.max_metric_scale: | |
| B = valid_masks[0].shape[0] | |
| dists_to_cam1 = [] | |
| for i in range(n_views): | |
| dists_to_cam1.append( | |
| torch.where( | |
| valid_masks[i], torch.norm(no_norm_gt_pts[i], dim=-1), 0 | |
| ).view(B, -1) | |
| ) | |
| batch[0]["is_metric_scale"] = batch[0]["is_metric_scale"] | |
| for dist in dists_to_cam1: | |
| batch[0]["is_metric_scale"] &= ( | |
| dist.max(dim=-1).values < self.max_metric_scale | |
| ) | |
| for i in range(1, n_views): | |
| batch[i]["is_metric_scale"] = batch[0]["is_metric_scale"] | |
| non_metric_scale_mask = ~batch[0]["is_metric_scale"] | |
| else: | |
| non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"]) | |
| # Initialize normalized tensors | |
| gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] | |
| gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] | |
| gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] | |
| gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] | |
| pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] | |
| pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] | |
| pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] | |
| pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] | |
| # Normalize points | |
| if self.norm_mode and non_metric_scale_mask.any(): | |
| pr_normalization_output = normalize_multiple_pointclouds( | |
| [pts[non_metric_scale_mask] for pts in no_norm_pr_pts], | |
| [mask[non_metric_scale_mask] for mask in valid_masks], | |
| self.norm_mode, | |
| ret_factor=True, | |
| ) | |
| pr_pts_norm = pr_normalization_output[:-1] | |
| pr_norm_factor = pr_normalization_output[-1] | |
| for i in range(n_views): | |
| pr_pts[i][non_metric_scale_mask] = pr_pts_norm[i] | |
| pr_pts_cam[i][non_metric_scale_mask] = ( | |
| no_norm_pr_pts_cam[i][non_metric_scale_mask] / pr_norm_factor | |
| ) | |
| pr_depth[i][non_metric_scale_mask] = ( | |
| no_norm_pr_depth[i][non_metric_scale_mask] / pr_norm_factor | |
| ) | |
| pr_pose_trans[i][non_metric_scale_mask] = ( | |
| no_norm_pr_pose_trans[i][non_metric_scale_mask] | |
| / pr_norm_factor[:, :, 0, 0] | |
| ) | |
| elif non_metric_scale_mask.any(): | |
| for i in range(n_views): | |
| pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][ | |
| non_metric_scale_mask | |
| ] | |
| pr_pts_cam[i][non_metric_scale_mask] = no_norm_pr_pts_cam[i][ | |
| non_metric_scale_mask | |
| ] | |
| pr_depth[i][non_metric_scale_mask] = no_norm_pr_depth[i][ | |
| non_metric_scale_mask | |
| ] | |
| pr_pose_trans[i][non_metric_scale_mask] = no_norm_pr_pose_trans[i][ | |
| non_metric_scale_mask | |
| ] | |
| if self.norm_mode and not self.gt_scale: | |
| gt_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True | |
| ) | |
| gt_pts_norm = gt_normalization_output[:-1] | |
| norm_factor = gt_normalization_output[-1] | |
| for i in range(n_views): | |
| gt_pts[i] = gt_pts_norm[i] | |
| gt_pts_cam[i] = no_norm_gt_pts_cam[i] / norm_factor | |
| gt_depth[i] = no_norm_gt_depth[i] / norm_factor | |
| gt_pose_trans[i] = no_norm_gt_pose_trans[i] / norm_factor[:, :, 0, 0] | |
| pr_pts[i][~non_metric_scale_mask] = ( | |
| no_norm_pr_pts[i][~non_metric_scale_mask] | |
| / norm_factor[~non_metric_scale_mask] | |
| ) | |
| pr_pts_cam[i][~non_metric_scale_mask] = ( | |
| no_norm_pr_pts_cam[i][~non_metric_scale_mask] | |
| / norm_factor[~non_metric_scale_mask] | |
| ) | |
| pr_depth[i][~non_metric_scale_mask] = ( | |
| no_norm_pr_depth[i][~non_metric_scale_mask] | |
| / norm_factor[~non_metric_scale_mask] | |
| ) | |
| pr_pose_trans[i][~non_metric_scale_mask] = ( | |
| no_norm_pr_pose_trans[i][~non_metric_scale_mask] | |
| / norm_factor[~non_metric_scale_mask][:, :, 0, 0] | |
| ) | |
| elif ~non_metric_scale_mask.any(): | |
| for i in range(n_views): | |
| gt_pts[i] = no_norm_gt_pts[i] | |
| gt_pts_cam[i] = no_norm_gt_pts_cam[i] | |
| gt_depth[i] = no_norm_gt_depth[i] | |
| gt_pose_trans[i] = no_norm_gt_pose_trans[i] | |
| pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][ | |
| ~non_metric_scale_mask | |
| ] | |
| pr_pts_cam[i][~non_metric_scale_mask] = no_norm_pr_pts_cam[i][ | |
| ~non_metric_scale_mask | |
| ] | |
| pr_depth[i][~non_metric_scale_mask] = no_norm_pr_depth[i][ | |
| ~non_metric_scale_mask | |
| ] | |
| pr_pose_trans[i][~non_metric_scale_mask] = no_norm_pr_pose_trans[i][ | |
| ~non_metric_scale_mask | |
| ] | |
| else: | |
| for i in range(n_views): | |
| gt_pts[i] = no_norm_gt_pts[i] | |
| gt_pts_cam[i] = no_norm_gt_pts_cam[i] | |
| gt_depth[i] = no_norm_gt_depth[i] | |
| gt_pose_trans[i] = no_norm_gt_pose_trans[i] | |
| # Get ambiguous masks | |
| ambiguous_masks = [] | |
| for i in range(n_views): | |
| ambiguous_masks.append( | |
| (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) | |
| ) | |
| # Pack into info dicts | |
| gt_info = [] | |
| pred_info = [] | |
| for i in range(n_views): | |
| gt_info.append( | |
| { | |
| "ray_directions": gt_ray_directions[i], | |
| self.depth_type_for_loss: gt_depth[i], | |
| "pose_trans": gt_pose_trans[i], | |
| "pose_quats": gt_pose_quats[i], | |
| "pts3d": gt_pts[i], | |
| "pts3d_cam": gt_pts_cam[i], | |
| } | |
| ) | |
| pred_info.append( | |
| { | |
| "ray_directions": pr_ray_directions[i], | |
| self.depth_type_for_loss: pr_depth[i], | |
| "pose_trans": pr_pose_trans[i], | |
| "pose_quats": pr_pose_quats[i], | |
| "pts3d": pr_pts[i], | |
| "pts3d_cam": pr_pts_cam[i], | |
| } | |
| ) | |
| return gt_info, pred_info, valid_masks, ambiguous_masks | |
| def compute_loss(self, batch, preds, **kw): | |
| gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info( | |
| batch, preds, **kw | |
| ) | |
| n_views = len(batch) | |
| # Mask out samples in the batch where the gt depth validity mask is entirely zero | |
| valid_norm_factor_masks = [ | |
| mask.sum(dim=(1, 2)) > 0 for mask in valid_masks | |
| ] # List of (B,) | |
| if self.ambiguous_loss_value > 0: | |
| assert self.criterion.reduction == "none", ( | |
| "ambiguous_loss_value should be 0 if no conf loss" | |
| ) | |
| # Add the ambiguous pixel as "valid" pixels... | |
| valid_masks = [ | |
| mask | ambig_mask | |
| for mask, ambig_mask in zip(valid_masks, ambiguous_masks) | |
| ] | |
| pose_trans_losses = [] | |
| pose_quats_losses = [] | |
| ray_directions_losses = [] | |
| depth_losses = [] | |
| cam_pts3d_losses = [] | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses = [] | |
| for i in range(n_views): | |
| # Get the predicted dense quantities | |
| if not self.flatten_across_image_only: | |
| # Flatten the points across the entire batch with the masks | |
| pred_ray_directions = pred_info[i]["ray_directions"] | |
| gt_ray_directions = gt_info[i]["ray_directions"] | |
| pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] | |
| gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] | |
| pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] | |
| gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] | |
| if self.compute_world_frame_points_loss: | |
| pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] | |
| gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] | |
| else: | |
| # Flatten the H x W dimensions to H*W | |
| batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape | |
| gt_ray_directions = gt_info[i]["ray_directions"].view( | |
| batch_size, -1, direction_dim | |
| ) | |
| pred_ray_directions = pred_info[i]["ray_directions"].view( | |
| batch_size, -1, direction_dim | |
| ) | |
| depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] | |
| gt_depth = gt_info[i][self.depth_type_for_loss].view( | |
| batch_size, -1, depth_dim | |
| ) | |
| pred_depth = pred_info[i][self.depth_type_for_loss].view( | |
| batch_size, -1, depth_dim | |
| ) | |
| cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] | |
| gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) | |
| pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( | |
| batch_size, -1, cam_pts_dim | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts_dim = gt_info[i]["pts3d"].shape[-1] | |
| gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| valid_masks[i] = valid_masks[i].view(batch_size, -1) | |
| # Apply loss in log space for depth if specified | |
| if self.loss_in_log: | |
| gt_depth = apply_log_to_norm(gt_depth) | |
| pred_depth = apply_log_to_norm(pred_depth) | |
| gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) | |
| pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) | |
| if self.compute_world_frame_points_loss: | |
| gt_pts3d = apply_log_to_norm(gt_pts3d) | |
| pred_pts3d = apply_log_to_norm(pred_pts3d) | |
| if self.compute_pairwise_relative_pose_loss: | |
| # Get the inverse of current view predicted pose | |
| pred_inv_curr_view_pose_quats = quaternion_inverse( | |
| pred_info[i]["pose_quats"] | |
| ) | |
| pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( | |
| pred_inv_curr_view_pose_quats | |
| ) | |
| pred_inv_curr_view_pose_trans = -1 * ein.einsum( | |
| pred_inv_curr_view_pose_rot_mat, | |
| pred_info[i]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| # Get the inverse of the current view GT pose | |
| gt_inv_curr_view_pose_quats = quaternion_inverse( | |
| gt_info[i]["pose_quats"] | |
| ) | |
| gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( | |
| gt_inv_curr_view_pose_quats | |
| ) | |
| gt_inv_curr_view_pose_trans = -1 * ein.einsum( | |
| gt_inv_curr_view_pose_rot_mat, | |
| gt_info[i]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| # Get the other N-1 relative poses using the current pose as reference frame | |
| pred_rel_pose_quats = [] | |
| pred_rel_pose_trans = [] | |
| gt_rel_pose_quats = [] | |
| gt_rel_pose_trans = [] | |
| for ov_idx in range(n_views): | |
| if ov_idx == i: | |
| continue | |
| # Get the relative predicted pose | |
| pred_ov_rel_pose_quats = quaternion_multiply( | |
| pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] | |
| ) | |
| pred_ov_rel_pose_trans = ( | |
| ein.einsum( | |
| pred_inv_curr_view_pose_rot_mat, | |
| pred_info[ov_idx]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| + pred_inv_curr_view_pose_trans | |
| ) | |
| # Get the relative GT pose | |
| gt_ov_rel_pose_quats = quaternion_multiply( | |
| gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] | |
| ) | |
| gt_ov_rel_pose_trans = ( | |
| ein.einsum( | |
| gt_inv_curr_view_pose_rot_mat, | |
| gt_info[ov_idx]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| + gt_inv_curr_view_pose_trans | |
| ) | |
| # Get the valid translations using valid_norm_factor_masks for current view and other view | |
| overall_valid_mask_for_trans = ( | |
| valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] | |
| ) | |
| # Append the relative poses | |
| pred_rel_pose_quats.append(pred_ov_rel_pose_quats) | |
| pred_rel_pose_trans.append( | |
| pred_ov_rel_pose_trans[overall_valid_mask_for_trans] | |
| ) | |
| gt_rel_pose_quats.append(gt_ov_rel_pose_quats) | |
| gt_rel_pose_trans.append( | |
| gt_ov_rel_pose_trans[overall_valid_mask_for_trans] | |
| ) | |
| # Cat the N-1 relative poses along the batch dimension | |
| pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) | |
| pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) | |
| gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) | |
| gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| # Handle quaternion two-to-one mapping | |
| pose_quats_loss = torch.minimum( | |
| self.criterion( | |
| pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" | |
| ), | |
| self.criterion( | |
| pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" | |
| ), | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| else: | |
| # Get the pose info for the current view | |
| pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] | |
| gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] | |
| pred_pose_quats = pred_info[i]["pose_quats"] | |
| gt_pose_quats = gt_info[i]["pose_quats"] | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_pose_trans, gt_pose_trans, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| # Handle quaternion two-to-one mapping | |
| pose_quats_loss = torch.minimum( | |
| self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), | |
| self.criterion( | |
| pred_pose_quats, -gt_pose_quats, factor="pose_quats" | |
| ), | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| # Compute ray direction loss | |
| ray_directions_loss = self.criterion( | |
| pred_ray_directions, gt_ray_directions, factor="ray_directions" | |
| ) | |
| ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight | |
| ray_directions_losses.append(ray_directions_loss) | |
| # Compute depth loss | |
| depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") | |
| depth_loss = depth_loss * self.depth_loss_weight | |
| depth_losses.append(depth_loss) | |
| # Compute camera frame point loss | |
| cam_pts3d_loss = self.criterion( | |
| pred_cam_pts3d, gt_cam_pts3d, factor="points" | |
| ) | |
| cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight | |
| cam_pts3d_losses.append(cam_pts3d_loss) | |
| if self.compute_world_frame_points_loss: | |
| # Compute point loss | |
| pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") | |
| pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight | |
| pts3d_losses.append(pts3d_loss) | |
| # Handle ambiguous pixels | |
| if self.ambiguous_loss_value > 0: | |
| if not self.flatten_across_image_only: | |
| depth_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| depth_losses[i], | |
| ) | |
| cam_pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| cam_pts3d_losses[i], | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| else: | |
| depth_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| depth_losses[i], | |
| ) | |
| cam_pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| cam_pts3d_losses[i], | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| # Use helper function to generate loss terms and details | |
| if self.compute_world_frame_points_loss: | |
| losses_dict = { | |
| "pts3d": { | |
| "values": pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| } | |
| else: | |
| losses_dict = {} | |
| losses_dict.update( | |
| { | |
| "cam_pts3d": { | |
| "values": cam_pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| self.depth_type_for_loss: { | |
| "values": depth_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "ray_directions": { | |
| "values": ray_directions_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "pose_quats": { | |
| "values": pose_quats_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "pose_trans": { | |
| "values": pose_trans_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| } | |
| ) | |
| loss_terms, details = get_loss_terms_and_details( | |
| losses_dict, | |
| valid_masks, | |
| type(self).__name__, | |
| n_views, | |
| self.flatten_across_image_only, | |
| ) | |
| losses = Sum(*loss_terms) | |
| return losses, (details | {}) | |
| class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D): | |
| """ | |
| Regression, Normals & Gradient Matching Loss for Factored Geometry. | |
| """ | |
| def __init__( | |
| self, | |
| criterion, | |
| norm_mode="?avg_dis", | |
| gt_scale=False, | |
| ambiguous_loss_value=0, | |
| max_metric_scale=False, | |
| loss_in_log=True, | |
| flatten_across_image_only=False, | |
| depth_type_for_loss="depth_along_ray", | |
| cam_frame_points_loss_weight=1, | |
| depth_loss_weight=1, | |
| ray_directions_loss_weight=1, | |
| pose_quats_loss_weight=1, | |
| pose_trans_loss_weight=1, | |
| compute_pairwise_relative_pose_loss=False, | |
| convert_predictions_to_view0_frame=False, | |
| compute_world_frame_points_loss=True, | |
| world_frame_points_loss_weight=1, | |
| apply_normal_and_gm_loss_to_synthetic_data_only=True, | |
| normal_loss_weight=1, | |
| gm_loss_weight=1, | |
| ): | |
| """ | |
| Initialize the loss criterion for Factored Geometry (see parent class for details). | |
| Additionally computes: | |
| (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, | |
| (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) | |
| Args: | |
| criterion (BaseCriterion): The base criterion to use for computing the loss. | |
| norm_mode (str): Normalization mode for scene representation. Default: "avg_dis". | |
| If prefixed with "?", normalization is only applied to non-metric scale data. | |
| gt_scale (bool): If True, enforce predictions to have the same scale as ground truth. | |
| If False, both GT and predictions are normalized independently. Default: False. | |
| ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. | |
| If 0, ambiguous pixels are ignored. Default: 0. | |
| max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this | |
| value, it will be treated as non-metric. Default: False (no limit). | |
| loss_in_log (bool): If True, apply logarithmic transformation to input before | |
| computing the loss for depth and pointmaps. Default: True. | |
| flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing | |
| the loss. If False, flatten across batch and spatial dimensions. Default: False. | |
| depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". | |
| Options: "depth_along_ray", "depth_z" | |
| cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. | |
| depth_loss_weight (float): Weight to use for the depth loss. Default: 1. | |
| ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. | |
| pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. | |
| pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. | |
| compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the | |
| exhaustive pairwise relative poses. Default: False. | |
| convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. | |
| Use this if the predictions are not already in the view0 frame. Default: False. | |
| compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. | |
| world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. | |
| apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. | |
| If False, apply the normal and gm loss to all data. Default: True. | |
| normal_loss_weight (float): Weight to use for the normal loss. Default: 1. | |
| gm_loss_weight (float): Weight to use for the gm loss. Default: 1. | |
| """ | |
| super().__init__( | |
| criterion=criterion, | |
| norm_mode=norm_mode, | |
| gt_scale=gt_scale, | |
| ambiguous_loss_value=ambiguous_loss_value, | |
| max_metric_scale=max_metric_scale, | |
| loss_in_log=loss_in_log, | |
| flatten_across_image_only=flatten_across_image_only, | |
| depth_type_for_loss=depth_type_for_loss, | |
| cam_frame_points_loss_weight=cam_frame_points_loss_weight, | |
| depth_loss_weight=depth_loss_weight, | |
| ray_directions_loss_weight=ray_directions_loss_weight, | |
| pose_quats_loss_weight=pose_quats_loss_weight, | |
| pose_trans_loss_weight=pose_trans_loss_weight, | |
| compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss, | |
| convert_predictions_to_view0_frame=convert_predictions_to_view0_frame, | |
| compute_world_frame_points_loss=compute_world_frame_points_loss, | |
| world_frame_points_loss_weight=world_frame_points_loss_weight, | |
| ) | |
| self.apply_normal_and_gm_loss_to_synthetic_data_only = ( | |
| apply_normal_and_gm_loss_to_synthetic_data_only | |
| ) | |
| self.normal_loss_weight = normal_loss_weight | |
| self.gm_loss_weight = gm_loss_weight | |
| def compute_loss(self, batch, preds, **kw): | |
| gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info( | |
| batch, preds, **kw | |
| ) | |
| n_views = len(batch) | |
| # Mask out samples in the batch where the gt depth validity mask is entirely zero | |
| valid_norm_factor_masks = [ | |
| mask.sum(dim=(1, 2)) > 0 for mask in valid_masks | |
| ] # List of (B,) | |
| if self.ambiguous_loss_value > 0: | |
| assert self.criterion.reduction == "none", ( | |
| "ambiguous_loss_value should be 0 if no conf loss" | |
| ) | |
| # Add the ambiguous pixel as "valid" pixels... | |
| valid_masks = [ | |
| mask | ambig_mask | |
| for mask, ambig_mask in zip(valid_masks, ambiguous_masks) | |
| ] | |
| normal_losses = [] | |
| gradient_matching_losses = [] | |
| pose_trans_losses = [] | |
| pose_quats_losses = [] | |
| ray_directions_losses = [] | |
| depth_losses = [] | |
| cam_pts3d_losses = [] | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses = [] | |
| for i in range(n_views): | |
| # Get the camera frame points, log space depth_z & valid masks | |
| pred_local_pts3d = pred_info[i]["pts3d_cam"] | |
| pred_depth_z = pred_local_pts3d[..., 2:] | |
| pred_depth_z = apply_log_to_norm(pred_depth_z) | |
| gt_local_pts3d = gt_info[i]["pts3d_cam"] | |
| gt_depth_z = gt_local_pts3d[..., 2:] | |
| gt_depth_z = apply_log_to_norm(gt_depth_z) | |
| valid_mask_for_normal_gm_loss = valid_masks[i].clone() | |
| # Update the validity mask for normal & gm loss based on the synthetic data mask if required | |
| if self.apply_normal_and_gm_loss_to_synthetic_data_only: | |
| synthetic_mask = batch[i]["is_synthetic"] # (B, ) | |
| synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) | |
| synthetic_mask = synthetic_mask.expand( | |
| -1, pred_depth_z.shape[1], pred_depth_z.shape[2] | |
| ) # (B, H, W) | |
| valid_mask_for_normal_gm_loss = ( | |
| valid_mask_for_normal_gm_loss & synthetic_mask | |
| ) | |
| # Compute the normal loss | |
| normal_loss = compute_normal_loss( | |
| pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() | |
| ) | |
| normal_loss = normal_loss * self.normal_loss_weight | |
| normal_losses.append(normal_loss) | |
| # Compute the gradient matching loss | |
| gradient_matching_loss = compute_gradient_matching_loss( | |
| pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() | |
| ) | |
| gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight | |
| gradient_matching_losses.append(gradient_matching_loss) | |
| # Get the predicted dense quantities | |
| if not self.flatten_across_image_only: | |
| # Flatten the points across the entire batch with the masks | |
| pred_ray_directions = pred_info[i]["ray_directions"] | |
| gt_ray_directions = gt_info[i]["ray_directions"] | |
| pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] | |
| gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] | |
| pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] | |
| gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] | |
| if self.compute_world_frame_points_loss: | |
| pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] | |
| gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] | |
| else: | |
| # Flatten the H x W dimensions to H*W | |
| batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape | |
| gt_ray_directions = gt_info[i]["ray_directions"].view( | |
| batch_size, -1, direction_dim | |
| ) | |
| pred_ray_directions = pred_info[i]["ray_directions"].view( | |
| batch_size, -1, direction_dim | |
| ) | |
| depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] | |
| gt_depth = gt_info[i][self.depth_type_for_loss].view( | |
| batch_size, -1, depth_dim | |
| ) | |
| pred_depth = pred_info[i][self.depth_type_for_loss].view( | |
| batch_size, -1, depth_dim | |
| ) | |
| cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] | |
| gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) | |
| pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( | |
| batch_size, -1, cam_pts_dim | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts_dim = gt_info[i]["pts3d"].shape[-1] | |
| gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| valid_masks[i] = valid_masks[i].view(batch_size, -1) | |
| # Apply loss in log space for depth if specified | |
| if self.loss_in_log: | |
| gt_depth = apply_log_to_norm(gt_depth) | |
| pred_depth = apply_log_to_norm(pred_depth) | |
| gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) | |
| pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) | |
| if self.compute_world_frame_points_loss: | |
| gt_pts3d = apply_log_to_norm(gt_pts3d) | |
| pred_pts3d = apply_log_to_norm(pred_pts3d) | |
| if self.compute_pairwise_relative_pose_loss: | |
| # Get the inverse of current view predicted pose | |
| pred_inv_curr_view_pose_quats = quaternion_inverse( | |
| pred_info[i]["pose_quats"] | |
| ) | |
| pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( | |
| pred_inv_curr_view_pose_quats | |
| ) | |
| pred_inv_curr_view_pose_trans = -1 * ein.einsum( | |
| pred_inv_curr_view_pose_rot_mat, | |
| pred_info[i]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| # Get the inverse of the current view GT pose | |
| gt_inv_curr_view_pose_quats = quaternion_inverse( | |
| gt_info[i]["pose_quats"] | |
| ) | |
| gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( | |
| gt_inv_curr_view_pose_quats | |
| ) | |
| gt_inv_curr_view_pose_trans = -1 * ein.einsum( | |
| gt_inv_curr_view_pose_rot_mat, | |
| gt_info[i]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| # Get the other N-1 relative poses using the current pose as reference frame | |
| pred_rel_pose_quats = [] | |
| pred_rel_pose_trans = [] | |
| gt_rel_pose_quats = [] | |
| gt_rel_pose_trans = [] | |
| for ov_idx in range(n_views): | |
| if ov_idx == i: | |
| continue | |
| # Get the relative predicted pose | |
| pred_ov_rel_pose_quats = quaternion_multiply( | |
| pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] | |
| ) | |
| pred_ov_rel_pose_trans = ( | |
| ein.einsum( | |
| pred_inv_curr_view_pose_rot_mat, | |
| pred_info[ov_idx]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| + pred_inv_curr_view_pose_trans | |
| ) | |
| # Get the relative GT pose | |
| gt_ov_rel_pose_quats = quaternion_multiply( | |
| gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] | |
| ) | |
| gt_ov_rel_pose_trans = ( | |
| ein.einsum( | |
| gt_inv_curr_view_pose_rot_mat, | |
| gt_info[ov_idx]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| + gt_inv_curr_view_pose_trans | |
| ) | |
| # Get the valid translations using valid_norm_factor_masks for current view and other view | |
| overall_valid_mask_for_trans = ( | |
| valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] | |
| ) | |
| # Append the relative poses | |
| pred_rel_pose_quats.append(pred_ov_rel_pose_quats) | |
| pred_rel_pose_trans.append( | |
| pred_ov_rel_pose_trans[overall_valid_mask_for_trans] | |
| ) | |
| gt_rel_pose_quats.append(gt_ov_rel_pose_quats) | |
| gt_rel_pose_trans.append( | |
| gt_ov_rel_pose_trans[overall_valid_mask_for_trans] | |
| ) | |
| # Cat the N-1 relative poses along the batch dimension | |
| pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) | |
| pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) | |
| gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) | |
| gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| # Handle quaternion two-to-one mapping | |
| pose_quats_loss = torch.minimum( | |
| self.criterion( | |
| pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" | |
| ), | |
| self.criterion( | |
| pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" | |
| ), | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| else: | |
| # Get the pose info for the current view | |
| pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] | |
| gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] | |
| pred_pose_quats = pred_info[i]["pose_quats"] | |
| gt_pose_quats = gt_info[i]["pose_quats"] | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_pose_trans, gt_pose_trans, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| # Handle quaternion two-to-one mapping | |
| pose_quats_loss = torch.minimum( | |
| self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), | |
| self.criterion( | |
| pred_pose_quats, -gt_pose_quats, factor="pose_quats" | |
| ), | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| # Compute ray direction loss | |
| ray_directions_loss = self.criterion( | |
| pred_ray_directions, gt_ray_directions, factor="ray_directions" | |
| ) | |
| ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight | |
| ray_directions_losses.append(ray_directions_loss) | |
| # Compute depth loss | |
| depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") | |
| depth_loss = depth_loss * self.depth_loss_weight | |
| depth_losses.append(depth_loss) | |
| # Compute camera frame point loss | |
| cam_pts3d_loss = self.criterion( | |
| pred_cam_pts3d, gt_cam_pts3d, factor="points" | |
| ) | |
| cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight | |
| cam_pts3d_losses.append(cam_pts3d_loss) | |
| if self.compute_world_frame_points_loss: | |
| # Compute point loss | |
| pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") | |
| pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight | |
| pts3d_losses.append(pts3d_loss) | |
| # Handle ambiguous pixels | |
| if self.ambiguous_loss_value > 0: | |
| if not self.flatten_across_image_only: | |
| depth_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| depth_losses[i], | |
| ) | |
| cam_pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| cam_pts3d_losses[i], | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| else: | |
| depth_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| depth_losses[i], | |
| ) | |
| cam_pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| cam_pts3d_losses[i], | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| # Use helper function to generate loss terms and details | |
| if self.compute_world_frame_points_loss: | |
| losses_dict = { | |
| "pts3d": { | |
| "values": pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| } | |
| else: | |
| losses_dict = {} | |
| losses_dict.update( | |
| { | |
| "cam_pts3d": { | |
| "values": cam_pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| self.depth_type_for_loss: { | |
| "values": depth_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "ray_directions": { | |
| "values": ray_directions_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "pose_quats": { | |
| "values": pose_quats_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "pose_trans": { | |
| "values": pose_trans_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "normal": { | |
| "values": normal_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "gradient_matching": { | |
| "values": gradient_matching_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| } | |
| ) | |
| loss_terms, details = get_loss_terms_and_details( | |
| losses_dict, | |
| valid_masks, | |
| type(self).__name__, | |
| n_views, | |
| self.flatten_across_image_only, | |
| ) | |
| losses = Sum(*loss_terms) | |
| return losses, (details | {}) | |
| class FactoredGeometryScaleRegr3D(Criterion, MultiLoss): | |
| """ | |
| Regression Loss for Factored Geometry & Scale. | |
| """ | |
| def __init__( | |
| self, | |
| criterion, | |
| norm_predictions=True, | |
| norm_mode="avg_dis", | |
| ambiguous_loss_value=0, | |
| loss_in_log=True, | |
| flatten_across_image_only=False, | |
| depth_type_for_loss="depth_along_ray", | |
| cam_frame_points_loss_weight=1, | |
| depth_loss_weight=1, | |
| ray_directions_loss_weight=1, | |
| pose_quats_loss_weight=1, | |
| pose_trans_loss_weight=1, | |
| scale_loss_weight=1, | |
| compute_pairwise_relative_pose_loss=False, | |
| convert_predictions_to_view0_frame=False, | |
| compute_world_frame_points_loss=True, | |
| world_frame_points_loss_weight=1, | |
| ): | |
| """ | |
| Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), Scale | |
| and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps. | |
| If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order: | |
| (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans, (7) scale. | |
| Else, the pixel-level losses are returned in the following order: | |
| (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans, (6) scale. | |
| The predicited scene representation is always normalized w.r.t. the frame of view0. | |
| Loss is applied between the predicted metric scale and the ground truth metric scale. | |
| Args: | |
| criterion (BaseCriterion): The base criterion to use for computing the loss. | |
| norm_predictions (bool): If True, normalize the predictions before computing the loss. | |
| norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". | |
| ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. | |
| If 0, ambiguous pixels are ignored. Default: 0. | |
| loss_in_log (bool): If True, apply logarithmic transformation to input before | |
| computing the loss for depth, pointmaps and scale. Default: True. | |
| flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing | |
| the loss. If False, flatten across batch and spatial dimensions. Default: False. | |
| depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". | |
| Options: "depth_along_ray", "depth_z" | |
| cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. | |
| depth_loss_weight (float): Weight to use for the depth loss. Default: 1. | |
| ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. | |
| pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. | |
| pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. | |
| scale_loss_weight (float): Weight to use for the scale loss. Default: 1. | |
| compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the | |
| exhaustive pairwise relative poses. Default: False. | |
| convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. | |
| Use this if the predictions are not already in the view0 frame. Default: False. | |
| compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. | |
| world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. | |
| """ | |
| super().__init__(criterion) | |
| self.norm_predictions = norm_predictions | |
| self.norm_mode = norm_mode | |
| self.ambiguous_loss_value = ambiguous_loss_value | |
| self.loss_in_log = loss_in_log | |
| self.flatten_across_image_only = flatten_across_image_only | |
| self.depth_type_for_loss = depth_type_for_loss | |
| assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], ( | |
| "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" | |
| ) | |
| self.cam_frame_points_loss_weight = cam_frame_points_loss_weight | |
| self.depth_loss_weight = depth_loss_weight | |
| self.ray_directions_loss_weight = ray_directions_loss_weight | |
| self.pose_quats_loss_weight = pose_quats_loss_weight | |
| self.pose_trans_loss_weight = pose_trans_loss_weight | |
| self.scale_loss_weight = scale_loss_weight | |
| self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss | |
| self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame | |
| self.compute_world_frame_points_loss = compute_world_frame_points_loss | |
| self.world_frame_points_loss_weight = world_frame_points_loss_weight | |
| def get_all_info(self, batch, preds, dist_clip=None): | |
| """ | |
| Function to get all the information needed to compute the loss. | |
| Returns all quantities normalized w.r.t. camera of view0. | |
| """ | |
| n_views = len(batch) | |
| # Everything is normalized w.r.t. camera of view0 | |
| # Intialize lists to store data for all views | |
| # Ground truth quantities | |
| in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) | |
| no_norm_gt_pts = [] | |
| no_norm_gt_pts_cam = [] | |
| no_norm_gt_depth = [] | |
| no_norm_gt_pose_trans = [] | |
| valid_masks = [] | |
| gt_ray_directions = [] | |
| gt_pose_quats = [] | |
| # Predicted quantities | |
| if self.convert_predictions_to_view0_frame: | |
| # Get the camera transform to convert quantities to view0 frame | |
| pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze( | |
| 0 | |
| ) | |
| batch_size = preds[0]["cam_quats"].shape[0] | |
| pred_camera0 = pred_camera0.repeat(batch_size, 1, 1) | |
| pred_camera0_rot = quaternion_to_rotation_matrix( | |
| preds[0]["cam_quats"].clone() | |
| ) | |
| pred_camera0[..., :3, :3] = pred_camera0_rot | |
| pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone() | |
| pred_in_camera0 = closed_form_pose_inverse(pred_camera0) | |
| no_norm_pr_pts = [] | |
| no_norm_pr_pts_cam = [] | |
| no_norm_pr_depth = [] | |
| no_norm_pr_pose_trans = [] | |
| pr_ray_directions = [] | |
| pr_pose_quats = [] | |
| metric_pr_pts_to_compute_scale = [] | |
| # Get ground truth & prediction info for all views | |
| for i in range(n_views): | |
| # Get the ground truth | |
| no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) | |
| valid_masks.append(batch[i]["valid_mask"].clone()) | |
| no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) | |
| gt_ray_directions.append(batch[i]["ray_directions_cam"]) | |
| if self.depth_type_for_loss == "depth_along_ray": | |
| no_norm_gt_depth.append(batch[i]["depth_along_ray"]) | |
| elif self.depth_type_for_loss == "depth_z": | |
| no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) | |
| if i == 0: | |
| # For view0, initialize identity pose | |
| gt_pose_quats.append( | |
| torch.tensor( | |
| [0, 0, 0, 1], | |
| dtype=gt_ray_directions[0].dtype, | |
| device=gt_ray_directions[0].device, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(gt_ray_directions[0].shape[0], 1) | |
| ) | |
| no_norm_gt_pose_trans.append( | |
| torch.tensor( | |
| [0, 0, 0], | |
| dtype=gt_ray_directions[0].dtype, | |
| device=gt_ray_directions[0].device, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(gt_ray_directions[0].shape[0], 1) | |
| ) | |
| else: | |
| # For other views, transform pose to view0's frame | |
| gt_pose_quats_world = batch[i]["camera_pose_quats"] | |
| no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] | |
| gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( | |
| transform_pose_using_quats_and_trans_2_to_1( | |
| batch[0]["camera_pose_quats"], | |
| batch[0]["camera_pose_trans"], | |
| gt_pose_quats_world, | |
| no_norm_gt_pose_trans_world, | |
| ) | |
| ) | |
| gt_pose_quats.append(gt_pose_quats_in_view0) | |
| no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) | |
| # Get the global predictions in view0's frame | |
| if self.convert_predictions_to_view0_frame: | |
| # Convert predictions to view0 frame | |
| pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"]) | |
| pr_pose_quats_in_view0, pr_pose_trans_in_view0 = ( | |
| transform_pose_using_quats_and_trans_2_to_1( | |
| preds[0]["cam_quats"], | |
| preds[0]["cam_trans"], | |
| preds[i]["cam_quats"], | |
| preds[i]["cam_trans"], | |
| ) | |
| ) | |
| else: | |
| # Predictions are already in view0 frame | |
| pr_pts3d_in_view0 = preds[i]["pts3d"] | |
| pr_pose_trans_in_view0 = preds[i]["cam_trans"] | |
| pr_pose_quats_in_view0 = preds[i]["cam_quats"] | |
| # Get predictions for normalized loss | |
| if self.depth_type_for_loss == "depth_along_ray": | |
| curr_view_no_norm_depth = preds[i]["depth_along_ray"] | |
| elif self.depth_type_for_loss == "depth_z": | |
| curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:] | |
| if "metric_scaling_factor" in preds[i].keys(): | |
| # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans | |
| # This detaches the predicted metric scaling factor from the geometry based loss | |
| curr_view_no_norm_pr_pts = pr_pts3d_in_view0 / preds[i][ | |
| "metric_scaling_factor" | |
| ].unsqueeze(-1).unsqueeze(-1) | |
| curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][ | |
| "metric_scaling_factor" | |
| ].unsqueeze(-1).unsqueeze(-1) | |
| curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][ | |
| "metric_scaling_factor" | |
| ].unsqueeze(-1).unsqueeze(-1) | |
| curr_view_no_norm_pr_pose_trans = ( | |
| pr_pose_trans_in_view0 / preds[i]["metric_scaling_factor"] | |
| ) | |
| else: | |
| curr_view_no_norm_pr_pts = pr_pts3d_in_view0 | |
| curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] | |
| curr_view_no_norm_depth = curr_view_no_norm_depth | |
| curr_view_no_norm_pr_pose_trans = pr_pose_trans_in_view0 | |
| no_norm_pr_pts.append(curr_view_no_norm_pr_pts) | |
| no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam) | |
| no_norm_pr_depth.append(curr_view_no_norm_depth) | |
| no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans) | |
| pr_ray_directions.append(preds[i]["ray_directions"]) | |
| pr_pose_quats.append(pr_pose_quats_in_view0) | |
| # Get the predicted metric scale points | |
| if "metric_scaling_factor" in preds[i].keys(): | |
| # Detach the raw predicted points so that the scale loss is only applied to the scaling factor | |
| curr_view_metric_pr_pts_to_compute_scale = ( | |
| curr_view_no_norm_pr_pts.detach() | |
| * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) | |
| ) | |
| else: | |
| curr_view_metric_pr_pts_to_compute_scale = ( | |
| curr_view_no_norm_pr_pts.clone() | |
| ) | |
| metric_pr_pts_to_compute_scale.append( | |
| curr_view_metric_pr_pts_to_compute_scale | |
| ) | |
| if dist_clip is not None: | |
| # Points that are too far-away == invalid | |
| for i in range(n_views): | |
| dis = no_norm_gt_pts[i].norm(dim=-1) | |
| valid_masks[i] = valid_masks[i] & (dis <= dist_clip) | |
| # Initialize normalized tensors | |
| gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] | |
| gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] | |
| gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] | |
| gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] | |
| pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] | |
| pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] | |
| pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] | |
| pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] | |
| # Normalize the predicted points if specified | |
| if self.norm_predictions: | |
| pr_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_pr_pts, | |
| valid_masks, | |
| self.norm_mode, | |
| ret_factor=True, | |
| ) | |
| pr_pts_norm = pr_normalization_output[:-1] | |
| pr_norm_factor = pr_normalization_output[-1] | |
| # Normalize the ground truth points | |
| gt_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True | |
| ) | |
| gt_pts_norm = gt_normalization_output[:-1] | |
| gt_norm_factor = gt_normalization_output[-1] | |
| for i in range(n_views): | |
| if self.norm_predictions: | |
| # Assign the normalized predictions | |
| pr_pts[i] = pr_pts_norm[i] | |
| pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor | |
| pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor | |
| pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0] | |
| else: | |
| pr_pts[i] = no_norm_pr_pts[i] | |
| pr_pts_cam[i] = no_norm_pr_pts_cam[i] | |
| pr_depth[i] = no_norm_pr_depth[i] | |
| pr_pose_trans[i] = no_norm_pr_pose_trans[i] | |
| # Assign the normalized ground truth quantities | |
| gt_pts[i] = gt_pts_norm[i] | |
| gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor | |
| gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor | |
| gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0] | |
| # Get the mask indicating ground truth metric scale quantities | |
| metric_scale_mask = batch[0]["is_metric_scale"] | |
| valid_gt_norm_factor_mask = ( | |
| gt_norm_factor[:, 0, 0, 0] > 1e-8 | |
| ) # Mask out cases where depth for all views is invalid | |
| valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask | |
| if valid_metric_scale_mask.any(): | |
| # Compute the scale norm factor using the predicted metric scale points | |
| metric_pr_normalization_output = normalize_multiple_pointclouds( | |
| metric_pr_pts_to_compute_scale, | |
| valid_masks, | |
| self.norm_mode, | |
| ret_factor=True, | |
| ) | |
| pr_metric_norm_factor = metric_pr_normalization_output[-1] | |
| # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities | |
| gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] | |
| pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] | |
| else: | |
| gt_metric_norm_factor = None | |
| pr_metric_norm_factor = None | |
| # Get ambiguous masks | |
| ambiguous_masks = [] | |
| for i in range(n_views): | |
| ambiguous_masks.append( | |
| (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) | |
| ) | |
| # Pack into info dicts | |
| gt_info = [] | |
| pred_info = [] | |
| for i in range(n_views): | |
| gt_info.append( | |
| { | |
| "ray_directions": gt_ray_directions[i], | |
| self.depth_type_for_loss: gt_depth[i], | |
| "pose_trans": gt_pose_trans[i], | |
| "pose_quats": gt_pose_quats[i], | |
| "pts3d": gt_pts[i], | |
| "pts3d_cam": gt_pts_cam[i], | |
| } | |
| ) | |
| pred_info.append( | |
| { | |
| "ray_directions": pr_ray_directions[i], | |
| self.depth_type_for_loss: pr_depth[i], | |
| "pose_trans": pr_pose_trans[i], | |
| "pose_quats": pr_pose_quats[i], | |
| "pts3d": pr_pts[i], | |
| "pts3d_cam": pr_pts_cam[i], | |
| } | |
| ) | |
| return ( | |
| gt_info, | |
| pred_info, | |
| valid_masks, | |
| ambiguous_masks, | |
| gt_metric_norm_factor, | |
| pr_metric_norm_factor, | |
| ) | |
| def compute_loss(self, batch, preds, **kw): | |
| ( | |
| gt_info, | |
| pred_info, | |
| valid_masks, | |
| ambiguous_masks, | |
| gt_metric_norm_factor, | |
| pr_metric_norm_factor, | |
| ) = self.get_all_info(batch, preds, **kw) | |
| n_views = len(batch) | |
| # Mask out samples in the batch where the gt depth validity mask is entirely zero | |
| valid_norm_factor_masks = [ | |
| mask.sum(dim=(1, 2)) > 0 for mask in valid_masks | |
| ] # List of (B,) | |
| if self.ambiguous_loss_value > 0: | |
| assert self.criterion.reduction == "none", ( | |
| "ambiguous_loss_value should be 0 if no conf loss" | |
| ) | |
| # Add the ambiguous pixel as "valid" pixels... | |
| valid_masks = [ | |
| mask | ambig_mask | |
| for mask, ambig_mask in zip(valid_masks, ambiguous_masks) | |
| ] | |
| pose_trans_losses = [] | |
| pose_quats_losses = [] | |
| ray_directions_losses = [] | |
| depth_losses = [] | |
| cam_pts3d_losses = [] | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses = [] | |
| for i in range(n_views): | |
| # Get the predicted dense quantities | |
| if not self.flatten_across_image_only: | |
| # Flatten the points across the entire batch with the masks | |
| pred_ray_directions = pred_info[i]["ray_directions"] | |
| gt_ray_directions = gt_info[i]["ray_directions"] | |
| pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] | |
| gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] | |
| pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] | |
| gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] | |
| if self.compute_world_frame_points_loss: | |
| pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] | |
| gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] | |
| else: | |
| # Flatten the H x W dimensions to H*W | |
| batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape | |
| gt_ray_directions = gt_info[i]["ray_directions"].view( | |
| batch_size, -1, direction_dim | |
| ) | |
| pred_ray_directions = pred_info[i]["ray_directions"].view( | |
| batch_size, -1, direction_dim | |
| ) | |
| depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] | |
| gt_depth = gt_info[i][self.depth_type_for_loss].view( | |
| batch_size, -1, depth_dim | |
| ) | |
| pred_depth = pred_info[i][self.depth_type_for_loss].view( | |
| batch_size, -1, depth_dim | |
| ) | |
| cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] | |
| gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) | |
| pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( | |
| batch_size, -1, cam_pts_dim | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts_dim = gt_info[i]["pts3d"].shape[-1] | |
| gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| valid_masks[i] = valid_masks[i].view(batch_size, -1) | |
| # Apply loss in log space for depth if specified | |
| if self.loss_in_log: | |
| gt_depth = apply_log_to_norm(gt_depth) | |
| pred_depth = apply_log_to_norm(pred_depth) | |
| gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) | |
| pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) | |
| if self.compute_world_frame_points_loss: | |
| gt_pts3d = apply_log_to_norm(gt_pts3d) | |
| pred_pts3d = apply_log_to_norm(pred_pts3d) | |
| if self.compute_pairwise_relative_pose_loss: | |
| # Get the inverse of current view predicted pose | |
| pred_inv_curr_view_pose_quats = quaternion_inverse( | |
| pred_info[i]["pose_quats"] | |
| ) | |
| pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( | |
| pred_inv_curr_view_pose_quats | |
| ) | |
| pred_inv_curr_view_pose_trans = -1 * ein.einsum( | |
| pred_inv_curr_view_pose_rot_mat, | |
| pred_info[i]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| # Get the inverse of the current view GT pose | |
| gt_inv_curr_view_pose_quats = quaternion_inverse( | |
| gt_info[i]["pose_quats"] | |
| ) | |
| gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( | |
| gt_inv_curr_view_pose_quats | |
| ) | |
| gt_inv_curr_view_pose_trans = -1 * ein.einsum( | |
| gt_inv_curr_view_pose_rot_mat, | |
| gt_info[i]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| # Get the other N-1 relative poses using the current pose as reference frame | |
| pred_rel_pose_quats = [] | |
| pred_rel_pose_trans = [] | |
| gt_rel_pose_quats = [] | |
| gt_rel_pose_trans = [] | |
| for ov_idx in range(n_views): | |
| if ov_idx == i: | |
| continue | |
| # Get the relative predicted pose | |
| pred_ov_rel_pose_quats = quaternion_multiply( | |
| pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] | |
| ) | |
| pred_ov_rel_pose_trans = ( | |
| ein.einsum( | |
| pred_inv_curr_view_pose_rot_mat, | |
| pred_info[ov_idx]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| + pred_inv_curr_view_pose_trans | |
| ) | |
| # Get the relative GT pose | |
| gt_ov_rel_pose_quats = quaternion_multiply( | |
| gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] | |
| ) | |
| gt_ov_rel_pose_trans = ( | |
| ein.einsum( | |
| gt_inv_curr_view_pose_rot_mat, | |
| gt_info[ov_idx]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| + gt_inv_curr_view_pose_trans | |
| ) | |
| # Get the valid translations using valid_norm_factor_masks for current view and other view | |
| overall_valid_mask_for_trans = ( | |
| valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] | |
| ) | |
| # Append the relative poses | |
| pred_rel_pose_quats.append(pred_ov_rel_pose_quats) | |
| pred_rel_pose_trans.append( | |
| pred_ov_rel_pose_trans[overall_valid_mask_for_trans] | |
| ) | |
| gt_rel_pose_quats.append(gt_ov_rel_pose_quats) | |
| gt_rel_pose_trans.append( | |
| gt_ov_rel_pose_trans[overall_valid_mask_for_trans] | |
| ) | |
| # Cat the N-1 relative poses along the batch dimension | |
| pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) | |
| pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) | |
| gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) | |
| gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| # Handle quaternion two-to-one mapping | |
| pose_quats_loss = torch.minimum( | |
| self.criterion( | |
| pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" | |
| ), | |
| self.criterion( | |
| pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" | |
| ), | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| else: | |
| # Get the pose info for the current view | |
| pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] | |
| gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] | |
| pred_pose_quats = pred_info[i]["pose_quats"] | |
| gt_pose_quats = gt_info[i]["pose_quats"] | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_pose_trans, gt_pose_trans, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| # Handle quaternion two-to-one mapping | |
| pose_quats_loss = torch.minimum( | |
| self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), | |
| self.criterion( | |
| pred_pose_quats, -gt_pose_quats, factor="pose_quats" | |
| ), | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| # Compute ray direction loss | |
| ray_directions_loss = self.criterion( | |
| pred_ray_directions, gt_ray_directions, factor="ray_directions" | |
| ) | |
| ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight | |
| ray_directions_losses.append(ray_directions_loss) | |
| # Compute depth loss | |
| depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") | |
| depth_loss = depth_loss * self.depth_loss_weight | |
| depth_losses.append(depth_loss) | |
| # Compute camera frame point loss | |
| cam_pts3d_loss = self.criterion( | |
| pred_cam_pts3d, gt_cam_pts3d, factor="points" | |
| ) | |
| cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight | |
| cam_pts3d_losses.append(cam_pts3d_loss) | |
| if self.compute_world_frame_points_loss: | |
| # Compute point loss | |
| pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") | |
| pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight | |
| pts3d_losses.append(pts3d_loss) | |
| # Handle ambiguous pixels | |
| if self.ambiguous_loss_value > 0: | |
| if not self.flatten_across_image_only: | |
| depth_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| depth_losses[i], | |
| ) | |
| cam_pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| cam_pts3d_losses[i], | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| else: | |
| depth_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| depth_losses[i], | |
| ) | |
| cam_pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| cam_pts3d_losses[i], | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| # Compute the scale loss | |
| if gt_metric_norm_factor is not None: | |
| if self.loss_in_log: | |
| gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) | |
| pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) | |
| scale_loss = ( | |
| self.criterion( | |
| pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" | |
| ) | |
| * self.scale_loss_weight | |
| ) | |
| else: | |
| scale_loss = None | |
| # Use helper function to generate loss terms and details | |
| if self.compute_world_frame_points_loss: | |
| losses_dict = { | |
| "pts3d": { | |
| "values": pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| } | |
| else: | |
| losses_dict = {} | |
| losses_dict.update( | |
| { | |
| "cam_pts3d": { | |
| "values": cam_pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| self.depth_type_for_loss: { | |
| "values": depth_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "ray_directions": { | |
| "values": ray_directions_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "pose_quats": { | |
| "values": pose_quats_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "pose_trans": { | |
| "values": pose_trans_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "scale": { | |
| "values": scale_loss, | |
| "use_mask": False, | |
| "is_multi_view": False, | |
| }, | |
| } | |
| ) | |
| loss_terms, details = get_loss_terms_and_details( | |
| losses_dict, | |
| valid_masks, | |
| type(self).__name__, | |
| n_views, | |
| self.flatten_across_image_only, | |
| ) | |
| losses = Sum(*loss_terms) | |
| return losses, (details | {}) | |
| class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D): | |
| """ | |
| Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale. | |
| """ | |
| def __init__( | |
| self, | |
| criterion, | |
| norm_predictions=True, | |
| norm_mode="avg_dis", | |
| ambiguous_loss_value=0, | |
| loss_in_log=True, | |
| flatten_across_image_only=False, | |
| depth_type_for_loss="depth_along_ray", | |
| cam_frame_points_loss_weight=1, | |
| depth_loss_weight=1, | |
| ray_directions_loss_weight=1, | |
| pose_quats_loss_weight=1, | |
| pose_trans_loss_weight=1, | |
| scale_loss_weight=1, | |
| compute_pairwise_relative_pose_loss=False, | |
| convert_predictions_to_view0_frame=False, | |
| compute_world_frame_points_loss=True, | |
| world_frame_points_loss_weight=1, | |
| apply_normal_and_gm_loss_to_synthetic_data_only=True, | |
| normal_loss_weight=1, | |
| gm_loss_weight=1, | |
| ): | |
| """ | |
| Initialize the loss criterion for Ray Directions, Depth, Pose, Pointmaps & Scale. | |
| Additionally computes: | |
| (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, | |
| (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) | |
| Args: | |
| criterion (BaseCriterion): The base criterion to use for computing the loss. | |
| norm_predictions (bool): If True, normalize the predictions before computing the loss. | |
| norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". | |
| ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss. | |
| If 0, ambiguous pixels are ignored. Default: 0. | |
| loss_in_log (bool): If True, apply logarithmic transformation to input before | |
| computing the loss for depth, pointmaps and scale. Default: True. | |
| flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing | |
| the loss. If False, flatten across batch and spatial dimensions. Default: False. | |
| depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". | |
| Options: "depth_along_ray", "depth_z" | |
| cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1. | |
| depth_loss_weight (float): Weight to use for the depth loss. Default: 1. | |
| ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. | |
| pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. | |
| pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. | |
| scale_loss_weight (float): Weight to use for the scale loss. Default: 1. | |
| compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the | |
| exhaustive pairwise relative poses. Default: False. | |
| convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame. | |
| Use this if the predictions are not already in the view0 frame. Default: False. | |
| compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True. | |
| world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1. | |
| apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. | |
| If False, apply the normal and gm loss to all data. Default: True. | |
| normal_loss_weight (float): Weight to use for the normal loss. Default: 1. | |
| gm_loss_weight (float): Weight to use for the gm loss. Default: 1. | |
| """ | |
| super().__init__( | |
| criterion=criterion, | |
| norm_predictions=norm_predictions, | |
| norm_mode=norm_mode, | |
| ambiguous_loss_value=ambiguous_loss_value, | |
| loss_in_log=loss_in_log, | |
| flatten_across_image_only=flatten_across_image_only, | |
| depth_type_for_loss=depth_type_for_loss, | |
| cam_frame_points_loss_weight=cam_frame_points_loss_weight, | |
| depth_loss_weight=depth_loss_weight, | |
| ray_directions_loss_weight=ray_directions_loss_weight, | |
| pose_quats_loss_weight=pose_quats_loss_weight, | |
| pose_trans_loss_weight=pose_trans_loss_weight, | |
| scale_loss_weight=scale_loss_weight, | |
| compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss, | |
| convert_predictions_to_view0_frame=convert_predictions_to_view0_frame, | |
| compute_world_frame_points_loss=compute_world_frame_points_loss, | |
| world_frame_points_loss_weight=world_frame_points_loss_weight, | |
| ) | |
| self.apply_normal_and_gm_loss_to_synthetic_data_only = ( | |
| apply_normal_and_gm_loss_to_synthetic_data_only | |
| ) | |
| self.normal_loss_weight = normal_loss_weight | |
| self.gm_loss_weight = gm_loss_weight | |
| def compute_loss(self, batch, preds, **kw): | |
| ( | |
| gt_info, | |
| pred_info, | |
| valid_masks, | |
| ambiguous_masks, | |
| gt_metric_norm_factor, | |
| pr_metric_norm_factor, | |
| ) = self.get_all_info(batch, preds, **kw) | |
| n_views = len(batch) | |
| # Mask out samples in the batch where the gt depth validity mask is entirely zero | |
| valid_norm_factor_masks = [ | |
| mask.sum(dim=(1, 2)) > 0 for mask in valid_masks | |
| ] # List of (B,) | |
| if self.ambiguous_loss_value > 0: | |
| assert self.criterion.reduction == "none", ( | |
| "ambiguous_loss_value should be 0 if no conf loss" | |
| ) | |
| # Add the ambiguous pixel as "valid" pixels... | |
| valid_masks = [ | |
| mask | ambig_mask | |
| for mask, ambig_mask in zip(valid_masks, ambiguous_masks) | |
| ] | |
| normal_losses = [] | |
| gradient_matching_losses = [] | |
| pose_trans_losses = [] | |
| pose_quats_losses = [] | |
| ray_directions_losses = [] | |
| depth_losses = [] | |
| cam_pts3d_losses = [] | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses = [] | |
| for i in range(n_views): | |
| # Get the camera frame points, log space depth_z & valid masks | |
| pred_local_pts3d = pred_info[i]["pts3d_cam"] | |
| pred_depth_z = pred_local_pts3d[..., 2:] | |
| pred_depth_z = apply_log_to_norm(pred_depth_z) | |
| gt_local_pts3d = gt_info[i]["pts3d_cam"] | |
| gt_depth_z = gt_local_pts3d[..., 2:] | |
| gt_depth_z = apply_log_to_norm(gt_depth_z) | |
| valid_mask_for_normal_gm_loss = valid_masks[i].clone() | |
| # Update the validity mask for normal & gm loss based on the synthetic data mask if required | |
| if self.apply_normal_and_gm_loss_to_synthetic_data_only: | |
| synthetic_mask = batch[i]["is_synthetic"] # (B, ) | |
| synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) | |
| synthetic_mask = synthetic_mask.expand( | |
| -1, pred_depth_z.shape[1], pred_depth_z.shape[2] | |
| ) # (B, H, W) | |
| valid_mask_for_normal_gm_loss = ( | |
| valid_mask_for_normal_gm_loss & synthetic_mask | |
| ) | |
| # Compute the normal loss | |
| normal_loss = compute_normal_loss( | |
| pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() | |
| ) | |
| normal_loss = normal_loss * self.normal_loss_weight | |
| normal_losses.append(normal_loss) | |
| # Compute the gradient matching loss | |
| gradient_matching_loss = compute_gradient_matching_loss( | |
| pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() | |
| ) | |
| gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight | |
| gradient_matching_losses.append(gradient_matching_loss) | |
| # Get the predicted dense quantities | |
| if not self.flatten_across_image_only: | |
| # Flatten the points across the entire batch with the masks and compute the metrics | |
| pred_ray_directions = pred_info[i]["ray_directions"] | |
| gt_ray_directions = gt_info[i]["ray_directions"] | |
| pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]] | |
| gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]] | |
| pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]] | |
| gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]] | |
| if self.compute_world_frame_points_loss: | |
| pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]] | |
| gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]] | |
| else: | |
| # Flatten the H x W dimensions to H*W and compute the metrics | |
| batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape | |
| gt_ray_directions = gt_info[i]["ray_directions"].view( | |
| batch_size, -1, direction_dim | |
| ) | |
| pred_ray_directions = pred_info[i]["ray_directions"].view( | |
| batch_size, -1, direction_dim | |
| ) | |
| depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1] | |
| gt_depth = gt_info[i][self.depth_type_for_loss].view( | |
| batch_size, -1, depth_dim | |
| ) | |
| pred_depth = pred_info[i][self.depth_type_for_loss].view( | |
| batch_size, -1, depth_dim | |
| ) | |
| cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1] | |
| gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim) | |
| pred_cam_pts3d = pred_info[i]["pts3d_cam"].view( | |
| batch_size, -1, cam_pts_dim | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts_dim = gt_info[i]["pts3d"].shape[-1] | |
| gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim) | |
| valid_masks[i] = valid_masks[i].view(batch_size, -1) | |
| # Apply loss in log space for depth if specified | |
| if self.loss_in_log: | |
| gt_depth = apply_log_to_norm(gt_depth) | |
| pred_depth = apply_log_to_norm(pred_depth) | |
| gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d) | |
| pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d) | |
| if self.compute_world_frame_points_loss: | |
| gt_pts3d = apply_log_to_norm(gt_pts3d) | |
| pred_pts3d = apply_log_to_norm(pred_pts3d) | |
| if self.compute_pairwise_relative_pose_loss: | |
| # Get the inverse of current view predicted pose | |
| pred_inv_curr_view_pose_quats = quaternion_inverse( | |
| pred_info[i]["pose_quats"] | |
| ) | |
| pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( | |
| pred_inv_curr_view_pose_quats | |
| ) | |
| pred_inv_curr_view_pose_trans = -1 * ein.einsum( | |
| pred_inv_curr_view_pose_rot_mat, | |
| pred_info[i]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| # Get the inverse of the current view GT pose | |
| gt_inv_curr_view_pose_quats = quaternion_inverse( | |
| gt_info[i]["pose_quats"] | |
| ) | |
| gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix( | |
| gt_inv_curr_view_pose_quats | |
| ) | |
| gt_inv_curr_view_pose_trans = -1 * ein.einsum( | |
| gt_inv_curr_view_pose_rot_mat, | |
| gt_info[i]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| # Get the other N-1 relative poses using the current pose as reference frame | |
| pred_rel_pose_quats = [] | |
| pred_rel_pose_trans = [] | |
| gt_rel_pose_quats = [] | |
| gt_rel_pose_trans = [] | |
| for ov_idx in range(n_views): | |
| if ov_idx == i: | |
| continue | |
| # Get the relative predicted pose | |
| pred_ov_rel_pose_quats = quaternion_multiply( | |
| pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"] | |
| ) | |
| pred_ov_rel_pose_trans = ( | |
| ein.einsum( | |
| pred_inv_curr_view_pose_rot_mat, | |
| pred_info[ov_idx]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| + pred_inv_curr_view_pose_trans | |
| ) | |
| # Get the relative GT pose | |
| gt_ov_rel_pose_quats = quaternion_multiply( | |
| gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"] | |
| ) | |
| gt_ov_rel_pose_trans = ( | |
| ein.einsum( | |
| gt_inv_curr_view_pose_rot_mat, | |
| gt_info[ov_idx]["pose_trans"], | |
| "b i j, b j -> b i", | |
| ) | |
| + gt_inv_curr_view_pose_trans | |
| ) | |
| # Get the valid translations using valid_norm_factor_masks for current view and other view | |
| overall_valid_mask_for_trans = ( | |
| valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx] | |
| ) | |
| # Append the relative poses | |
| pred_rel_pose_quats.append(pred_ov_rel_pose_quats) | |
| pred_rel_pose_trans.append( | |
| pred_ov_rel_pose_trans[overall_valid_mask_for_trans] | |
| ) | |
| gt_rel_pose_quats.append(gt_ov_rel_pose_quats) | |
| gt_rel_pose_trans.append( | |
| gt_ov_rel_pose_trans[overall_valid_mask_for_trans] | |
| ) | |
| # Cat the N-1 relative poses along the batch dimension | |
| pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0) | |
| pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0) | |
| gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0) | |
| gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0) | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| # Handle quaternion two-to-one mapping | |
| pose_quats_loss = torch.minimum( | |
| self.criterion( | |
| pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats" | |
| ), | |
| self.criterion( | |
| pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats" | |
| ), | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| else: | |
| # Get the pose info for the current view | |
| pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]] | |
| gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]] | |
| pred_pose_quats = pred_info[i]["pose_quats"] | |
| gt_pose_quats = gt_info[i]["pose_quats"] | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_pose_trans, gt_pose_trans, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| # Handle quaternion two-to-one mapping | |
| pose_quats_loss = torch.minimum( | |
| self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"), | |
| self.criterion( | |
| pred_pose_quats, -gt_pose_quats, factor="pose_quats" | |
| ), | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| # Compute ray direction loss | |
| ray_directions_loss = self.criterion( | |
| pred_ray_directions, gt_ray_directions, factor="ray_directions" | |
| ) | |
| ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight | |
| ray_directions_losses.append(ray_directions_loss) | |
| # Compute depth loss | |
| depth_loss = self.criterion(pred_depth, gt_depth, factor="depth") | |
| depth_loss = depth_loss * self.depth_loss_weight | |
| depth_losses.append(depth_loss) | |
| # Compute camera frame point loss | |
| cam_pts3d_loss = self.criterion( | |
| pred_cam_pts3d, gt_cam_pts3d, factor="points" | |
| ) | |
| cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight | |
| cam_pts3d_losses.append(cam_pts3d_loss) | |
| if self.compute_world_frame_points_loss: | |
| # Compute point loss | |
| pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points") | |
| pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight | |
| pts3d_losses.append(pts3d_loss) | |
| # Handle ambiguous pixels | |
| if self.ambiguous_loss_value > 0: | |
| if not self.flatten_across_image_only: | |
| depth_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| depth_losses[i], | |
| ) | |
| cam_pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| cam_pts3d_losses[i], | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i][valid_masks[i]], | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| else: | |
| depth_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| depth_losses[i], | |
| ) | |
| cam_pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| cam_pts3d_losses[i], | |
| ) | |
| if self.compute_world_frame_points_loss: | |
| pts3d_losses[i] = torch.where( | |
| ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1), | |
| self.ambiguous_loss_value, | |
| pts3d_losses[i], | |
| ) | |
| # Compute the scale loss | |
| if gt_metric_norm_factor is not None: | |
| if self.loss_in_log: | |
| gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) | |
| pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) | |
| scale_loss = ( | |
| self.criterion( | |
| pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" | |
| ) | |
| * self.scale_loss_weight | |
| ) | |
| else: | |
| scale_loss = None | |
| # Use helper function to generate loss terms and details | |
| if self.compute_world_frame_points_loss: | |
| losses_dict = { | |
| "pts3d": { | |
| "values": pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| } | |
| else: | |
| losses_dict = {} | |
| losses_dict.update( | |
| { | |
| "cam_pts3d": { | |
| "values": cam_pts3d_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| self.depth_type_for_loss: { | |
| "values": depth_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "ray_directions": { | |
| "values": ray_directions_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "pose_quats": { | |
| "values": pose_quats_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "pose_trans": { | |
| "values": pose_trans_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "scale": { | |
| "values": scale_loss, | |
| "use_mask": False, | |
| "is_multi_view": False, | |
| }, | |
| "normal": { | |
| "values": normal_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "gradient_matching": { | |
| "values": gradient_matching_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| } | |
| ) | |
| loss_terms, details = get_loss_terms_and_details( | |
| losses_dict, | |
| valid_masks, | |
| type(self).__name__, | |
| n_views, | |
| self.flatten_across_image_only, | |
| ) | |
| losses = Sum(*loss_terms) | |
| return losses, (details | {}) | |
| class DisentangledFactoredGeometryScaleRegr3D(Criterion, MultiLoss): | |
| """ | |
| Disentangled Regression Loss for Factored Geometry & Scale. | |
| """ | |
| def __init__( | |
| self, | |
| criterion, | |
| norm_predictions=True, | |
| norm_mode="avg_dis", | |
| loss_in_log=True, | |
| flatten_across_image_only=False, | |
| depth_type_for_loss="depth_along_ray", | |
| depth_loss_weight=1, | |
| ray_directions_loss_weight=1, | |
| pose_quats_loss_weight=1, | |
| pose_trans_loss_weight=1, | |
| scale_loss_weight=1, | |
| ): | |
| """ | |
| Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale. | |
| It isolates/disentangles the contribution of each factor to the final task of 3D reconstruction. | |
| All the losses are in the same space where the loss for each factor is computed by constructing world-frame pointmaps. | |
| This sidesteps the difficulty of finding a proper weighting. | |
| For insance, for predicted rays, the GT depth & pose is used to construct the predicted world-frame pointmaps on which the loss is computed. | |
| Inspired by https://openaccess.thecvf.com/content_ICCV_2019/papers/Simonelli_Disentangling_Monocular_3D_Object_Detection_ICCV_2019_paper.pdf | |
| The pixel-level losses are computed in the following order: | |
| (1) depth, (2) ray directions, (3) pose quats, (4) pose trans, (5) scale. | |
| The predicited scene representation is always normalized w.r.t. the frame of view0. | |
| Loss is applied between the predicted metric scale and the ground truth metric scale. | |
| Args: | |
| criterion (BaseCriterion): The base criterion to use for computing the loss. | |
| norm_predictions (bool): If True, normalize the predictions before computing the loss. | |
| norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". | |
| loss_in_log (bool): If True, apply logarithmic transformation to input before | |
| computing the loss for depth, pointmaps and scale. Default: True. | |
| flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing | |
| the loss. If False, flatten across batch and spatial dimensions. Default: False. | |
| depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". | |
| Options: "depth_along_ray", "depth_z" | |
| depth_loss_weight (float): Weight to use for the depth loss. Default: 1. | |
| ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. | |
| pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. | |
| pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. | |
| scale_loss_weight (float): Weight to use for the scale loss. Default: 1. | |
| """ | |
| super().__init__(criterion) | |
| self.norm_predictions = norm_predictions | |
| self.norm_mode = norm_mode | |
| self.loss_in_log = loss_in_log | |
| self.flatten_across_image_only = flatten_across_image_only | |
| self.depth_type_for_loss = depth_type_for_loss | |
| assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], ( | |
| "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']" | |
| ) | |
| self.depth_loss_weight = depth_loss_weight | |
| self.ray_directions_loss_weight = ray_directions_loss_weight | |
| self.pose_quats_loss_weight = pose_quats_loss_weight | |
| self.pose_trans_loss_weight = pose_trans_loss_weight | |
| self.scale_loss_weight = scale_loss_weight | |
| def get_all_info(self, batch, preds, dist_clip=None): | |
| """ | |
| Function to get all the information needed to compute the loss. | |
| Returns all quantities normalized w.r.t. camera of view0. | |
| """ | |
| n_views = len(batch) | |
| # Everything is normalized w.r.t. camera of view0 | |
| # Intialize lists to store data for all views | |
| # Ground truth quantities | |
| in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"]) | |
| no_norm_gt_pts = [] | |
| no_norm_gt_pts_cam = [] | |
| no_norm_gt_depth = [] | |
| no_norm_gt_pose_trans = [] | |
| valid_masks = [] | |
| gt_ray_directions = [] | |
| gt_pose_quats = [] | |
| # Predicted quantities | |
| no_norm_pr_pts = [] | |
| no_norm_pr_pts_cam = [] | |
| no_norm_pr_depth = [] | |
| no_norm_pr_pose_trans = [] | |
| pr_ray_directions = [] | |
| pr_pose_quats = [] | |
| metric_pr_pts_to_compute_scale = [] | |
| # Get ground truth & prediction info for all views | |
| for i in range(n_views): | |
| # Get the ground truth | |
| no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"])) | |
| valid_masks.append(batch[i]["valid_mask"].clone()) | |
| no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"]) | |
| gt_ray_directions.append(batch[i]["ray_directions_cam"]) | |
| if self.depth_type_for_loss == "depth_along_ray": | |
| no_norm_gt_depth.append(batch[i]["depth_along_ray"]) | |
| elif self.depth_type_for_loss == "depth_z": | |
| no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:]) | |
| if i == 0: | |
| # For view0, initialize identity pose | |
| gt_pose_quats.append( | |
| torch.tensor( | |
| [0, 0, 0, 1], | |
| dtype=gt_ray_directions[0].dtype, | |
| device=gt_ray_directions[0].device, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(gt_ray_directions[0].shape[0], 1) | |
| ) | |
| no_norm_gt_pose_trans.append( | |
| torch.tensor( | |
| [0, 0, 0], | |
| dtype=gt_ray_directions[0].dtype, | |
| device=gt_ray_directions[0].device, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(gt_ray_directions[0].shape[0], 1) | |
| ) | |
| else: | |
| # For other views, transform pose to view0's frame | |
| gt_pose_quats_world = batch[i]["camera_pose_quats"] | |
| no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"] | |
| gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = ( | |
| transform_pose_using_quats_and_trans_2_to_1( | |
| batch[0]["camera_pose_quats"], | |
| batch[0]["camera_pose_trans"], | |
| gt_pose_quats_world, | |
| no_norm_gt_pose_trans_world, | |
| ) | |
| ) | |
| gt_pose_quats.append(gt_pose_quats_in_view0) | |
| no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0) | |
| # Get predictions for normalized loss | |
| if self.depth_type_for_loss == "depth_along_ray": | |
| curr_view_no_norm_depth = preds[i]["depth_along_ray"] | |
| elif self.depth_type_for_loss == "depth_z": | |
| curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:] | |
| if "metric_scaling_factor" in preds[i].keys(): | |
| # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans | |
| # This detaches the predicted metric scaling factor from the geometry based loss | |
| curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][ | |
| "metric_scaling_factor" | |
| ].unsqueeze(-1).unsqueeze(-1) | |
| curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][ | |
| "metric_scaling_factor" | |
| ].unsqueeze(-1).unsqueeze(-1) | |
| curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][ | |
| "metric_scaling_factor" | |
| ].unsqueeze(-1).unsqueeze(-1) | |
| curr_view_no_norm_pr_pose_trans = ( | |
| preds[i]["cam_trans"] / preds[i]["metric_scaling_factor"] | |
| ) | |
| else: | |
| curr_view_no_norm_pr_pts = preds[i]["pts3d"] | |
| curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] | |
| curr_view_no_norm_depth = curr_view_no_norm_depth | |
| curr_view_no_norm_pr_pose_trans = preds[i]["cam_trans"] | |
| no_norm_pr_pts.append(curr_view_no_norm_pr_pts) | |
| no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam) | |
| no_norm_pr_depth.append(curr_view_no_norm_depth) | |
| no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans) | |
| pr_ray_directions.append(preds[i]["ray_directions"]) | |
| pr_pose_quats.append(preds[i]["cam_quats"]) | |
| # Get the predicted metric scale points | |
| if "metric_scaling_factor" in preds[i].keys(): | |
| # Detach the raw predicted points so that the scale loss is only applied to the scaling factor | |
| curr_view_metric_pr_pts_to_compute_scale = ( | |
| curr_view_no_norm_pr_pts.detach() | |
| * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1) | |
| ) | |
| else: | |
| curr_view_metric_pr_pts_to_compute_scale = ( | |
| curr_view_no_norm_pr_pts.clone() | |
| ) | |
| metric_pr_pts_to_compute_scale.append( | |
| curr_view_metric_pr_pts_to_compute_scale | |
| ) | |
| if dist_clip is not None: | |
| # Points that are too far-away == invalid | |
| for i in range(n_views): | |
| dis = no_norm_gt_pts[i].norm(dim=-1) | |
| valid_masks[i] = valid_masks[i] & (dis <= dist_clip) | |
| # Initialize normalized tensors | |
| gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts] | |
| gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam] | |
| gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth] | |
| gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans] | |
| pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts] | |
| pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam] | |
| pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth] | |
| pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans] | |
| # Normalize the predicted points if specified | |
| if self.norm_predictions: | |
| pr_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_pr_pts, | |
| valid_masks, | |
| self.norm_mode, | |
| ret_factor=True, | |
| ) | |
| pr_pts_norm = pr_normalization_output[:-1] | |
| pr_norm_factor = pr_normalization_output[-1] | |
| # Normalize the ground truth points | |
| gt_normalization_output = normalize_multiple_pointclouds( | |
| no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True | |
| ) | |
| gt_pts_norm = gt_normalization_output[:-1] | |
| gt_norm_factor = gt_normalization_output[-1] | |
| for i in range(n_views): | |
| if self.norm_predictions: | |
| # Assign the normalized predictions | |
| pr_pts[i] = pr_pts_norm[i] | |
| pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor | |
| pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor | |
| pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0] | |
| else: | |
| pr_pts[i] = no_norm_pr_pts[i] | |
| pr_pts_cam[i] = no_norm_pr_pts_cam[i] | |
| pr_depth[i] = no_norm_pr_depth[i] | |
| pr_pose_trans[i] = no_norm_pr_pose_trans[i] | |
| # Assign the normalized ground truth quantities | |
| gt_pts[i] = gt_pts_norm[i] | |
| gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor | |
| gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor | |
| gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0] | |
| # Get the mask indicating ground truth metric scale quantities | |
| metric_scale_mask = batch[0]["is_metric_scale"] | |
| valid_gt_norm_factor_mask = ( | |
| gt_norm_factor[:, 0, 0, 0] > 1e-8 | |
| ) # Mask out cases where depth for all views is invalid | |
| valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask | |
| if valid_metric_scale_mask.any(): | |
| # Compute the scale norm factor using the predicted metric scale points | |
| metric_pr_normalization_output = normalize_multiple_pointclouds( | |
| metric_pr_pts_to_compute_scale, | |
| valid_masks, | |
| self.norm_mode, | |
| ret_factor=True, | |
| ) | |
| pr_metric_norm_factor = metric_pr_normalization_output[-1] | |
| # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities | |
| gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask] | |
| pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask] | |
| else: | |
| gt_metric_norm_factor = None | |
| pr_metric_norm_factor = None | |
| # Get ambiguous masks | |
| ambiguous_masks = [] | |
| for i in range(n_views): | |
| ambiguous_masks.append( | |
| (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i]) | |
| ) | |
| # Pack into info dicts | |
| gt_info = [] | |
| pred_info = [] | |
| for i in range(n_views): | |
| gt_info.append( | |
| { | |
| "ray_directions": gt_ray_directions[i], | |
| self.depth_type_for_loss: gt_depth[i], | |
| "pose_trans": gt_pose_trans[i], | |
| "pose_quats": gt_pose_quats[i], | |
| "pts3d": gt_pts[i], | |
| "pts3d_cam": gt_pts_cam[i], | |
| } | |
| ) | |
| pred_info.append( | |
| { | |
| "ray_directions": pr_ray_directions[i], | |
| self.depth_type_for_loss: pr_depth[i], | |
| "pose_trans": pr_pose_trans[i], | |
| "pose_quats": pr_pose_quats[i], | |
| "pts3d": pr_pts[i], | |
| "pts3d_cam": pr_pts_cam[i], | |
| } | |
| ) | |
| return ( | |
| gt_info, | |
| pred_info, | |
| valid_masks, | |
| ambiguous_masks, | |
| gt_metric_norm_factor, | |
| pr_metric_norm_factor, | |
| ) | |
| def compute_loss(self, batch, preds, **kw): | |
| ( | |
| gt_info, | |
| pred_info, | |
| valid_masks, | |
| ambiguous_masks, | |
| gt_metric_norm_factor, | |
| pr_metric_norm_factor, | |
| ) = self.get_all_info(batch, preds, **kw) | |
| n_views = len(batch) | |
| pose_trans_losses = [] | |
| pose_quats_losses = [] | |
| ray_directions_losses = [] | |
| depth_losses = [] | |
| for i in range(n_views): | |
| # Get the GT factored quantities for the current view | |
| gt_pts3d = gt_info[i]["pts3d"] | |
| gt_ray_directions = gt_info[i]["ray_directions"] | |
| gt_depth = gt_info[i][self.depth_type_for_loss] | |
| gt_pose_trans = gt_info[i]["pose_trans"] | |
| gt_pose_quats = gt_info[i]["pose_quats"] | |
| # Get the predicted factored quantities for the current view | |
| pred_ray_directions = pred_info[i]["ray_directions"] | |
| pred_depth = pred_info[i][self.depth_type_for_loss] | |
| pred_pose_trans = pred_info[i]["pose_trans"] | |
| pred_pose_quats = pred_info[i]["pose_quats"] | |
| # Get the predicted world-frame pointmaps using the different factors | |
| if self.depth_type_for_loss == "depth_along_ray": | |
| pred_ray_directions_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| pred_ray_directions, | |
| gt_depth, | |
| gt_pose_trans, | |
| gt_pose_quats, | |
| ) | |
| ) | |
| pred_depth_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| gt_ray_directions, | |
| pred_depth, | |
| gt_pose_trans, | |
| gt_pose_quats, | |
| ) | |
| ) | |
| pred_pose_trans_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| gt_ray_directions, | |
| gt_depth, | |
| pred_pose_trans, | |
| gt_pose_quats, | |
| ) | |
| ) | |
| pred_pose_quats_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| gt_ray_directions, | |
| gt_depth, | |
| gt_pose_trans, | |
| pred_pose_quats, | |
| ) | |
| ) | |
| else: | |
| raise NotImplementedError | |
| # Mask out the valid quantities as required | |
| if not self.flatten_across_image_only: | |
| # Flatten the points across the entire batch with the masks | |
| pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]] | |
| pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]] | |
| pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]] | |
| pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]] | |
| gt_pts3d = gt_pts3d[valid_masks[i]] | |
| else: | |
| # Flatten the H x W dimensions to H*W | |
| batch_size, _, _, pts_dim = gt_pts3d.shape | |
| pred_ray_directions_pts3d = pred_ray_directions_pts3d.view( | |
| batch_size, -1, pts_dim | |
| ) | |
| pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim) | |
| pred_pose_trans_pts3d = pred_pose_trans_pts3d.view( | |
| batch_size, -1, pts_dim | |
| ) | |
| pred_pose_quats_pts3d = pred_pose_quats_pts3d.view( | |
| batch_size, -1, pts_dim | |
| ) | |
| gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim) | |
| valid_masks[i] = valid_masks[i].view(batch_size, -1) | |
| # Apply loss in log space if specified | |
| if self.loss_in_log: | |
| gt_pts3d = apply_log_to_norm(gt_pts3d) | |
| pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d) | |
| pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d) | |
| pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d) | |
| pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d) | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| pose_quats_loss = self.criterion( | |
| pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats" | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| # Compute ray direction loss | |
| ray_directions_loss = self.criterion( | |
| pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions" | |
| ) | |
| ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight | |
| ray_directions_losses.append(ray_directions_loss) | |
| # Compute depth loss | |
| depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth") | |
| depth_loss = depth_loss * self.depth_loss_weight | |
| depth_losses.append(depth_loss) | |
| # Compute the scale loss | |
| if gt_metric_norm_factor is not None: | |
| if self.loss_in_log: | |
| gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) | |
| pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) | |
| scale_loss = ( | |
| self.criterion( | |
| pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" | |
| ) | |
| * self.scale_loss_weight | |
| ) | |
| else: | |
| scale_loss = None | |
| # Use helper function to generate loss terms and details | |
| losses_dict = {} | |
| losses_dict.update( | |
| { | |
| self.depth_type_for_loss: { | |
| "values": depth_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "ray_directions": { | |
| "values": ray_directions_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "pose_quats": { | |
| "values": pose_quats_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "pose_trans": { | |
| "values": pose_trans_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "scale": { | |
| "values": scale_loss, | |
| "use_mask": False, | |
| "is_multi_view": False, | |
| }, | |
| } | |
| ) | |
| loss_terms, details = get_loss_terms_and_details( | |
| losses_dict, | |
| valid_masks, | |
| type(self).__name__, | |
| n_views, | |
| self.flatten_across_image_only, | |
| ) | |
| losses = Sum(*loss_terms) | |
| return losses, (details | {}) | |
| class DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss( | |
| DisentangledFactoredGeometryScaleRegr3D | |
| ): | |
| """ | |
| Disentangled Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale. | |
| """ | |
| def __init__( | |
| self, | |
| criterion, | |
| norm_predictions=True, | |
| norm_mode="avg_dis", | |
| loss_in_log=True, | |
| flatten_across_image_only=False, | |
| depth_type_for_loss="depth_along_ray", | |
| depth_loss_weight=1, | |
| ray_directions_loss_weight=1, | |
| pose_quats_loss_weight=1, | |
| pose_trans_loss_weight=1, | |
| scale_loss_weight=1, | |
| apply_normal_and_gm_loss_to_synthetic_data_only=True, | |
| normal_loss_weight=1, | |
| gm_loss_weight=1, | |
| ): | |
| """ | |
| Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale. | |
| See parent class (DisentangledFactoredGeometryScaleRegr3D) for more details. | |
| Additionally computes: | |
| (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates, | |
| (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space) | |
| Args: | |
| criterion (BaseCriterion): The base criterion to use for computing the loss. | |
| norm_predictions (bool): If True, normalize the predictions before computing the loss. | |
| norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis". | |
| loss_in_log (bool): If True, apply logarithmic transformation to input before | |
| computing the loss for depth, pointmaps and scale. Default: True. | |
| flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing | |
| the loss. If False, flatten across batch and spatial dimensions. Default: False. | |
| depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray". | |
| Options: "depth_along_ray", "depth_z" | |
| depth_loss_weight (float): Weight to use for the depth loss. Default: 1. | |
| ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1. | |
| pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1. | |
| pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1. | |
| scale_loss_weight (float): Weight to use for the scale loss. Default: 1. | |
| apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data. | |
| If False, apply the normal and gm loss to all data. Default: True. | |
| normal_loss_weight (float): Weight to use for the normal loss. Default: 1. | |
| gm_loss_weight (float): Weight to use for the gm loss. Default: 1. | |
| """ | |
| super().__init__( | |
| criterion=criterion, | |
| norm_predictions=norm_predictions, | |
| norm_mode=norm_mode, | |
| loss_in_log=loss_in_log, | |
| flatten_across_image_only=flatten_across_image_only, | |
| depth_type_for_loss=depth_type_for_loss, | |
| depth_loss_weight=depth_loss_weight, | |
| ray_directions_loss_weight=ray_directions_loss_weight, | |
| pose_quats_loss_weight=pose_quats_loss_weight, | |
| pose_trans_loss_weight=pose_trans_loss_weight, | |
| scale_loss_weight=scale_loss_weight, | |
| ) | |
| self.apply_normal_and_gm_loss_to_synthetic_data_only = ( | |
| apply_normal_and_gm_loss_to_synthetic_data_only | |
| ) | |
| self.normal_loss_weight = normal_loss_weight | |
| self.gm_loss_weight = gm_loss_weight | |
| def compute_loss(self, batch, preds, **kw): | |
| ( | |
| gt_info, | |
| pred_info, | |
| valid_masks, | |
| ambiguous_masks, | |
| gt_metric_norm_factor, | |
| pr_metric_norm_factor, | |
| ) = self.get_all_info(batch, preds, **kw) | |
| n_views = len(batch) | |
| normal_losses = [] | |
| gradient_matching_losses = [] | |
| pose_trans_losses = [] | |
| pose_quats_losses = [] | |
| ray_directions_losses = [] | |
| depth_losses = [] | |
| for i in range(n_views): | |
| # Get the camera frame points, log space depth_z & valid masks | |
| pred_local_pts3d = pred_info[i]["pts3d_cam"] | |
| pred_depth_z = pred_local_pts3d[..., 2:] | |
| pred_depth_z = apply_log_to_norm(pred_depth_z) | |
| gt_local_pts3d = gt_info[i]["pts3d_cam"] | |
| gt_depth_z = gt_local_pts3d[..., 2:] | |
| gt_depth_z = apply_log_to_norm(gt_depth_z) | |
| valid_mask_for_normal_gm_loss = valid_masks[i].clone() | |
| # Update the validity mask for normal & gm loss based on the synthetic data mask if required | |
| if self.apply_normal_and_gm_loss_to_synthetic_data_only: | |
| synthetic_mask = batch[i]["is_synthetic"] # (B, ) | |
| synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1) | |
| synthetic_mask = synthetic_mask.expand( | |
| -1, pred_depth_z.shape[1], pred_depth_z.shape[2] | |
| ) # (B, H, W) | |
| valid_mask_for_normal_gm_loss = ( | |
| valid_mask_for_normal_gm_loss & synthetic_mask | |
| ) | |
| # Compute the normal loss | |
| normal_loss = compute_normal_loss( | |
| pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone() | |
| ) | |
| normal_loss = normal_loss * self.normal_loss_weight | |
| normal_losses.append(normal_loss) | |
| # Compute the gradient matching loss | |
| gradient_matching_loss = compute_gradient_matching_loss( | |
| pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone() | |
| ) | |
| gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight | |
| gradient_matching_losses.append(gradient_matching_loss) | |
| # Get the GT factored quantities for the current view | |
| gt_pts3d = gt_info[i]["pts3d"] | |
| gt_ray_directions = gt_info[i]["ray_directions"] | |
| gt_depth = gt_info[i][self.depth_type_for_loss] | |
| gt_pose_trans = gt_info[i]["pose_trans"] | |
| gt_pose_quats = gt_info[i]["pose_quats"] | |
| # Get the predicted factored quantities for the current view | |
| pred_ray_directions = pred_info[i]["ray_directions"] | |
| pred_depth = pred_info[i][self.depth_type_for_loss] | |
| pred_pose_trans = pred_info[i]["pose_trans"] | |
| pred_pose_quats = pred_info[i]["pose_quats"] | |
| # Get the predicted world-frame pointmaps using the different factors | |
| if self.depth_type_for_loss == "depth_along_ray": | |
| pred_ray_directions_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| pred_ray_directions, | |
| gt_depth, | |
| gt_pose_trans, | |
| gt_pose_quats, | |
| ) | |
| ) | |
| pred_depth_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| gt_ray_directions, | |
| pred_depth, | |
| gt_pose_trans, | |
| gt_pose_quats, | |
| ) | |
| ) | |
| pred_pose_trans_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| gt_ray_directions, | |
| gt_depth, | |
| pred_pose_trans, | |
| gt_pose_quats, | |
| ) | |
| ) | |
| pred_pose_quats_pts3d = ( | |
| convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| gt_ray_directions, | |
| gt_depth, | |
| gt_pose_trans, | |
| pred_pose_quats, | |
| ) | |
| ) | |
| else: | |
| raise NotImplementedError | |
| # Mask out the valid quantities as required | |
| if not self.flatten_across_image_only: | |
| # Flatten the points across the entire batch with the masks | |
| pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]] | |
| pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]] | |
| pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]] | |
| pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]] | |
| gt_pts3d = gt_pts3d[valid_masks[i]] | |
| else: | |
| # Flatten the H x W dimensions to H*W | |
| batch_size, _, _, pts_dim = gt_pts3d.shape | |
| pred_ray_directions_pts3d = pred_ray_directions_pts3d.view( | |
| batch_size, -1, pts_dim | |
| ) | |
| pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim) | |
| pred_pose_trans_pts3d = pred_pose_trans_pts3d.view( | |
| batch_size, -1, pts_dim | |
| ) | |
| pred_pose_quats_pts3d = pred_pose_quats_pts3d.view( | |
| batch_size, -1, pts_dim | |
| ) | |
| gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim) | |
| valid_masks[i] = valid_masks[i].view(batch_size, -1) | |
| # Apply loss in log space if specified | |
| if self.loss_in_log: | |
| gt_pts3d = apply_log_to_norm(gt_pts3d) | |
| pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d) | |
| pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d) | |
| pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d) | |
| pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d) | |
| # Compute pose translation loss | |
| pose_trans_loss = self.criterion( | |
| pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans" | |
| ) | |
| pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight | |
| pose_trans_losses.append(pose_trans_loss) | |
| # Compute pose rotation loss | |
| pose_quats_loss = self.criterion( | |
| pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats" | |
| ) | |
| pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight | |
| pose_quats_losses.append(pose_quats_loss) | |
| # Compute ray direction loss | |
| ray_directions_loss = self.criterion( | |
| pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions" | |
| ) | |
| ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight | |
| ray_directions_losses.append(ray_directions_loss) | |
| # Compute depth loss | |
| depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth") | |
| depth_loss = depth_loss * self.depth_loss_weight | |
| depth_losses.append(depth_loss) | |
| # Compute the scale loss | |
| if gt_metric_norm_factor is not None: | |
| if self.loss_in_log: | |
| gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor) | |
| pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor) | |
| scale_loss = ( | |
| self.criterion( | |
| pr_metric_norm_factor, gt_metric_norm_factor, factor="scale" | |
| ) | |
| * self.scale_loss_weight | |
| ) | |
| else: | |
| scale_loss = None | |
| # Use helper function to generate loss terms and details | |
| losses_dict = {} | |
| losses_dict.update( | |
| { | |
| self.depth_type_for_loss: { | |
| "values": depth_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "ray_directions": { | |
| "values": ray_directions_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "pose_quats": { | |
| "values": pose_quats_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "pose_trans": { | |
| "values": pose_trans_losses, | |
| "use_mask": True, | |
| "is_multi_view": True, | |
| }, | |
| "scale": { | |
| "values": scale_loss, | |
| "use_mask": False, | |
| "is_multi_view": False, | |
| }, | |
| "normal": { | |
| "values": normal_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| "gradient_matching": { | |
| "values": gradient_matching_losses, | |
| "use_mask": False, | |
| "is_multi_view": True, | |
| }, | |
| } | |
| ) | |
| loss_terms, details = get_loss_terms_and_details( | |
| losses_dict, | |
| valid_masks, | |
| type(self).__name__, | |
| n_views, | |
| self.flatten_across_image_only, | |
| ) | |
| losses = Sum(*loss_terms) | |
| return losses, (details | {}) | |