aknapitsch user
simpler inference and refactoring
37de32d
raw
history blame
222 kB
"""
Multi-view geometric losses for training 3D reconstruction models.
References: DUSt3R & MASt3R
"""
import math
from copy import copy, deepcopy
import einops as ein
import torch
import torch.nn as nn
from mapanything.utils.geometry import (
angle_diff_vec3,
apply_log_to_norm,
closed_form_pose_inverse,
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
geotrf,
normalize_multiple_pointclouds,
quaternion_inverse,
quaternion_multiply,
quaternion_to_rotation_matrix,
transform_pose_using_quats_and_trans_2_to_1,
)
def get_loss_terms_and_details(
losses_dict, valid_masks, self_name, n_views, flatten_across_image_only
):
"""
Helper function to generate loss terms and details for different loss types.
Args:
losses_dict (dict): Dictionary mapping loss types to their values.
Format: {
'loss_type': {
'values': list_of_loss_tensors or single_tensor,
'use_mask': bool,
'is_multi_view': bool
}
}
valid_masks (list): List of valid masks for each view.
self_name (str): Name of the loss class.
n_views (int): Number of views.
flatten_across_image_only (bool): Whether flattening was done across image only.
Returns:
tuple: (loss_terms, details) where loss_terms is a list of tuples (loss, mask, type)
and details is a dictionary of loss details.
"""
loss_terms = []
details = {}
for loss_type, loss_info in losses_dict.items():
values = loss_info["values"]
use_mask = loss_info["use_mask"]
is_multi_view = loss_info["is_multi_view"]
if is_multi_view:
# Handle multi-view losses (list of tensors)
view_loss_details = []
for i in range(n_views):
mask = valid_masks[i] if use_mask else None
loss_terms.append((values[i], mask, loss_type))
# Add details for individual view
if not flatten_across_image_only or not use_mask:
values_after_masking = values[i]
else:
values_after_masking = values[i][mask]
if values_after_masking.numel() > 0:
view_loss_detail = float(values_after_masking.mean())
if view_loss_detail > 0:
details[f"{self_name}_{loss_type}_view{i + 1}"] = (
view_loss_detail
)
view_loss_details.append(view_loss_detail)
# Add average across views
if len(view_loss_details) > 0:
details[f"{self_name}_{loss_type}_avg"] = sum(view_loss_details) / len(
view_loss_details
)
else:
# Handle single tensor losses
if values is not None:
loss_terms.append((values, None, loss_type))
if values.numel() > 0:
loss_detail = float(values.mean())
if loss_detail > 0:
details[f"{self_name}_{loss_type}"] = loss_detail
return loss_terms, details
def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
if beta == 0:
return err
else:
return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
def compute_normal_loss(points, gt_points, mask):
"""
Compute the normal loss between the predicted and ground truth points.
References:
https://github.com/microsoft/MoGe/blob/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/train/losses.py#L205
Args:
points (torch.Tensor): Predicted points. Shape: (..., H, W, 3).
gt_points (torch.Tensor): Ground truth points. Shape: (..., H, W, 3).
mask (torch.Tensor): Mask indicating valid points. Shape: (..., H, W).
Returns:
torch.Tensor: Normal loss.
"""
height, width = points.shape[-3:-1]
leftup, rightup, leftdown, rightdown = (
points[..., :-1, :-1, :],
points[..., :-1, 1:, :],
points[..., 1:, :-1, :],
points[..., 1:, 1:, :],
)
upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1)
leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1)
downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1)
rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1)
gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = (
gt_points[..., :-1, :-1, :],
gt_points[..., :-1, 1:, :],
gt_points[..., 1:, :-1, :],
gt_points[..., 1:, 1:, :],
)
gt_upxleft = torch.cross(
gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1
)
gt_leftxdown = torch.cross(
gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1
)
gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1)
gt_rightxup = torch.cross(
gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1
)
mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = (
mask[..., :-1, :-1],
mask[..., :-1, 1:],
mask[..., 1:, :-1],
mask[..., 1:, 1:],
)
mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown
mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup
mask_downxright = mask_leftdown & mask_rightup & mask_leftup
mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown
MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3)
loss = (
mask_upxleft
* _smooth(
angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE),
beta=BETA_RAD,
)
+ mask_leftxdown
* _smooth(
angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE),
beta=BETA_RAD,
)
+ mask_downxright
* _smooth(
angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE),
beta=BETA_RAD,
)
+ mask_rightxup
* _smooth(
angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE),
beta=BETA_RAD,
)
)
total_valid_mask = mask_upxleft | mask_leftxdown | mask_downxright | mask_rightxup
valid_count = total_valid_mask.sum()
if valid_count > 0:
loss = loss.sum() / (valid_count * (4 * max(points.shape[-3:-1])))
else:
loss = 0 * loss.sum()
return loss
def compute_gradient_loss(prediction, gt_target, mask):
"""
Compute the gradient loss between the prediction and GT target at valid points.
References:
https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss
https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py
Args:
prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C).
gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C).
mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W).
"""
# Expand mask to match number of channels in prediction
mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1])
summed_mask = torch.sum(mask, (1, 2, 3))
# Compute the gradient of the prediction and GT target
diff = prediction - gt_target
diff = torch.mul(mask, diff)
# Gradient in x direction
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
grad_x = torch.mul(mask_x, grad_x)
# Gradient in y direction
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
grad_y = torch.mul(mask_y, grad_y)
# Clamp the outlier gradients
grad_x = grad_x.clamp(max=100)
grad_y = grad_y.clamp(max=100)
# Compute the total loss
image_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3))
num_valid_pixels = torch.sum(summed_mask)
if num_valid_pixels > 0:
image_loss = torch.sum(image_loss) / num_valid_pixels
else:
image_loss = 0 * torch.sum(image_loss)
return image_loss
def compute_gradient_matching_loss(prediction, gt_target, mask, scales=4):
"""
Compute the multi-scale gradient matching loss between the prediction and GT target at valid points.
This loss biases discontinuities to be sharp and to coincide with discontinuities in the ground truth.
More info in MiDAS: https://arxiv.org/pdf/1907.01341.pdf; Equation 11
References:
https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss
https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py
Args:
prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C).
gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C).
mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W).
scales (int): Number of scales to compute the loss at. Default: 4.
"""
# Define total loss
total_loss = 0.0
# Compute the gradient loss at different scales
for scale in range(scales):
step = pow(2, scale)
grad_loss = compute_gradient_loss(
prediction[:, ::step, ::step],
gt_target[:, ::step, ::step],
mask[:, ::step, ::step],
)
total_loss += grad_loss
return total_loss
def Sum(*losses_and_masks):
"""
Aggregates multiple losses into a single loss value or returns the original losses.
Args:
*losses_and_masks: Variable number of tuples, each containing (loss, mask, rep_type)
- loss: Tensor containing loss values
- mask: Mask indicating valid pixels/regions
- rep_type: String indicating the type of representation (e.g., 'pts3d', 'depth')
Returns:
If the first loss has dimensions > 0:
Returns the original list of (loss, mask, rep_type) tuples
Otherwise:
Returns a scalar tensor that is the sum of all loss values
"""
loss, mask, rep_type = losses_and_masks[0]
if loss.ndim > 0:
# we are actually returning the loss for every pixels
return losses_and_masks
else:
# we are returning the global loss
for loss2, mask2, rep_type2 in losses_and_masks[1:]:
loss = loss + loss2
return loss
class BaseCriterion(nn.Module):
"Base Criterion to support different reduction methods"
def __init__(self, reduction="mean"):
super().__init__()
self.reduction = reduction
class LLoss(BaseCriterion):
"L-norm loss"
def forward(self, a, b, **kwargs):
assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 4, (
f"Bad shape = {a.shape}"
)
dist = self.distance(a, b, **kwargs)
assert dist.ndim == a.ndim - 1 # one dimension less
if self.reduction == "none":
return dist
if self.reduction == "sum":
return dist.sum()
if self.reduction == "mean":
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
raise ValueError(f"bad {self.reduction=} mode")
def distance(self, a, b, **kwargs):
raise NotImplementedError()
class L1Loss(LLoss):
"L1 distance"
def distance(self, a, b, **kwargs):
return torch.abs(a - b).sum(dim=-1)
class L2Loss(LLoss):
"Euclidean (L2 Norm) distance"
def distance(self, a, b, **kwargs):
return torch.norm(a - b, dim=-1)
class GenericLLoss(LLoss):
"Criterion that supports different L-norms"
def distance(self, a, b, loss_type, **kwargs):
if loss_type == "l1":
# L1 distance
return torch.abs(a - b).sum(dim=-1)
elif loss_type == "l2":
# Euclidean (L2 norm) distance
return torch.norm(a - b, dim=-1)
else:
raise ValueError(
f"Unsupported loss type: {loss_type}. Supported types are 'l1' and 'l2'."
)
class FactoredLLoss(LLoss):
"Criterion that supports different L-norms for the factored loss functions"
def __init__(
self,
reduction="mean",
points_loss_type="l2",
depth_loss_type="l1",
ray_directions_loss_type="l1",
pose_quats_loss_type="l1",
pose_trans_loss_type="l1",
scale_loss_type="l1",
):
super().__init__(reduction)
self.points_loss_type = points_loss_type
self.depth_loss_type = depth_loss_type
self.ray_directions_loss_type = ray_directions_loss_type
self.pose_quats_loss_type = pose_quats_loss_type
self.pose_trans_loss_type = pose_trans_loss_type
self.scale_loss_type = scale_loss_type
def _distance(self, a, b, loss_type):
if loss_type == "l1":
# L1 distance
return torch.abs(a - b).sum(dim=-1)
elif loss_type == "l2":
# Euclidean (L2 norm) distance
return torch.norm(a - b, dim=-1)
else:
raise ValueError(f"Unsupported loss type: {loss_type}.")
def distance(self, a, b, factor, **kwargs):
if factor == "points":
return self._distance(a, b, self.points_loss_type)
elif factor == "depth":
return self._distance(a, b, self.depth_loss_type)
elif factor == "ray_directions":
return self._distance(a, b, self.ray_directions_loss_type)
elif factor == "pose_quats":
return self._distance(a, b, self.pose_quats_loss_type)
elif factor == "pose_trans":
return self._distance(a, b, self.pose_trans_loss_type)
elif factor == "scale":
return self._distance(a, b, self.scale_loss_type)
else:
raise ValueError(f"Unsupported factor type: {factor}.")
class RobustRegressionLoss(LLoss):
"""
Generalized Robust Loss introduced in https://arxiv.org/abs/1701.03077.
"""
def __init__(self, alpha=0.5, scaling_c=0.25, reduction="mean"):
"""
Initialize the Robust Regression Loss.
Args:
alpha (float): Shape parameter controlling the robustness of the loss.
Lower values make the loss more robust to outliers. Default: 0.5.
scaling_c (float): Scale parameter controlling the transition between
quadratic and robust behavior. Default: 0.1.
reduction (str): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. Default: 'mean'.
"""
super().__init__(reduction)
self.alpha = alpha
self.scaling_c = scaling_c
def distance(self, a, b, **kwargs):
error_scaled = torch.sum(((a - b) / self.scaling_c) ** 2, dim=-1)
robust_loss = (abs(self.alpha - 2) / self.alpha) * (
torch.pow((error_scaled / abs(self.alpha - 2)) + 1, self.alpha / 2) - 1
)
return robust_loss
class BCELoss(BaseCriterion):
"""Binary Cross Entropy loss"""
def forward(self, predicted_logits, reference_mask):
"""
Args:
predicted_logits: (B, H, W) tensor of predicted logits for the mask
reference_mask: (B, H, W) tensor of reference mask
Returns:
loss: scalar tensor of the BCE loss
"""
bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(
predicted_logits, reference_mask.float()
)
return bce_loss
class Criterion(nn.Module):
"""
Base class for all criterion modules that wrap a BaseCriterion.
This class serves as a wrapper around BaseCriterion objects, providing
additional functionality like naming and reduction mode control.
Args:
criterion (BaseCriterion): The base criterion to wrap.
"""
def __init__(self, criterion=None):
super().__init__()
assert isinstance(criterion, BaseCriterion), (
f"{criterion} is not a proper criterion!"
)
self.criterion = copy(criterion)
def get_name(self):
"""
Returns a string representation of this criterion.
Returns:
str: A string containing the class name and the wrapped criterion.
"""
return f"{type(self).__name__}({self.criterion})"
def with_reduction(self, mode="none"):
"""
Creates a deep copy of this criterion with the specified reduction mode.
This method recursively sets the reduction mode for this criterion and
any chained MultiLoss criteria.
Args:
mode (str): The reduction mode to set. Default: "none".
Returns:
Criterion: A new criterion with the specified reduction mode.
"""
res = loss = deepcopy(self)
while loss is not None:
assert isinstance(loss, Criterion)
loss.criterion.reduction = mode # make it return the loss for each sample
loss = loss._loss2 # we assume loss is a Multiloss
return res
class MultiLoss(nn.Module):
"""
Base class for combinable loss functions with automatic tracking of individual loss values.
This class enables easy combination of multiple loss functions through arithmetic operations:
loss = MyLoss1() + 0.1*MyLoss2()
The combined loss functions maintain their individual weights and the forward pass
automatically computes and aggregates all losses while tracking individual loss values.
Usage:
Inherit from this class and override get_name() and compute_loss() methods.
Attributes:
_alpha (float): Weight multiplier for this loss component.
_loss2 (MultiLoss): Reference to the next loss in the chain, if any.
"""
def __init__(self):
"""Initialize the MultiLoss with default weight of 1 and no chained loss."""
super().__init__()
self._alpha = 1
self._loss2 = None
def compute_loss(self, *args, **kwargs):
"""
Compute the loss value for this specific loss component.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor or tuple: Either the loss tensor or a tuple of (loss, details_dict).
Raises:
NotImplementedError: This method must be implemented by subclasses.
"""
raise NotImplementedError()
def get_name(self):
"""
Get the name of this loss component.
Returns:
str: The name of the loss.
Raises:
NotImplementedError: This method must be implemented by subclasses.
"""
raise NotImplementedError()
def __mul__(self, alpha):
"""
Multiply the loss by a scalar weight.
Args:
alpha (int or float): The weight to multiply the loss by.
Returns:
MultiLoss: A new loss object with the updated weight.
Raises:
AssertionError: If alpha is not a number.
"""
assert isinstance(alpha, (int, float))
res = copy(self)
res._alpha = alpha
return res
__rmul__ = __mul__ # Support both loss*alpha and alpha*loss
def __add__(self, loss2):
"""
Add another loss to this loss, creating a chain of losses.
Args:
loss2 (MultiLoss): Another loss to add to this one.
Returns:
MultiLoss: A new loss object representing the combined losses.
Raises:
AssertionError: If loss2 is not a MultiLoss.
"""
assert isinstance(loss2, MultiLoss)
res = cur = copy(self)
# Find the end of the chain
while cur._loss2 is not None:
cur = cur._loss2
cur._loss2 = loss2
return res
def __repr__(self):
"""
Create a string representation of the loss, including weights and chained losses.
Returns:
str: String representation of the loss.
"""
name = self.get_name()
if self._alpha != 1:
name = f"{self._alpha:g}*{name}"
if self._loss2:
name = f"{name} + {self._loss2}"
return name
def forward(self, *args, **kwargs):
"""
Compute the weighted loss and aggregate with any chained losses.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
tuple: A tuple containing:
- torch.Tensor: The total weighted loss.
- dict: Details about individual loss components.
"""
loss = self.compute_loss(*args, **kwargs)
if isinstance(loss, tuple):
loss, details = loss
elif loss.ndim == 0:
details = {self.get_name(): float(loss)}
else:
details = {}
loss = loss * self._alpha
if self._loss2:
loss2, details2 = self._loss2(*args, **kwargs)
loss = loss + loss2
details |= details2
return loss, details
class NonAmbiguousMaskLoss(Criterion, MultiLoss):
"""
Loss on non-ambiguous mask prediction logits.
"""
def __init__(self, criterion):
super().__init__(criterion)
def compute_loss(self, batch, preds, **kw):
"""
Args:
batch: list of dicts with the gt data
preds: list of dicts with the predictions
Returns:
loss: Sum class of the lossses for N-views and the loss details
"""
# Init loss list to keep track of individual losses for each view
loss_list = []
mask_loss_details = {}
mask_loss_total = 0
self_name = type(self).__name__
# Loop over the views
for view_idx, (gt, pred) in enumerate(zip(batch, preds)):
# Get the GT non-ambiguous masks
gt_non_ambiguous_mask = gt["non_ambiguous_mask"]
# Get the predicted non-ambiguous mask logits
pred_non_ambiguous_mask_logits = pred["non_ambiguous_mask_logits"]
# Compute the loss for the current view
loss = self.criterion(pred_non_ambiguous_mask_logits, gt_non_ambiguous_mask)
# Add the loss to the list
loss_list.append((loss, None, "non_ambiguous_mask"))
# Add the loss details to the dictionary
mask_loss_details[f"{self_name}_mask_view{view_idx + 1}"] = float(loss)
mask_loss_total += float(loss)
# Compute the average loss across all views
mask_loss_details[f"{self_name}_mask_avg"] = mask_loss_total / len(batch)
return Sum(*loss_list), (mask_loss_details | {})
class ConfLoss(MultiLoss):
"""
Applies confidence-weighted regression loss using model-predicted confidence values.
The confidence-weighted loss has the form:
conf_loss = raw_loss * conf - alpha * log(conf)
Where:
- raw_loss is the original per-pixel loss
- conf is the predicted confidence (higher values = higher confidence)
- alpha is a hyperparameter controlling the regularization strength
This loss can be selectively applied to specific loss components in factored and multi-view settings.
"""
def __init__(self, pixel_loss, alpha=1, loss_set_indices=None):
"""
Args:
pixel_loss (MultiLoss): The pixel-level regression loss to be used.
alpha (float): Hyperparameter controlling the confidence regularization strength.
loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to.
Each index selects a specific loss set across all views (with the same rep_type).
If None, defaults to [0] which applies to the first loss set only.
"""
super().__init__()
assert alpha > 0
self.alpha = alpha
self.pixel_loss = pixel_loss.with_reduction("none")
self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices
def get_name(self):
return f"ConfLoss({self.pixel_loss})"
def get_conf_log(self, x):
return x, torch.log(x)
def compute_loss(self, batch, preds, **kw):
# Init loss list and details
total_loss = 0
conf_loss_details = {}
running_avg_dict = {}
self_name = type(self.pixel_loss).__name__
n_views = len(batch)
# Compute per-pixel loss for each view
losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw)
# Select specific loss sets based on indices
selected_losses = []
processed_indices = set()
for idx in self.loss_set_indices:
start_idx = idx * n_views
end_idx = min((idx + 1) * n_views, len(losses))
selected_losses.extend(losses[start_idx:end_idx])
processed_indices.update(range(start_idx, end_idx))
# Process selected losses with confidence weighting
for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses):
view_idx = loss_idx % n_views # Map to corresponding view index
if loss.numel() == 0:
# print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
continue
# Get the confidence and log confidence
if (
hasattr(self.pixel_loss, "flatten_across_image_only")
and self.pixel_loss.flatten_across_image_only
):
# Reshape confidence to match the flattened dimensions
conf_reshaped = preds[view_idx]["conf"].view(
preds[view_idx]["conf"].shape[0], -1
)
conf, log_conf = self.get_conf_log(conf_reshaped[msk])
loss = loss[msk]
else:
conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk])
# Weight the loss by the confidence
conf_loss = loss * conf - self.alpha * log_conf
# Only add to total loss and store details if there are valid elements
if conf_loss.numel() > 0:
conf_loss = conf_loss.mean()
total_loss = total_loss + conf_loss
# Store details
conf_loss_details[
f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"
] = float(conf_loss)
# Initialize or update running average directly
avg_key = f"{self_name}_{rep_type}_conf_loss_avg"
if avg_key not in conf_loss_details:
conf_loss_details[avg_key] = float(conf_loss)
running_avg_dict[
f"{self_name}_{rep_type}_conf_loss_valid_views"
] = 1
else:
valid_views = (
running_avg_dict[
f"{self_name}_{rep_type}_conf_loss_valid_views"
]
+ 1
)
running_avg_dict[
f"{self_name}_{rep_type}_conf_loss_valid_views"
] = valid_views
conf_loss_details[avg_key] += (
float(conf_loss) - conf_loss_details[avg_key]
) / valid_views
# Add unmodified losses for sets not in selected_losses
for idx, (loss, msk, rep_type) in enumerate(losses):
if idx not in processed_indices:
if msk is not None:
loss_after_masking = loss[msk]
else:
loss_after_masking = loss
if loss_after_masking.numel() > 0:
loss_mean = loss_after_masking.mean()
else:
# print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
loss_mean = 0
total_loss = total_loss + loss_mean
return total_loss, dict(**conf_loss_details, **pixel_loss_details)
class ExcludeTopNPercentPixelLoss(MultiLoss):
"""
Pixel-level regression loss where for each instance in a batch the top N% of per-pixel loss values are ignored
for the mean loss computation.
Allows selecting which pixel-level regression loss sets to apply the exclusion to.
"""
def __init__(
self,
pixel_loss,
top_n_percent=5,
apply_to_real_data_only=True,
loss_set_indices=None,
):
"""
Args:
pixel_loss (MultiLoss): The pixel-level regression loss to be used.
top_n_percent (float): The percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5.
apply_to_real_data_only (bool): Whether to apply the loss only to real world data. Default: True.
loss_set_indices (list or None): Indices of the loss sets to apply the exclusion to.
Each index selects a specific loss set across all views (with the same rep_type).
If None, defaults to [0] which applies to the first loss set only.
"""
super().__init__()
self.pixel_loss = pixel_loss.with_reduction("none")
self.top_n_percent = top_n_percent
self.bottom_n_percent = 100 - top_n_percent
self.apply_to_real_data_only = apply_to_real_data_only
self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices
def get_name(self):
return f"ExcludeTopNPercentPixelLoss({self.pixel_loss})"
def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent):
"""
Function to compute the mask for keeping the bottom n percent of per-pixel loss values.
Args:
tensor (torch.Tensor): The tensor containing the per-pixel loss values.
Shape: (B, N) where B is the batch size and N is the number of total pixels.
mask (torch.Tensor): The mask indicating valid pixels. Shape: (B, N).
Returns:
torch.Tensor: Flattened tensor containing the bottom n percent of per-pixel loss values.
"""
B, N = tensor.shape
# Calculate the number of valid elements (where mask is True)
num_valid = mask.sum(dim=1)
# Calculate the number of elements to keep (n% of valid elements)
num_keep = (num_valid * bottom_n_percent / 100).long()
# Create a mask for the bottom n% elements
keep_mask = torch.arange(N, device=tensor.device).unsqueeze(
0
) < num_keep.unsqueeze(1)
# Create a tensor with inf where mask is False
masked_tensor = torch.where(
mask, tensor, torch.tensor(float("inf"), device=tensor.device)
)
# Sort the masked tensor along the N dimension
sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False)
# Get the bottom n% elements
bottom_n_percent_elements = sorted_tensor[keep_mask]
return bottom_n_percent_elements
def compute_loss(self, batch, preds, **kw):
# Compute per-pixel loss
losses, details = self.pixel_loss(batch, preds, **kw)
n_views = len(batch)
# Select specific loss sets based on indices
selected_losses = []
processed_indices = set()
for idx in self.loss_set_indices:
start_idx = idx * n_views
end_idx = min((idx + 1) * n_views, len(losses))
selected_losses.extend(losses[start_idx:end_idx])
processed_indices.update(range(start_idx, end_idx))
# Initialize total loss
total_loss = 0.0
loss_details = {}
running_avg_dict = {}
self_name = type(self.pixel_loss).__name__
# Process selected losses with top N percent exclusion
for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses):
view_idx = loss_idx % n_views # Map to corresponding view index
if loss.numel() == 0:
# print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
continue
# Create empty list for current view's aggregated tensors
aggregated_losses = []
if self.apply_to_real_data_only:
# Get the synthetic and real world data mask
synthetic_mask = batch[view_idx]["is_synthetic"]
real_data_mask = ~batch[view_idx]["is_synthetic"]
else:
# Apply the filtering to all data
synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"])
real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"])
# Process synthetic data
if synthetic_mask.any():
synthetic_loss = loss[synthetic_mask]
synthetic_msk = msk[synthetic_mask]
aggregated_losses.append(synthetic_loss[synthetic_msk])
# Process real data
if real_data_mask.any():
real_loss = loss[real_data_mask]
real_msk = msk[real_data_mask]
real_bottom_n_percent_loss = self.keep_bottom_n_percent(
real_loss, real_msk, self.bottom_n_percent
)
aggregated_losses.append(real_bottom_n_percent_loss)
# Compute view loss
view_loss = torch.cat(aggregated_losses, dim=0)
# Only add to total loss and store details if there are valid elements
if view_loss.numel() > 0:
view_loss = view_loss.mean()
total_loss = total_loss + view_loss
# Store details
loss_details[
f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}"
] = float(view_loss)
# Initialize or update running average directly
avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg"
if avg_key not in loss_details:
loss_details[avg_key] = float(view_loss)
running_avg_dict[
f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
] = 1
else:
valid_views = (
running_avg_dict[
f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
]
+ 1
)
running_avg_dict[
f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
] = valid_views
loss_details[avg_key] += (
float(view_loss) - loss_details[avg_key]
) / valid_views
# Add unmodified losses for sets not in selected_losses
for idx, (loss, msk, rep_type) in enumerate(losses):
if idx not in processed_indices:
if msk is not None:
loss_after_masking = loss[msk]
else:
loss_after_masking = loss
if loss_after_masking.numel() > 0:
loss_mean = loss_after_masking.mean()
else:
# print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
loss_mean = 0
total_loss = total_loss + loss_mean
return total_loss, dict(**loss_details, **details)
class ConfAndExcludeTopNPercentPixelLoss(MultiLoss):
"""
Combined loss that applies ConfLoss to one set of pixel-level regression losses
and ExcludeTopNPercentPixelLoss to another set of pixel-level regression losses.
"""
def __init__(
self,
pixel_loss,
conf_alpha=1,
top_n_percent=5,
apply_to_real_data_only=True,
conf_loss_set_indices=None,
exclude_loss_set_indices=None,
):
"""
Args:
pixel_loss (MultiLoss): The pixel-level regression loss to be used.
conf_alpha (float): Alpha parameter for ConfLoss. Default: 1.
top_n_percent (float): Percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5.
apply_to_real_data_only (bool): Whether to apply the exclude loss only to real world data. Default: True.
conf_loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to.
Each index selects a specific loss set across all views (with the same rep_type).
If None, defaults to [0] which applies to the first loss set only.
exclude_loss_set_indices (list or None): Indices of the loss sets to apply top N percent exclusion to.
Each index selects a specific loss set across all views (with the same rep_type).
If None, defaults to [1] which applies to the second loss set only.
"""
super().__init__()
self.pixel_loss = pixel_loss.with_reduction("none")
assert conf_alpha > 0
self.conf_alpha = conf_alpha
self.top_n_percent = top_n_percent
self.bottom_n_percent = 100 - top_n_percent
self.apply_to_real_data_only = apply_to_real_data_only
self.conf_loss_set_indices = (
[0] if conf_loss_set_indices is None else conf_loss_set_indices
)
self.exclude_loss_set_indices = (
[1] if exclude_loss_set_indices is None else exclude_loss_set_indices
)
def get_name(self):
return f"ConfAndExcludeTopNPercentPixelLoss({self.pixel_loss})"
def get_conf_log(self, x):
return x, torch.log(x)
def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent):
"""
Function to compute the mask for keeping the bottom n percent of per-pixel loss values.
"""
B, N = tensor.shape
# Calculate the number of valid elements (where mask is True)
num_valid = mask.sum(dim=1)
# Calculate the number of elements to keep (n% of valid elements)
num_keep = (num_valid * bottom_n_percent / 100).long()
# Create a mask for the bottom n% elements
keep_mask = torch.arange(N, device=tensor.device).unsqueeze(
0
) < num_keep.unsqueeze(1)
# Create a tensor with inf where mask is False
masked_tensor = torch.where(
mask, tensor, torch.tensor(float("inf"), device=tensor.device)
)
# Sort the masked tensor along the N dimension
sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False)
# Get the bottom n% elements
bottom_n_percent_elements = sorted_tensor[keep_mask]
return bottom_n_percent_elements
def compute_loss(self, batch, preds, **kw):
# Compute per-pixel loss
losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw)
n_views = len(batch)
# Select specific loss sets for confidence weighting
conf_selected_losses = []
conf_processed_indices = set()
for idx in self.conf_loss_set_indices:
start_idx = idx * n_views
end_idx = min((idx + 1) * n_views, len(losses))
conf_selected_losses.extend(losses[start_idx:end_idx])
conf_processed_indices.update(range(start_idx, end_idx))
# Select specific loss sets for top N percent exclusion
exclude_selected_losses = []
exclude_processed_indices = set()
for idx in self.exclude_loss_set_indices:
start_idx = idx * n_views
end_idx = min((idx + 1) * n_views, len(losses))
exclude_selected_losses.extend(losses[start_idx:end_idx])
exclude_processed_indices.update(range(start_idx, end_idx))
# Initialize total loss and details
total_loss = 0
loss_details = {}
running_avg_dict = {}
self_name = type(self.pixel_loss).__name__
# Process selected losses with confidence weighting
for loss_idx, (loss, msk, rep_type) in enumerate(conf_selected_losses):
view_idx = loss_idx % n_views # Map to corresponding view index
if loss.numel() == 0:
# print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for conf loss", force=True)
continue
# Get the confidence and log confidence
if (
hasattr(self.pixel_loss, "flatten_across_image_only")
and self.pixel_loss.flatten_across_image_only
):
# Reshape confidence to match the flattened dimensions
conf_reshaped = preds[view_idx]["conf"].view(
preds[view_idx]["conf"].shape[0], -1
)
conf, log_conf = self.get_conf_log(conf_reshaped[msk])
loss = loss[msk]
else:
conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk])
# Weight the loss by the confidence
conf_loss = loss * conf - self.conf_alpha * log_conf
# Only add to total loss and store details if there are valid elements
if conf_loss.numel() > 0:
conf_loss = conf_loss.mean()
total_loss = total_loss + conf_loss
# Store details
loss_details[f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"] = (
float(conf_loss)
)
# Initialize or update running average directly
avg_key = f"{self_name}_{rep_type}_conf_loss_avg"
if avg_key not in loss_details:
loss_details[avg_key] = float(conf_loss)
running_avg_dict[
f"{self_name}_{rep_type}_conf_loss_valid_views"
] = 1
else:
valid_views = (
running_avg_dict[
f"{self_name}_{rep_type}_conf_loss_valid_views"
]
+ 1
)
running_avg_dict[
f"{self_name}_{rep_type}_conf_loss_valid_views"
] = valid_views
loss_details[avg_key] += (
float(conf_loss) - loss_details[avg_key]
) / valid_views
# Process selected losses with top N percent exclusion
for loss_idx, (loss, msk, rep_type) in enumerate(exclude_selected_losses):
view_idx = loss_idx % n_views # Map to corresponding view index
if loss.numel() == 0:
# print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for exclude loss", force=True)
continue
# Create empty list for current view's aggregated tensors
aggregated_losses = []
if self.apply_to_real_data_only:
# Get the synthetic and real world data mask
synthetic_mask = batch[view_idx]["is_synthetic"]
real_data_mask = ~batch[view_idx]["is_synthetic"]
else:
# Apply the filtering to all data
synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"])
real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"])
# Process synthetic data
if synthetic_mask.any():
synthetic_loss = loss[synthetic_mask]
synthetic_msk = msk[synthetic_mask]
aggregated_losses.append(synthetic_loss[synthetic_msk])
# Process real data
if real_data_mask.any():
real_loss = loss[real_data_mask]
real_msk = msk[real_data_mask]
real_bottom_n_percent_loss = self.keep_bottom_n_percent(
real_loss, real_msk, self.bottom_n_percent
)
aggregated_losses.append(real_bottom_n_percent_loss)
# Compute view loss
view_loss = torch.cat(aggregated_losses, dim=0)
# Only add to total loss and store details if there are valid elements
if view_loss.numel() > 0:
view_loss = view_loss.mean()
total_loss = total_loss + view_loss
# Store details
loss_details[
f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}"
] = float(view_loss)
# Initialize or update running average directly
avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg"
if avg_key not in loss_details:
loss_details[avg_key] = float(view_loss)
running_avg_dict[
f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
] = 1
else:
valid_views = (
running_avg_dict[
f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
]
+ 1
)
running_avg_dict[
f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
] = valid_views
loss_details[avg_key] += (
float(view_loss) - loss_details[avg_key]
) / valid_views
# Add unmodified losses for sets not processed with either confidence or exclusion
all_processed_indices = conf_processed_indices.union(exclude_processed_indices)
for idx, (loss, msk, rep_type) in enumerate(losses):
if idx not in all_processed_indices:
if msk is not None:
loss_after_masking = loss[msk]
else:
loss_after_masking = loss
if loss_after_masking.numel() > 0:
loss_mean = loss_after_masking.mean()
else:
# print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
loss_mean = 0
total_loss = total_loss + loss_mean
return total_loss, dict(**loss_details, **pixel_loss_details)
class Regr3D(Criterion, MultiLoss):
"""
Regression Loss for World Frame Pointmaps.
Asymmetric loss where view 1 is supposed to be the anchor.
For each view i:
Pi = RTi @ Di
lossi = (RTi1 @ pred_Di) - (RT1^-1 @ RTi @ Di)
where RT1 is the anchor view camera pose
"""
def __init__(
self,
criterion,
norm_mode="?avg_dis",
gt_scale=False,
ambiguous_loss_value=0,
max_metric_scale=False,
loss_in_log=True,
flatten_across_image_only=False,
):
"""
Initialize the loss criterion for World Frame Pointmaps.
Args:
criterion (BaseCriterion): The base criterion to use for computing the loss.
norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis".
If prefixed with "?", normalization is only applied to non-metric scale data.
gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
If False, both GT and predictions are normalized independently. Default: False.
ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
If 0, ambiguous pixels are ignored. Default: 0.
max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
value, it will be treated as non-metric. Default: False (no limit).
loss_in_log (bool): If True, apply logarithmic transformation to input before
computing the loss for pointmaps. Default: True.
flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
the loss. If False, flatten across batch and spatial dimensions. Default: False.
"""
super().__init__(criterion)
if norm_mode.startswith("?"):
# Do no norm pts from metric scale datasets
self.norm_all = False
self.norm_mode = norm_mode[1:]
else:
self.norm_all = True
self.norm_mode = norm_mode
self.gt_scale = gt_scale
self.ambiguous_loss_value = ambiguous_loss_value
self.max_metric_scale = max_metric_scale
self.loss_in_log = loss_in_log
self.flatten_across_image_only = flatten_across_image_only
def get_all_info(self, batch, preds, dist_clip=None):
n_views = len(batch)
in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
# Initialize lists to store points and masks
no_norm_gt_pts = []
valid_masks = []
# Process ground truth points and valid masks
for view_idx in range(n_views):
no_norm_gt_pts.append(
geotrf(in_camera0, batch[view_idx]["pts3d"])
) # B,H,W,3
valid_masks.append(batch[view_idx]["valid_mask"].clone())
if dist_clip is not None:
# Points that are too far-away == invalid
for view_idx in range(n_views):
dis = no_norm_gt_pts[view_idx].norm(dim=-1) # (B, H, W)
valid_masks[view_idx] = valid_masks[view_idx] & (dis <= dist_clip)
# Get predicted points
no_norm_pr_pts = []
for view_idx in range(n_views):
no_norm_pr_pts.append(preds[view_idx]["pts3d"])
if not self.norm_all:
if self.max_metric_scale:
B = valid_masks[0].shape[0]
# Calculate distances to camera for all views
dists_to_cam1 = []
for view_idx in range(n_views):
dist = torch.where(
valid_masks[view_idx],
torch.norm(no_norm_gt_pts[view_idx], dim=-1),
0,
).view(B, -1)
dists_to_cam1.append(dist)
# Update metric scale flags
metric_scale_mask = batch[0]["is_metric_scale"]
for dist in dists_to_cam1:
metric_scale_mask = metric_scale_mask & (
dist.max(dim=-1).values < self.max_metric_scale
)
for view_idx in range(n_views):
batch[view_idx]["is_metric_scale"] = metric_scale_mask
non_metric_scale_mask = ~batch[0]["is_metric_scale"]
else:
non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"])
# Initialize normalized points
gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
# Normalize 3d points
if self.norm_mode and non_metric_scale_mask.any():
normalized_pr_pts = normalize_multiple_pointclouds(
[pts[non_metric_scale_mask] for pts in no_norm_pr_pts],
[mask[non_metric_scale_mask] for mask in valid_masks],
self.norm_mode,
)
for i in range(n_views):
pr_pts[i][non_metric_scale_mask] = normalized_pr_pts[i]
elif non_metric_scale_mask.any():
for i in range(n_views):
pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][
non_metric_scale_mask
]
if self.norm_mode and not self.gt_scale:
gt_normalization_output = normalize_multiple_pointclouds(
no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
)
normalized_gt_pts = gt_normalization_output[:-1]
norm_factor = gt_normalization_output[-1]
for i in range(n_views):
gt_pts[i] = normalized_gt_pts[i]
pr_pts[i][~non_metric_scale_mask] = (
no_norm_pr_pts[i][~non_metric_scale_mask]
/ norm_factor[~non_metric_scale_mask]
)
elif ~non_metric_scale_mask.any():
for i in range(n_views):
gt_pts[i] = no_norm_gt_pts[i]
pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][
~non_metric_scale_mask
]
else:
for i in range(n_views):
gt_pts[i] = no_norm_gt_pts[i]
# Get ambiguous masks
ambiguous_masks = []
for view_idx in range(n_views):
ambiguous_masks.append(
(~batch[view_idx]["non_ambiguous_mask"]) & (~valid_masks[view_idx])
)
return gt_pts, pr_pts, valid_masks, ambiguous_masks, {}
def compute_loss(self, batch, preds, **kw):
gt_pts, pred_pts, masks, ambiguous_masks, monitoring = self.get_all_info(
batch, preds, **kw
)
n_views = len(batch)
if self.ambiguous_loss_value > 0:
assert self.criterion.reduction == "none", (
"ambiguous_loss_value should be 0 if no conf loss"
)
# Add the ambiguous pixels as "valid" pixels
masks = [mask | amb_mask for mask, amb_mask in zip(masks, ambiguous_masks)]
losses = []
details = {}
running_avg_dict = {}
self_name = type(self).__name__
if not self.flatten_across_image_only:
for view_idx in range(n_views):
pred = pred_pts[view_idx][masks[view_idx]]
gt = gt_pts[view_idx][masks[view_idx]]
if self.loss_in_log:
pred = apply_log_to_norm(pred)
gt = apply_log_to_norm(gt)
loss = self.criterion(pred, gt)
if self.ambiguous_loss_value > 0:
loss = torch.where(
ambiguous_masks[view_idx][masks[view_idx]],
self.ambiguous_loss_value,
loss,
)
losses.append((loss, masks[view_idx], "pts3d"))
if loss.numel() > 0:
loss_mean = float(loss.mean())
details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean
# Initialize or update running average directly
avg_key = f"{self_name}_pts3d_avg"
if avg_key not in details:
details[avg_key] = loss_mean
running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1
else:
valid_views = (
running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1
)
running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views
details[avg_key] += (loss_mean - details[avg_key]) / valid_views
else:
batch_size, _, _, dim = gt_pts[0].shape
for view_idx in range(n_views):
gt = gt_pts[view_idx].view(batch_size, -1, dim)
pred = pred_pts[view_idx].view(batch_size, -1, dim)
view_mask = masks[view_idx].view(batch_size, -1)
amb_mask = ambiguous_masks[view_idx].view(batch_size, -1)
if self.loss_in_log:
pred = apply_log_to_norm(pred)
gt = apply_log_to_norm(gt)
loss = self.criterion(pred, gt)
if self.ambiguous_loss_value > 0:
loss = torch.where(amb_mask, self.ambiguous_loss_value, loss)
losses.append((loss, view_mask, "pts3d"))
loss_after_masking = loss[view_mask]
if loss_after_masking.numel() > 0:
loss_mean = float(loss_after_masking.mean())
details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean
# Initialize or update running average directly
avg_key = f"{self_name}_pts3d_avg"
if avg_key not in details:
details[avg_key] = loss_mean
running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1
else:
valid_views = (
running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1
)
running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views
details[avg_key] += (loss_mean - details[avg_key]) / valid_views
return Sum(*losses), (details | monitoring)
class PointsPlusScaleRegr3D(Criterion, MultiLoss):
"""
Regression Loss for World Frame Pointmaps & Scale.
"""
def __init__(
self,
criterion,
norm_predictions=True,
norm_mode="avg_dis",
ambiguous_loss_value=0,
loss_in_log=True,
flatten_across_image_only=False,
world_frame_points_loss_weight=1,
scale_loss_weight=1,
):
"""
Initialize the loss criterion for World Frame Pointmaps & Scale.
The predicited scene representation is always normalized w.r.t. the frame of view0.
Loss is applied between the predicted metric scale and the ground truth metric scale.
Args:
criterion (BaseCriterion): The base criterion to use for computing the loss.
norm_predictions (bool): If True, normalize the predictions before computing the loss.
norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
If 0, ambiguous pixels are ignored. Default: 0.
loss_in_log (bool): If True, apply logarithmic transformation to input before
computing the loss for depth, pointmaps and scale. Default: True.
flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
the loss. If False, flatten across batch and spatial dimensions. Default: False.
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
"""
super().__init__(criterion)
self.norm_predictions = norm_predictions
self.norm_mode = norm_mode
self.ambiguous_loss_value = ambiguous_loss_value
self.loss_in_log = loss_in_log
self.flatten_across_image_only = flatten_across_image_only
self.world_frame_points_loss_weight = world_frame_points_loss_weight
self.scale_loss_weight = scale_loss_weight
def get_all_info(self, batch, preds, dist_clip=None):
"""
Function to get all the information needed to compute the loss.
Returns all quantities normalized w.r.t. camera of view0.
"""
n_views = len(batch)
# Everything is normalized w.r.t. camera of view0
# Intialize lists to store data for all views
# Ground truth quantities
in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
no_norm_gt_pts = []
valid_masks = []
# Predicted quantities
no_norm_pr_pts = []
metric_pr_pts_to_compute_scale = []
# Get ground truth & prediction info for all views
for i in range(n_views):
# Get the ground truth
no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
valid_masks.append(batch[i]["valid_mask"].clone())
# Get predictions for normalized loss
if "metric_scaling_factor" in preds[i].keys():
# Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
# This detaches the predicted metric scaling factor from the geometry based loss
curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
"metric_scaling_factor"
].unsqueeze(-1).unsqueeze(-1)
else:
curr_view_no_norm_pr_pts = preds[i]["pts3d"]
no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
# Get the predicted metric scale points
if "metric_scaling_factor" in preds[i].keys():
# Detach the raw predicted points so that the scale loss is only applied to the scaling factor
curr_view_metric_pr_pts_to_compute_scale = (
curr_view_no_norm_pr_pts.detach()
* preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
)
else:
curr_view_metric_pr_pts_to_compute_scale = (
curr_view_no_norm_pr_pts.clone()
)
metric_pr_pts_to_compute_scale.append(
curr_view_metric_pr_pts_to_compute_scale
)
if dist_clip is not None:
# Points that are too far-away == invalid
for i in range(n_views):
dis = no_norm_gt_pts[i].norm(dim=-1)
valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
# Initialize normalized tensors
gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
# Normalize the predicted points if specified
if self.norm_predictions:
pr_normalization_output = normalize_multiple_pointclouds(
no_norm_pr_pts,
valid_masks,
self.norm_mode,
ret_factor=True,
)
pr_pts_norm = pr_normalization_output[:-1]
# Normalize the ground truth points
gt_normalization_output = normalize_multiple_pointclouds(
no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
)
gt_pts_norm = gt_normalization_output[:-1]
gt_norm_factor = gt_normalization_output[-1]
for i in range(n_views):
if self.norm_predictions:
# Assign the normalized predictions
pr_pts[i] = pr_pts_norm[i]
else:
pr_pts[i] = no_norm_pr_pts[i]
# Assign the normalized ground truth quantities
gt_pts[i] = gt_pts_norm[i]
# Get the mask indicating ground truth metric scale quantities
metric_scale_mask = batch[0]["is_metric_scale"]
valid_gt_norm_factor_mask = (
gt_norm_factor[:, 0, 0, 0] > 1e-8
) # Mask out cases where depth for all views is invalid
valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
if valid_metric_scale_mask.any():
# Compute the scale norm factor using the predicted metric scale points
metric_pr_normalization_output = normalize_multiple_pointclouds(
metric_pr_pts_to_compute_scale,
valid_masks,
self.norm_mode,
ret_factor=True,
)
pr_metric_norm_factor = metric_pr_normalization_output[-1]
# Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
else:
gt_metric_norm_factor = None
pr_metric_norm_factor = None
# Get ambiguous masks
ambiguous_masks = []
for i in range(n_views):
ambiguous_masks.append(
(~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
)
# Pack into info dicts
gt_info = []
pred_info = []
for i in range(n_views):
gt_info.append(
{
"pts3d": gt_pts[i],
}
)
pred_info.append(
{
"pts3d": pr_pts[i],
}
)
return (
gt_info,
pred_info,
valid_masks,
ambiguous_masks,
gt_metric_norm_factor,
pr_metric_norm_factor,
)
def compute_loss(self, batch, preds, **kw):
(
gt_info,
pred_info,
valid_masks,
ambiguous_masks,
gt_metric_norm_factor,
pr_metric_norm_factor,
) = self.get_all_info(batch, preds, **kw)
n_views = len(batch)
if self.ambiguous_loss_value > 0:
assert self.criterion.reduction == "none", (
"ambiguous_loss_value should be 0 if no conf loss"
)
# Add the ambiguous pixel as "valid" pixels...
valid_masks = [
mask | ambig_mask
for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
]
pts3d_losses = []
for i in range(n_views):
# Get the predicted dense quantities
if not self.flatten_across_image_only:
# Flatten the points across the entire batch with the masks
pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
else:
# Flatten the H x W dimensions to H*W
batch_size, _, _, pts_dim = gt_info[i]["pts3d"].shape
gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
valid_masks[i] = valid_masks[i].view(batch_size, -1)
# Apply loss in log space if specified
if self.loss_in_log:
gt_pts3d = apply_log_to_norm(gt_pts3d)
pred_pts3d = apply_log_to_norm(pred_pts3d)
# Compute point loss
pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
pts3d_losses.append(pts3d_loss)
# Handle ambiguous pixels
if self.ambiguous_loss_value > 0:
if not self.flatten_across_image_only:
pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
pts3d_losses[i],
)
else:
pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
pts3d_losses[i],
)
# Compute the scale loss
if gt_metric_norm_factor is not None:
if self.loss_in_log:
gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
scale_loss = (
self.criterion(
pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
)
* self.scale_loss_weight
)
else:
scale_loss = None
# Use helper function to generate loss terms and details
losses_dict = {
"pts3d": {
"values": pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
"scale": {
"values": scale_loss,
"use_mask": False,
"is_multi_view": False,
},
}
loss_terms, details = get_loss_terms_and_details(
losses_dict,
valid_masks,
type(self).__name__,
n_views,
self.flatten_across_image_only,
)
losses = Sum(*loss_terms)
return losses, (details | {})
class NormalGMLoss(MultiLoss):
"""
Normal & Gradient Matching Loss for Monocular Depth Training.
"""
def __init__(
self,
norm_predictions=True,
norm_mode="avg_dis",
apply_normal_and_gm_loss_to_synthetic_data_only=True,
):
"""
Initialize the loss criterion for Normal & Gradient Matching Loss (currently only valid for 1 view).
Computes:
(1) Normal Loss over the PointMap (naturally will be in local frame) in euclidean coordinates,
(2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
Args:
norm_predictions (bool): If True, normalize the predictions before computing the loss.
norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
If False, apply the normal and gm loss to all data. Default: True.
"""
super().__init__()
self.norm_predictions = norm_predictions
self.norm_mode = norm_mode
self.apply_normal_and_gm_loss_to_synthetic_data_only = (
apply_normal_and_gm_loss_to_synthetic_data_only
)
def get_all_info(self, batch, preds, dist_clip=None):
"""
Function to get all the information needed to compute the loss.
Returns all quantities normalized.
"""
n_views = len(batch)
assert n_views == 1, (
"Normal & Gradient Matching Loss Class only supports 1 view"
)
# Everything is normalized w.r.t. camera of view1
in_camera1 = closed_form_pose_inverse(batch[0]["camera_pose"])
# Initialize lists to store data for all views
no_norm_gt_pts = []
valid_masks = []
no_norm_pr_pts = []
# Get ground truth & prediction info for all views
for i in range(n_views):
# Get ground truth
no_norm_gt_pts.append(geotrf(in_camera1, batch[i]["pts3d"]))
valid_masks.append(batch[i]["valid_mask"].clone())
# Get predictions for normalized loss
if "metric_scaling_factor" in preds[i].keys():
# Divide by the predicted metric scaling factor to get the raw predicted points
# This detaches the predicted metric scaling factor from the geometry based loss
curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
"metric_scaling_factor"
].unsqueeze(-1).unsqueeze(-1)
else:
curr_view_no_norm_pr_pts = preds[i]["pts3d"]
no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
if dist_clip is not None:
# Points that are too far-away == invalid
for i in range(n_views):
dis = no_norm_gt_pts[i].norm(dim=-1)
valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
# Initialize normalized tensors
gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
# Normalize the predicted points if specified
if self.norm_predictions:
pr_normalization_output = normalize_multiple_pointclouds(
no_norm_pr_pts,
valid_masks,
self.norm_mode,
ret_factor=True,
)
pr_pts_norm = pr_normalization_output[:-1]
# Normalize the ground truth points
gt_normalization_output = normalize_multiple_pointclouds(
no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
)
gt_pts_norm = gt_normalization_output[:-1]
for i in range(n_views):
if self.norm_predictions:
# Assign the normalized predictions
pr_pts[i] = pr_pts_norm[i]
else:
# Assign the raw predicted points
pr_pts[i] = no_norm_pr_pts[i]
# Assign the normalized ground truth
gt_pts[i] = gt_pts_norm[i]
return gt_pts, pr_pts, valid_masks
def compute_loss(self, batch, preds, **kw):
gt_pts, pred_pts, valid_masks = self.get_all_info(batch, preds, **kw)
n_views = len(batch)
assert n_views == 1, (
"Normal & Gradient Matching Loss Class only supports 1 view"
)
normal_losses = []
gradient_matching_losses = []
details = {}
running_avg_dict = {}
self_name = type(self).__name__
for i in range(n_views):
# Get the local frame points, log space depth_z & valid masks
pred_local_pts3d = pred_pts[i]
pred_depth_z = pred_local_pts3d[..., 2:]
pred_depth_z = apply_log_to_norm(pred_depth_z)
gt_local_pts3d = gt_pts[i]
gt_depth_z = gt_local_pts3d[..., 2:]
gt_depth_z = apply_log_to_norm(gt_depth_z)
valid_mask_for_normal_gm_loss = valid_masks[i].clone()
# Update the validity mask for normal & gm loss based on the synthetic data mask if required
if self.apply_normal_and_gm_loss_to_synthetic_data_only:
synthetic_mask = batch[i]["is_synthetic"] # (B, )
synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
synthetic_mask = synthetic_mask.expand(
-1, pred_depth_z.shape[1], pred_depth_z.shape[2]
) # (B, H, W)
valid_mask_for_normal_gm_loss = (
valid_mask_for_normal_gm_loss & synthetic_mask
)
# Compute the normal loss
normal_loss = compute_normal_loss(
pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
)
normal_losses.append(normal_loss)
# Compute the gradient matching loss
gradient_matching_loss = compute_gradient_matching_loss(
pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
)
gradient_matching_losses.append(gradient_matching_loss)
# Add loss details if only valid values are present
# Initialize or update running average directly
# Normal loss details
if float(normal_loss) > 0:
details[f"{self_name}_normal_view{i + 1}"] = float(normal_loss)
normal_avg_key = f"{self_name}_normal_avg"
if normal_avg_key not in details:
details[normal_avg_key] = float(normal_losses[i])
running_avg_dict[f"{self_name}_normal_valid_views"] = 1
else:
normal_valid_views = (
running_avg_dict[f"{self_name}_normal_valid_views"] + 1
)
running_avg_dict[f"{self_name}_normal_valid_views"] = (
normal_valid_views
)
details[normal_avg_key] += (
float(normal_losses[i]) - details[normal_avg_key]
) / normal_valid_views
# Gradient Matching loss details
if float(gradient_matching_loss) > 0:
details[f"{self_name}_gradient_matching_view{i + 1}"] = float(
gradient_matching_loss
)
# For gradient matching loss
gm_avg_key = f"{self_name}_gradient_matching_avg"
if gm_avg_key not in details:
details[gm_avg_key] = float(gradient_matching_losses[i])
running_avg_dict[f"{self_name}_gm_valid_views"] = 1
else:
gm_valid_views = running_avg_dict[f"{self_name}_gm_valid_views"] + 1
running_avg_dict[f"{self_name}_gm_valid_views"] = gm_valid_views
details[gm_avg_key] += (
float(gradient_matching_losses[i]) - details[gm_avg_key]
) / gm_valid_views
# Put the losses together
loss_terms = []
for i in range(n_views):
loss_terms.append((normal_losses[i], None, "normal"))
loss_terms.append((gradient_matching_losses[i], None, "gradient_matching"))
losses = Sum(*loss_terms)
return losses, details
class FactoredGeometryRegr3D(Criterion, MultiLoss):
"""
Regression Loss for Factored Geometry.
"""
def __init__(
self,
criterion,
norm_mode="?avg_dis",
gt_scale=False,
ambiguous_loss_value=0,
max_metric_scale=False,
loss_in_log=True,
flatten_across_image_only=False,
depth_type_for_loss="depth_along_ray",
cam_frame_points_loss_weight=1,
depth_loss_weight=1,
ray_directions_loss_weight=1,
pose_quats_loss_weight=1,
pose_trans_loss_weight=1,
compute_pairwise_relative_pose_loss=False,
convert_predictions_to_view0_frame=False,
compute_world_frame_points_loss=True,
world_frame_points_loss_weight=1,
):
"""
Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose),
and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps.
If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order:
(1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans.
Else, the pixel-level losses are returned in the following order:
(1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans.
Args:
criterion (BaseCriterion): The base criterion to use for computing the loss.
norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis".
If prefixed with "?", normalization is only applied to non-metric scale data.
gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
If False, both GT and predictions are normalized independently. Default: False.
ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
If 0, ambiguous pixels are ignored. Default: 0.
max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
value, it will be treated as non-metric. Default: False (no limit).
loss_in_log (bool): If True, apply logarithmic transformation to input before
computing the loss for depth and pointmaps. Default: True.
flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
the loss. If False, flatten across batch and spatial dimensions. Default: False.
depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
Options: "depth_along_ray", "depth_z"
cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
exhaustive pairwise relative poses. Default: False.
convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
Use this if the predictions are not already in the view0 frame. Default: False.
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
"""
super().__init__(criterion)
if norm_mode.startswith("?"):
# Do no norm pts from metric scale datasets
self.norm_all = False
self.norm_mode = norm_mode[1:]
else:
self.norm_all = True
self.norm_mode = norm_mode
self.gt_scale = gt_scale
self.ambiguous_loss_value = ambiguous_loss_value
self.max_metric_scale = max_metric_scale
self.loss_in_log = loss_in_log
self.flatten_across_image_only = flatten_across_image_only
self.depth_type_for_loss = depth_type_for_loss
assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
"depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
)
self.cam_frame_points_loss_weight = cam_frame_points_loss_weight
self.depth_loss_weight = depth_loss_weight
self.ray_directions_loss_weight = ray_directions_loss_weight
self.pose_quats_loss_weight = pose_quats_loss_weight
self.pose_trans_loss_weight = pose_trans_loss_weight
self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame
self.compute_world_frame_points_loss = compute_world_frame_points_loss
self.world_frame_points_loss_weight = world_frame_points_loss_weight
def get_all_info(self, batch, preds, dist_clip=None):
"""
Function to get all the information needed to compute the loss.
Returns all quantities normalized w.r.t. camera of view0.
"""
n_views = len(batch)
# Everything is normalized w.r.t. camera of view0
# Intialize lists to store data for all views
# Ground truth quantities
in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
no_norm_gt_pts = []
no_norm_gt_pts_cam = []
no_norm_gt_depth = []
no_norm_gt_pose_trans = []
valid_masks = []
gt_ray_directions = []
gt_pose_quats = []
# Predicted quantities
if self.convert_predictions_to_view0_frame:
# Get the camera transform to convert quantities to view0 frame
pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze(
0
)
batch_size = preds[0]["cam_quats"].shape[0]
pred_camera0 = pred_camera0.repeat(batch_size, 1, 1)
pred_camera0_rot = quaternion_to_rotation_matrix(
preds[0]["cam_quats"].clone()
)
pred_camera0[..., :3, :3] = pred_camera0_rot
pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone()
pred_in_camera0 = closed_form_pose_inverse(pred_camera0)
no_norm_pr_pts = []
no_norm_pr_pts_cam = []
no_norm_pr_depth = []
no_norm_pr_pose_trans = []
pr_ray_directions = []
pr_pose_quats = []
# Get ground truth & prediction info for all views
for i in range(n_views):
# Get ground truth
no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
valid_masks.append(batch[i]["valid_mask"].clone())
no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
gt_ray_directions.append(batch[i]["ray_directions_cam"])
if self.depth_type_for_loss == "depth_along_ray":
no_norm_gt_depth.append(batch[i]["depth_along_ray"])
elif self.depth_type_for_loss == "depth_z":
no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
if i == 0:
# For view0, initialize identity pose
gt_pose_quats.append(
torch.tensor(
[0, 0, 0, 1],
dtype=gt_ray_directions[0].dtype,
device=gt_ray_directions[0].device,
)
.unsqueeze(0)
.repeat(gt_ray_directions[0].shape[0], 1)
)
no_norm_gt_pose_trans.append(
torch.tensor(
[0, 0, 0],
dtype=gt_ray_directions[0].dtype,
device=gt_ray_directions[0].device,
)
.unsqueeze(0)
.repeat(gt_ray_directions[0].shape[0], 1)
)
else:
# For other views, transform pose to view0's frame
gt_pose_quats_world = batch[i]["camera_pose_quats"]
no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
transform_pose_using_quats_and_trans_2_to_1(
batch[0]["camera_pose_quats"],
batch[0]["camera_pose_trans"],
gt_pose_quats_world,
no_norm_gt_pose_trans_world,
)
)
gt_pose_quats.append(gt_pose_quats_in_view0)
no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
# Get the local predictions
no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"])
pr_ray_directions.append(preds[i]["ray_directions"])
if self.depth_type_for_loss == "depth_along_ray":
no_norm_pr_depth.append(preds[i]["depth_along_ray"])
elif self.depth_type_for_loss == "depth_z":
no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:])
# Get the predicted global predictions in view0's frame
if self.convert_predictions_to_view0_frame:
# Convert predictions to view0 frame
pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"])
pr_pose_quats_in_view0, pr_pose_trans_in_view0 = (
transform_pose_using_quats_and_trans_2_to_1(
preds[0]["cam_quats"],
preds[0]["cam_trans"],
preds[i]["cam_quats"],
preds[i]["cam_trans"],
)
)
no_norm_pr_pts.append(pr_pts3d_in_view0)
no_norm_pr_pose_trans.append(pr_pose_trans_in_view0)
pr_pose_quats.append(pr_pose_quats_in_view0)
else:
# Predictions are already in view0 frame
no_norm_pr_pts.append(preds[i]["pts3d"])
no_norm_pr_pose_trans.append(preds[i]["cam_trans"])
pr_pose_quats.append(preds[i]["cam_quats"])
if dist_clip is not None:
# Points that are too far-away == invalid
for i in range(n_views):
dis = no_norm_gt_pts[i].norm(dim=-1)
valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
# Handle metric scale
if not self.norm_all:
if self.max_metric_scale:
B = valid_masks[0].shape[0]
dists_to_cam1 = []
for i in range(n_views):
dists_to_cam1.append(
torch.where(
valid_masks[i], torch.norm(no_norm_gt_pts[i], dim=-1), 0
).view(B, -1)
)
batch[0]["is_metric_scale"] = batch[0]["is_metric_scale"]
for dist in dists_to_cam1:
batch[0]["is_metric_scale"] &= (
dist.max(dim=-1).values < self.max_metric_scale
)
for i in range(1, n_views):
batch[i]["is_metric_scale"] = batch[0]["is_metric_scale"]
non_metric_scale_mask = ~batch[0]["is_metric_scale"]
else:
non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"])
# Initialize normalized tensors
gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
# Normalize points
if self.norm_mode and non_metric_scale_mask.any():
pr_normalization_output = normalize_multiple_pointclouds(
[pts[non_metric_scale_mask] for pts in no_norm_pr_pts],
[mask[non_metric_scale_mask] for mask in valid_masks],
self.norm_mode,
ret_factor=True,
)
pr_pts_norm = pr_normalization_output[:-1]
pr_norm_factor = pr_normalization_output[-1]
for i in range(n_views):
pr_pts[i][non_metric_scale_mask] = pr_pts_norm[i]
pr_pts_cam[i][non_metric_scale_mask] = (
no_norm_pr_pts_cam[i][non_metric_scale_mask] / pr_norm_factor
)
pr_depth[i][non_metric_scale_mask] = (
no_norm_pr_depth[i][non_metric_scale_mask] / pr_norm_factor
)
pr_pose_trans[i][non_metric_scale_mask] = (
no_norm_pr_pose_trans[i][non_metric_scale_mask]
/ pr_norm_factor[:, :, 0, 0]
)
elif non_metric_scale_mask.any():
for i in range(n_views):
pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][
non_metric_scale_mask
]
pr_pts_cam[i][non_metric_scale_mask] = no_norm_pr_pts_cam[i][
non_metric_scale_mask
]
pr_depth[i][non_metric_scale_mask] = no_norm_pr_depth[i][
non_metric_scale_mask
]
pr_pose_trans[i][non_metric_scale_mask] = no_norm_pr_pose_trans[i][
non_metric_scale_mask
]
if self.norm_mode and not self.gt_scale:
gt_normalization_output = normalize_multiple_pointclouds(
no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
)
gt_pts_norm = gt_normalization_output[:-1]
norm_factor = gt_normalization_output[-1]
for i in range(n_views):
gt_pts[i] = gt_pts_norm[i]
gt_pts_cam[i] = no_norm_gt_pts_cam[i] / norm_factor
gt_depth[i] = no_norm_gt_depth[i] / norm_factor
gt_pose_trans[i] = no_norm_gt_pose_trans[i] / norm_factor[:, :, 0, 0]
pr_pts[i][~non_metric_scale_mask] = (
no_norm_pr_pts[i][~non_metric_scale_mask]
/ norm_factor[~non_metric_scale_mask]
)
pr_pts_cam[i][~non_metric_scale_mask] = (
no_norm_pr_pts_cam[i][~non_metric_scale_mask]
/ norm_factor[~non_metric_scale_mask]
)
pr_depth[i][~non_metric_scale_mask] = (
no_norm_pr_depth[i][~non_metric_scale_mask]
/ norm_factor[~non_metric_scale_mask]
)
pr_pose_trans[i][~non_metric_scale_mask] = (
no_norm_pr_pose_trans[i][~non_metric_scale_mask]
/ norm_factor[~non_metric_scale_mask][:, :, 0, 0]
)
elif ~non_metric_scale_mask.any():
for i in range(n_views):
gt_pts[i] = no_norm_gt_pts[i]
gt_pts_cam[i] = no_norm_gt_pts_cam[i]
gt_depth[i] = no_norm_gt_depth[i]
gt_pose_trans[i] = no_norm_gt_pose_trans[i]
pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][
~non_metric_scale_mask
]
pr_pts_cam[i][~non_metric_scale_mask] = no_norm_pr_pts_cam[i][
~non_metric_scale_mask
]
pr_depth[i][~non_metric_scale_mask] = no_norm_pr_depth[i][
~non_metric_scale_mask
]
pr_pose_trans[i][~non_metric_scale_mask] = no_norm_pr_pose_trans[i][
~non_metric_scale_mask
]
else:
for i in range(n_views):
gt_pts[i] = no_norm_gt_pts[i]
gt_pts_cam[i] = no_norm_gt_pts_cam[i]
gt_depth[i] = no_norm_gt_depth[i]
gt_pose_trans[i] = no_norm_gt_pose_trans[i]
# Get ambiguous masks
ambiguous_masks = []
for i in range(n_views):
ambiguous_masks.append(
(~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
)
# Pack into info dicts
gt_info = []
pred_info = []
for i in range(n_views):
gt_info.append(
{
"ray_directions": gt_ray_directions[i],
self.depth_type_for_loss: gt_depth[i],
"pose_trans": gt_pose_trans[i],
"pose_quats": gt_pose_quats[i],
"pts3d": gt_pts[i],
"pts3d_cam": gt_pts_cam[i],
}
)
pred_info.append(
{
"ray_directions": pr_ray_directions[i],
self.depth_type_for_loss: pr_depth[i],
"pose_trans": pr_pose_trans[i],
"pose_quats": pr_pose_quats[i],
"pts3d": pr_pts[i],
"pts3d_cam": pr_pts_cam[i],
}
)
return gt_info, pred_info, valid_masks, ambiguous_masks
def compute_loss(self, batch, preds, **kw):
gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info(
batch, preds, **kw
)
n_views = len(batch)
# Mask out samples in the batch where the gt depth validity mask is entirely zero
valid_norm_factor_masks = [
mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
] # List of (B,)
if self.ambiguous_loss_value > 0:
assert self.criterion.reduction == "none", (
"ambiguous_loss_value should be 0 if no conf loss"
)
# Add the ambiguous pixel as "valid" pixels...
valid_masks = [
mask | ambig_mask
for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
]
pose_trans_losses = []
pose_quats_losses = []
ray_directions_losses = []
depth_losses = []
cam_pts3d_losses = []
if self.compute_world_frame_points_loss:
pts3d_losses = []
for i in range(n_views):
# Get the predicted dense quantities
if not self.flatten_across_image_only:
# Flatten the points across the entire batch with the masks
pred_ray_directions = pred_info[i]["ray_directions"]
gt_ray_directions = gt_info[i]["ray_directions"]
pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
if self.compute_world_frame_points_loss:
pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
else:
# Flatten the H x W dimensions to H*W
batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
gt_ray_directions = gt_info[i]["ray_directions"].view(
batch_size, -1, direction_dim
)
pred_ray_directions = pred_info[i]["ray_directions"].view(
batch_size, -1, direction_dim
)
depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
gt_depth = gt_info[i][self.depth_type_for_loss].view(
batch_size, -1, depth_dim
)
pred_depth = pred_info[i][self.depth_type_for_loss].view(
batch_size, -1, depth_dim
)
cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
batch_size, -1, cam_pts_dim
)
if self.compute_world_frame_points_loss:
pts_dim = gt_info[i]["pts3d"].shape[-1]
gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
valid_masks[i] = valid_masks[i].view(batch_size, -1)
# Apply loss in log space for depth if specified
if self.loss_in_log:
gt_depth = apply_log_to_norm(gt_depth)
pred_depth = apply_log_to_norm(pred_depth)
gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
if self.compute_world_frame_points_loss:
gt_pts3d = apply_log_to_norm(gt_pts3d)
pred_pts3d = apply_log_to_norm(pred_pts3d)
if self.compute_pairwise_relative_pose_loss:
# Get the inverse of current view predicted pose
pred_inv_curr_view_pose_quats = quaternion_inverse(
pred_info[i]["pose_quats"]
)
pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
pred_inv_curr_view_pose_quats
)
pred_inv_curr_view_pose_trans = -1 * ein.einsum(
pred_inv_curr_view_pose_rot_mat,
pred_info[i]["pose_trans"],
"b i j, b j -> b i",
)
# Get the inverse of the current view GT pose
gt_inv_curr_view_pose_quats = quaternion_inverse(
gt_info[i]["pose_quats"]
)
gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
gt_inv_curr_view_pose_quats
)
gt_inv_curr_view_pose_trans = -1 * ein.einsum(
gt_inv_curr_view_pose_rot_mat,
gt_info[i]["pose_trans"],
"b i j, b j -> b i",
)
# Get the other N-1 relative poses using the current pose as reference frame
pred_rel_pose_quats = []
pred_rel_pose_trans = []
gt_rel_pose_quats = []
gt_rel_pose_trans = []
for ov_idx in range(n_views):
if ov_idx == i:
continue
# Get the relative predicted pose
pred_ov_rel_pose_quats = quaternion_multiply(
pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
)
pred_ov_rel_pose_trans = (
ein.einsum(
pred_inv_curr_view_pose_rot_mat,
pred_info[ov_idx]["pose_trans"],
"b i j, b j -> b i",
)
+ pred_inv_curr_view_pose_trans
)
# Get the relative GT pose
gt_ov_rel_pose_quats = quaternion_multiply(
gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
)
gt_ov_rel_pose_trans = (
ein.einsum(
gt_inv_curr_view_pose_rot_mat,
gt_info[ov_idx]["pose_trans"],
"b i j, b j -> b i",
)
+ gt_inv_curr_view_pose_trans
)
# Get the valid translations using valid_norm_factor_masks for current view and other view
overall_valid_mask_for_trans = (
valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
)
# Append the relative poses
pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
pred_rel_pose_trans.append(
pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
)
gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
gt_rel_pose_trans.append(
gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
)
# Cat the N-1 relative poses along the batch dimension
pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
# Handle quaternion two-to-one mapping
pose_quats_loss = torch.minimum(
self.criterion(
pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
),
self.criterion(
pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
),
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
else:
# Get the pose info for the current view
pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
pred_pose_quats = pred_info[i]["pose_quats"]
gt_pose_quats = gt_info[i]["pose_quats"]
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_pose_trans, gt_pose_trans, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
# Handle quaternion two-to-one mapping
pose_quats_loss = torch.minimum(
self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
self.criterion(
pred_pose_quats, -gt_pose_quats, factor="pose_quats"
),
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
# Compute ray direction loss
ray_directions_loss = self.criterion(
pred_ray_directions, gt_ray_directions, factor="ray_directions"
)
ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
ray_directions_losses.append(ray_directions_loss)
# Compute depth loss
depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
depth_loss = depth_loss * self.depth_loss_weight
depth_losses.append(depth_loss)
# Compute camera frame point loss
cam_pts3d_loss = self.criterion(
pred_cam_pts3d, gt_cam_pts3d, factor="points"
)
cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
cam_pts3d_losses.append(cam_pts3d_loss)
if self.compute_world_frame_points_loss:
# Compute point loss
pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
pts3d_losses.append(pts3d_loss)
# Handle ambiguous pixels
if self.ambiguous_loss_value > 0:
if not self.flatten_across_image_only:
depth_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
depth_losses[i],
)
cam_pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
cam_pts3d_losses[i],
)
if self.compute_world_frame_points_loss:
pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
pts3d_losses[i],
)
else:
depth_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
depth_losses[i],
)
cam_pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
cam_pts3d_losses[i],
)
if self.compute_world_frame_points_loss:
pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
pts3d_losses[i],
)
# Use helper function to generate loss terms and details
if self.compute_world_frame_points_loss:
losses_dict = {
"pts3d": {
"values": pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
}
else:
losses_dict = {}
losses_dict.update(
{
"cam_pts3d": {
"values": cam_pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
self.depth_type_for_loss: {
"values": depth_losses,
"use_mask": True,
"is_multi_view": True,
},
"ray_directions": {
"values": ray_directions_losses,
"use_mask": False,
"is_multi_view": True,
},
"pose_quats": {
"values": pose_quats_losses,
"use_mask": False,
"is_multi_view": True,
},
"pose_trans": {
"values": pose_trans_losses,
"use_mask": False,
"is_multi_view": True,
},
}
)
loss_terms, details = get_loss_terms_and_details(
losses_dict,
valid_masks,
type(self).__name__,
n_views,
self.flatten_across_image_only,
)
losses = Sum(*loss_terms)
return losses, (details | {})
class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
"""
Regression, Normals & Gradient Matching Loss for Factored Geometry.
"""
def __init__(
self,
criterion,
norm_mode="?avg_dis",
gt_scale=False,
ambiguous_loss_value=0,
max_metric_scale=False,
loss_in_log=True,
flatten_across_image_only=False,
depth_type_for_loss="depth_along_ray",
cam_frame_points_loss_weight=1,
depth_loss_weight=1,
ray_directions_loss_weight=1,
pose_quats_loss_weight=1,
pose_trans_loss_weight=1,
compute_pairwise_relative_pose_loss=False,
convert_predictions_to_view0_frame=False,
compute_world_frame_points_loss=True,
world_frame_points_loss_weight=1,
apply_normal_and_gm_loss_to_synthetic_data_only=True,
normal_loss_weight=1,
gm_loss_weight=1,
):
"""
Initialize the loss criterion for Factored Geometry (see parent class for details).
Additionally computes:
(1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
(2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
Args:
criterion (BaseCriterion): The base criterion to use for computing the loss.
norm_mode (str): Normalization mode for scene representation. Default: "avg_dis".
If prefixed with "?", normalization is only applied to non-metric scale data.
gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
If False, both GT and predictions are normalized independently. Default: False.
ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
If 0, ambiguous pixels are ignored. Default: 0.
max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
value, it will be treated as non-metric. Default: False (no limit).
loss_in_log (bool): If True, apply logarithmic transformation to input before
computing the loss for depth and pointmaps. Default: True.
flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
the loss. If False, flatten across batch and spatial dimensions. Default: False.
depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
Options: "depth_along_ray", "depth_z"
cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
exhaustive pairwise relative poses. Default: False.
convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
Use this if the predictions are not already in the view0 frame. Default: False.
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
If False, apply the normal and gm loss to all data. Default: True.
normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
"""
super().__init__(
criterion=criterion,
norm_mode=norm_mode,
gt_scale=gt_scale,
ambiguous_loss_value=ambiguous_loss_value,
max_metric_scale=max_metric_scale,
loss_in_log=loss_in_log,
flatten_across_image_only=flatten_across_image_only,
depth_type_for_loss=depth_type_for_loss,
cam_frame_points_loss_weight=cam_frame_points_loss_weight,
depth_loss_weight=depth_loss_weight,
ray_directions_loss_weight=ray_directions_loss_weight,
pose_quats_loss_weight=pose_quats_loss_weight,
pose_trans_loss_weight=pose_trans_loss_weight,
compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
convert_predictions_to_view0_frame=convert_predictions_to_view0_frame,
compute_world_frame_points_loss=compute_world_frame_points_loss,
world_frame_points_loss_weight=world_frame_points_loss_weight,
)
self.apply_normal_and_gm_loss_to_synthetic_data_only = (
apply_normal_and_gm_loss_to_synthetic_data_only
)
self.normal_loss_weight = normal_loss_weight
self.gm_loss_weight = gm_loss_weight
def compute_loss(self, batch, preds, **kw):
gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info(
batch, preds, **kw
)
n_views = len(batch)
# Mask out samples in the batch where the gt depth validity mask is entirely zero
valid_norm_factor_masks = [
mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
] # List of (B,)
if self.ambiguous_loss_value > 0:
assert self.criterion.reduction == "none", (
"ambiguous_loss_value should be 0 if no conf loss"
)
# Add the ambiguous pixel as "valid" pixels...
valid_masks = [
mask | ambig_mask
for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
]
normal_losses = []
gradient_matching_losses = []
pose_trans_losses = []
pose_quats_losses = []
ray_directions_losses = []
depth_losses = []
cam_pts3d_losses = []
if self.compute_world_frame_points_loss:
pts3d_losses = []
for i in range(n_views):
# Get the camera frame points, log space depth_z & valid masks
pred_local_pts3d = pred_info[i]["pts3d_cam"]
pred_depth_z = pred_local_pts3d[..., 2:]
pred_depth_z = apply_log_to_norm(pred_depth_z)
gt_local_pts3d = gt_info[i]["pts3d_cam"]
gt_depth_z = gt_local_pts3d[..., 2:]
gt_depth_z = apply_log_to_norm(gt_depth_z)
valid_mask_for_normal_gm_loss = valid_masks[i].clone()
# Update the validity mask for normal & gm loss based on the synthetic data mask if required
if self.apply_normal_and_gm_loss_to_synthetic_data_only:
synthetic_mask = batch[i]["is_synthetic"] # (B, )
synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
synthetic_mask = synthetic_mask.expand(
-1, pred_depth_z.shape[1], pred_depth_z.shape[2]
) # (B, H, W)
valid_mask_for_normal_gm_loss = (
valid_mask_for_normal_gm_loss & synthetic_mask
)
# Compute the normal loss
normal_loss = compute_normal_loss(
pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
)
normal_loss = normal_loss * self.normal_loss_weight
normal_losses.append(normal_loss)
# Compute the gradient matching loss
gradient_matching_loss = compute_gradient_matching_loss(
pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
)
gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
gradient_matching_losses.append(gradient_matching_loss)
# Get the predicted dense quantities
if not self.flatten_across_image_only:
# Flatten the points across the entire batch with the masks
pred_ray_directions = pred_info[i]["ray_directions"]
gt_ray_directions = gt_info[i]["ray_directions"]
pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
if self.compute_world_frame_points_loss:
pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
else:
# Flatten the H x W dimensions to H*W
batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
gt_ray_directions = gt_info[i]["ray_directions"].view(
batch_size, -1, direction_dim
)
pred_ray_directions = pred_info[i]["ray_directions"].view(
batch_size, -1, direction_dim
)
depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
gt_depth = gt_info[i][self.depth_type_for_loss].view(
batch_size, -1, depth_dim
)
pred_depth = pred_info[i][self.depth_type_for_loss].view(
batch_size, -1, depth_dim
)
cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
batch_size, -1, cam_pts_dim
)
if self.compute_world_frame_points_loss:
pts_dim = gt_info[i]["pts3d"].shape[-1]
gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
valid_masks[i] = valid_masks[i].view(batch_size, -1)
# Apply loss in log space for depth if specified
if self.loss_in_log:
gt_depth = apply_log_to_norm(gt_depth)
pred_depth = apply_log_to_norm(pred_depth)
gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
if self.compute_world_frame_points_loss:
gt_pts3d = apply_log_to_norm(gt_pts3d)
pred_pts3d = apply_log_to_norm(pred_pts3d)
if self.compute_pairwise_relative_pose_loss:
# Get the inverse of current view predicted pose
pred_inv_curr_view_pose_quats = quaternion_inverse(
pred_info[i]["pose_quats"]
)
pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
pred_inv_curr_view_pose_quats
)
pred_inv_curr_view_pose_trans = -1 * ein.einsum(
pred_inv_curr_view_pose_rot_mat,
pred_info[i]["pose_trans"],
"b i j, b j -> b i",
)
# Get the inverse of the current view GT pose
gt_inv_curr_view_pose_quats = quaternion_inverse(
gt_info[i]["pose_quats"]
)
gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
gt_inv_curr_view_pose_quats
)
gt_inv_curr_view_pose_trans = -1 * ein.einsum(
gt_inv_curr_view_pose_rot_mat,
gt_info[i]["pose_trans"],
"b i j, b j -> b i",
)
# Get the other N-1 relative poses using the current pose as reference frame
pred_rel_pose_quats = []
pred_rel_pose_trans = []
gt_rel_pose_quats = []
gt_rel_pose_trans = []
for ov_idx in range(n_views):
if ov_idx == i:
continue
# Get the relative predicted pose
pred_ov_rel_pose_quats = quaternion_multiply(
pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
)
pred_ov_rel_pose_trans = (
ein.einsum(
pred_inv_curr_view_pose_rot_mat,
pred_info[ov_idx]["pose_trans"],
"b i j, b j -> b i",
)
+ pred_inv_curr_view_pose_trans
)
# Get the relative GT pose
gt_ov_rel_pose_quats = quaternion_multiply(
gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
)
gt_ov_rel_pose_trans = (
ein.einsum(
gt_inv_curr_view_pose_rot_mat,
gt_info[ov_idx]["pose_trans"],
"b i j, b j -> b i",
)
+ gt_inv_curr_view_pose_trans
)
# Get the valid translations using valid_norm_factor_masks for current view and other view
overall_valid_mask_for_trans = (
valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
)
# Append the relative poses
pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
pred_rel_pose_trans.append(
pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
)
gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
gt_rel_pose_trans.append(
gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
)
# Cat the N-1 relative poses along the batch dimension
pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
# Handle quaternion two-to-one mapping
pose_quats_loss = torch.minimum(
self.criterion(
pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
),
self.criterion(
pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
),
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
else:
# Get the pose info for the current view
pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
pred_pose_quats = pred_info[i]["pose_quats"]
gt_pose_quats = gt_info[i]["pose_quats"]
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_pose_trans, gt_pose_trans, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
# Handle quaternion two-to-one mapping
pose_quats_loss = torch.minimum(
self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
self.criterion(
pred_pose_quats, -gt_pose_quats, factor="pose_quats"
),
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
# Compute ray direction loss
ray_directions_loss = self.criterion(
pred_ray_directions, gt_ray_directions, factor="ray_directions"
)
ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
ray_directions_losses.append(ray_directions_loss)
# Compute depth loss
depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
depth_loss = depth_loss * self.depth_loss_weight
depth_losses.append(depth_loss)
# Compute camera frame point loss
cam_pts3d_loss = self.criterion(
pred_cam_pts3d, gt_cam_pts3d, factor="points"
)
cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
cam_pts3d_losses.append(cam_pts3d_loss)
if self.compute_world_frame_points_loss:
# Compute point loss
pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
pts3d_losses.append(pts3d_loss)
# Handle ambiguous pixels
if self.ambiguous_loss_value > 0:
if not self.flatten_across_image_only:
depth_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
depth_losses[i],
)
cam_pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
cam_pts3d_losses[i],
)
if self.compute_world_frame_points_loss:
pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
pts3d_losses[i],
)
else:
depth_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
depth_losses[i],
)
cam_pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
cam_pts3d_losses[i],
)
if self.compute_world_frame_points_loss:
pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
pts3d_losses[i],
)
# Use helper function to generate loss terms and details
if self.compute_world_frame_points_loss:
losses_dict = {
"pts3d": {
"values": pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
}
else:
losses_dict = {}
losses_dict.update(
{
"cam_pts3d": {
"values": cam_pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
self.depth_type_for_loss: {
"values": depth_losses,
"use_mask": True,
"is_multi_view": True,
},
"ray_directions": {
"values": ray_directions_losses,
"use_mask": False,
"is_multi_view": True,
},
"pose_quats": {
"values": pose_quats_losses,
"use_mask": False,
"is_multi_view": True,
},
"pose_trans": {
"values": pose_trans_losses,
"use_mask": False,
"is_multi_view": True,
},
"normal": {
"values": normal_losses,
"use_mask": False,
"is_multi_view": True,
},
"gradient_matching": {
"values": gradient_matching_losses,
"use_mask": False,
"is_multi_view": True,
},
}
)
loss_terms, details = get_loss_terms_and_details(
losses_dict,
valid_masks,
type(self).__name__,
n_views,
self.flatten_across_image_only,
)
losses = Sum(*loss_terms)
return losses, (details | {})
class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
"""
Regression Loss for Factored Geometry & Scale.
"""
def __init__(
self,
criterion,
norm_predictions=True,
norm_mode="avg_dis",
ambiguous_loss_value=0,
loss_in_log=True,
flatten_across_image_only=False,
depth_type_for_loss="depth_along_ray",
cam_frame_points_loss_weight=1,
depth_loss_weight=1,
ray_directions_loss_weight=1,
pose_quats_loss_weight=1,
pose_trans_loss_weight=1,
scale_loss_weight=1,
compute_pairwise_relative_pose_loss=False,
convert_predictions_to_view0_frame=False,
compute_world_frame_points_loss=True,
world_frame_points_loss_weight=1,
):
"""
Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), Scale
and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps.
If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order:
(1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans, (7) scale.
Else, the pixel-level losses are returned in the following order:
(1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans, (6) scale.
The predicited scene representation is always normalized w.r.t. the frame of view0.
Loss is applied between the predicted metric scale and the ground truth metric scale.
Args:
criterion (BaseCriterion): The base criterion to use for computing the loss.
norm_predictions (bool): If True, normalize the predictions before computing the loss.
norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
If 0, ambiguous pixels are ignored. Default: 0.
loss_in_log (bool): If True, apply logarithmic transformation to input before
computing the loss for depth, pointmaps and scale. Default: True.
flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
the loss. If False, flatten across batch and spatial dimensions. Default: False.
depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
Options: "depth_along_ray", "depth_z"
cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
exhaustive pairwise relative poses. Default: False.
convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
Use this if the predictions are not already in the view0 frame. Default: False.
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
"""
super().__init__(criterion)
self.norm_predictions = norm_predictions
self.norm_mode = norm_mode
self.ambiguous_loss_value = ambiguous_loss_value
self.loss_in_log = loss_in_log
self.flatten_across_image_only = flatten_across_image_only
self.depth_type_for_loss = depth_type_for_loss
assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
"depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
)
self.cam_frame_points_loss_weight = cam_frame_points_loss_weight
self.depth_loss_weight = depth_loss_weight
self.ray_directions_loss_weight = ray_directions_loss_weight
self.pose_quats_loss_weight = pose_quats_loss_weight
self.pose_trans_loss_weight = pose_trans_loss_weight
self.scale_loss_weight = scale_loss_weight
self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame
self.compute_world_frame_points_loss = compute_world_frame_points_loss
self.world_frame_points_loss_weight = world_frame_points_loss_weight
def get_all_info(self, batch, preds, dist_clip=None):
"""
Function to get all the information needed to compute the loss.
Returns all quantities normalized w.r.t. camera of view0.
"""
n_views = len(batch)
# Everything is normalized w.r.t. camera of view0
# Intialize lists to store data for all views
# Ground truth quantities
in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
no_norm_gt_pts = []
no_norm_gt_pts_cam = []
no_norm_gt_depth = []
no_norm_gt_pose_trans = []
valid_masks = []
gt_ray_directions = []
gt_pose_quats = []
# Predicted quantities
if self.convert_predictions_to_view0_frame:
# Get the camera transform to convert quantities to view0 frame
pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze(
0
)
batch_size = preds[0]["cam_quats"].shape[0]
pred_camera0 = pred_camera0.repeat(batch_size, 1, 1)
pred_camera0_rot = quaternion_to_rotation_matrix(
preds[0]["cam_quats"].clone()
)
pred_camera0[..., :3, :3] = pred_camera0_rot
pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone()
pred_in_camera0 = closed_form_pose_inverse(pred_camera0)
no_norm_pr_pts = []
no_norm_pr_pts_cam = []
no_norm_pr_depth = []
no_norm_pr_pose_trans = []
pr_ray_directions = []
pr_pose_quats = []
metric_pr_pts_to_compute_scale = []
# Get ground truth & prediction info for all views
for i in range(n_views):
# Get the ground truth
no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
valid_masks.append(batch[i]["valid_mask"].clone())
no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
gt_ray_directions.append(batch[i]["ray_directions_cam"])
if self.depth_type_for_loss == "depth_along_ray":
no_norm_gt_depth.append(batch[i]["depth_along_ray"])
elif self.depth_type_for_loss == "depth_z":
no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
if i == 0:
# For view0, initialize identity pose
gt_pose_quats.append(
torch.tensor(
[0, 0, 0, 1],
dtype=gt_ray_directions[0].dtype,
device=gt_ray_directions[0].device,
)
.unsqueeze(0)
.repeat(gt_ray_directions[0].shape[0], 1)
)
no_norm_gt_pose_trans.append(
torch.tensor(
[0, 0, 0],
dtype=gt_ray_directions[0].dtype,
device=gt_ray_directions[0].device,
)
.unsqueeze(0)
.repeat(gt_ray_directions[0].shape[0], 1)
)
else:
# For other views, transform pose to view0's frame
gt_pose_quats_world = batch[i]["camera_pose_quats"]
no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
transform_pose_using_quats_and_trans_2_to_1(
batch[0]["camera_pose_quats"],
batch[0]["camera_pose_trans"],
gt_pose_quats_world,
no_norm_gt_pose_trans_world,
)
)
gt_pose_quats.append(gt_pose_quats_in_view0)
no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
# Get the global predictions in view0's frame
if self.convert_predictions_to_view0_frame:
# Convert predictions to view0 frame
pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"])
pr_pose_quats_in_view0, pr_pose_trans_in_view0 = (
transform_pose_using_quats_and_trans_2_to_1(
preds[0]["cam_quats"],
preds[0]["cam_trans"],
preds[i]["cam_quats"],
preds[i]["cam_trans"],
)
)
else:
# Predictions are already in view0 frame
pr_pts3d_in_view0 = preds[i]["pts3d"]
pr_pose_trans_in_view0 = preds[i]["cam_trans"]
pr_pose_quats_in_view0 = preds[i]["cam_quats"]
# Get predictions for normalized loss
if self.depth_type_for_loss == "depth_along_ray":
curr_view_no_norm_depth = preds[i]["depth_along_ray"]
elif self.depth_type_for_loss == "depth_z":
curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:]
if "metric_scaling_factor" in preds[i].keys():
# Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
# This detaches the predicted metric scaling factor from the geometry based loss
curr_view_no_norm_pr_pts = pr_pts3d_in_view0 / preds[i][
"metric_scaling_factor"
].unsqueeze(-1).unsqueeze(-1)
curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
"metric_scaling_factor"
].unsqueeze(-1).unsqueeze(-1)
curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][
"metric_scaling_factor"
].unsqueeze(-1).unsqueeze(-1)
curr_view_no_norm_pr_pose_trans = (
pr_pose_trans_in_view0 / preds[i]["metric_scaling_factor"]
)
else:
curr_view_no_norm_pr_pts = pr_pts3d_in_view0
curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
curr_view_no_norm_depth = curr_view_no_norm_depth
curr_view_no_norm_pr_pose_trans = pr_pose_trans_in_view0
no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
no_norm_pr_depth.append(curr_view_no_norm_depth)
no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
pr_ray_directions.append(preds[i]["ray_directions"])
pr_pose_quats.append(pr_pose_quats_in_view0)
# Get the predicted metric scale points
if "metric_scaling_factor" in preds[i].keys():
# Detach the raw predicted points so that the scale loss is only applied to the scaling factor
curr_view_metric_pr_pts_to_compute_scale = (
curr_view_no_norm_pr_pts.detach()
* preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
)
else:
curr_view_metric_pr_pts_to_compute_scale = (
curr_view_no_norm_pr_pts.clone()
)
metric_pr_pts_to_compute_scale.append(
curr_view_metric_pr_pts_to_compute_scale
)
if dist_clip is not None:
# Points that are too far-away == invalid
for i in range(n_views):
dis = no_norm_gt_pts[i].norm(dim=-1)
valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
# Initialize normalized tensors
gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
# Normalize the predicted points if specified
if self.norm_predictions:
pr_normalization_output = normalize_multiple_pointclouds(
no_norm_pr_pts,
valid_masks,
self.norm_mode,
ret_factor=True,
)
pr_pts_norm = pr_normalization_output[:-1]
pr_norm_factor = pr_normalization_output[-1]
# Normalize the ground truth points
gt_normalization_output = normalize_multiple_pointclouds(
no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
)
gt_pts_norm = gt_normalization_output[:-1]
gt_norm_factor = gt_normalization_output[-1]
for i in range(n_views):
if self.norm_predictions:
# Assign the normalized predictions
pr_pts[i] = pr_pts_norm[i]
pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor
pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor
pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0]
else:
pr_pts[i] = no_norm_pr_pts[i]
pr_pts_cam[i] = no_norm_pr_pts_cam[i]
pr_depth[i] = no_norm_pr_depth[i]
pr_pose_trans[i] = no_norm_pr_pose_trans[i]
# Assign the normalized ground truth quantities
gt_pts[i] = gt_pts_norm[i]
gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor
gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor
gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0]
# Get the mask indicating ground truth metric scale quantities
metric_scale_mask = batch[0]["is_metric_scale"]
valid_gt_norm_factor_mask = (
gt_norm_factor[:, 0, 0, 0] > 1e-8
) # Mask out cases where depth for all views is invalid
valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
if valid_metric_scale_mask.any():
# Compute the scale norm factor using the predicted metric scale points
metric_pr_normalization_output = normalize_multiple_pointclouds(
metric_pr_pts_to_compute_scale,
valid_masks,
self.norm_mode,
ret_factor=True,
)
pr_metric_norm_factor = metric_pr_normalization_output[-1]
# Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
else:
gt_metric_norm_factor = None
pr_metric_norm_factor = None
# Get ambiguous masks
ambiguous_masks = []
for i in range(n_views):
ambiguous_masks.append(
(~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
)
# Pack into info dicts
gt_info = []
pred_info = []
for i in range(n_views):
gt_info.append(
{
"ray_directions": gt_ray_directions[i],
self.depth_type_for_loss: gt_depth[i],
"pose_trans": gt_pose_trans[i],
"pose_quats": gt_pose_quats[i],
"pts3d": gt_pts[i],
"pts3d_cam": gt_pts_cam[i],
}
)
pred_info.append(
{
"ray_directions": pr_ray_directions[i],
self.depth_type_for_loss: pr_depth[i],
"pose_trans": pr_pose_trans[i],
"pose_quats": pr_pose_quats[i],
"pts3d": pr_pts[i],
"pts3d_cam": pr_pts_cam[i],
}
)
return (
gt_info,
pred_info,
valid_masks,
ambiguous_masks,
gt_metric_norm_factor,
pr_metric_norm_factor,
)
def compute_loss(self, batch, preds, **kw):
(
gt_info,
pred_info,
valid_masks,
ambiguous_masks,
gt_metric_norm_factor,
pr_metric_norm_factor,
) = self.get_all_info(batch, preds, **kw)
n_views = len(batch)
# Mask out samples in the batch where the gt depth validity mask is entirely zero
valid_norm_factor_masks = [
mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
] # List of (B,)
if self.ambiguous_loss_value > 0:
assert self.criterion.reduction == "none", (
"ambiguous_loss_value should be 0 if no conf loss"
)
# Add the ambiguous pixel as "valid" pixels...
valid_masks = [
mask | ambig_mask
for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
]
pose_trans_losses = []
pose_quats_losses = []
ray_directions_losses = []
depth_losses = []
cam_pts3d_losses = []
if self.compute_world_frame_points_loss:
pts3d_losses = []
for i in range(n_views):
# Get the predicted dense quantities
if not self.flatten_across_image_only:
# Flatten the points across the entire batch with the masks
pred_ray_directions = pred_info[i]["ray_directions"]
gt_ray_directions = gt_info[i]["ray_directions"]
pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
if self.compute_world_frame_points_loss:
pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
else:
# Flatten the H x W dimensions to H*W
batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
gt_ray_directions = gt_info[i]["ray_directions"].view(
batch_size, -1, direction_dim
)
pred_ray_directions = pred_info[i]["ray_directions"].view(
batch_size, -1, direction_dim
)
depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
gt_depth = gt_info[i][self.depth_type_for_loss].view(
batch_size, -1, depth_dim
)
pred_depth = pred_info[i][self.depth_type_for_loss].view(
batch_size, -1, depth_dim
)
cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
batch_size, -1, cam_pts_dim
)
if self.compute_world_frame_points_loss:
pts_dim = gt_info[i]["pts3d"].shape[-1]
gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
valid_masks[i] = valid_masks[i].view(batch_size, -1)
# Apply loss in log space for depth if specified
if self.loss_in_log:
gt_depth = apply_log_to_norm(gt_depth)
pred_depth = apply_log_to_norm(pred_depth)
gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
if self.compute_world_frame_points_loss:
gt_pts3d = apply_log_to_norm(gt_pts3d)
pred_pts3d = apply_log_to_norm(pred_pts3d)
if self.compute_pairwise_relative_pose_loss:
# Get the inverse of current view predicted pose
pred_inv_curr_view_pose_quats = quaternion_inverse(
pred_info[i]["pose_quats"]
)
pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
pred_inv_curr_view_pose_quats
)
pred_inv_curr_view_pose_trans = -1 * ein.einsum(
pred_inv_curr_view_pose_rot_mat,
pred_info[i]["pose_trans"],
"b i j, b j -> b i",
)
# Get the inverse of the current view GT pose
gt_inv_curr_view_pose_quats = quaternion_inverse(
gt_info[i]["pose_quats"]
)
gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
gt_inv_curr_view_pose_quats
)
gt_inv_curr_view_pose_trans = -1 * ein.einsum(
gt_inv_curr_view_pose_rot_mat,
gt_info[i]["pose_trans"],
"b i j, b j -> b i",
)
# Get the other N-1 relative poses using the current pose as reference frame
pred_rel_pose_quats = []
pred_rel_pose_trans = []
gt_rel_pose_quats = []
gt_rel_pose_trans = []
for ov_idx in range(n_views):
if ov_idx == i:
continue
# Get the relative predicted pose
pred_ov_rel_pose_quats = quaternion_multiply(
pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
)
pred_ov_rel_pose_trans = (
ein.einsum(
pred_inv_curr_view_pose_rot_mat,
pred_info[ov_idx]["pose_trans"],
"b i j, b j -> b i",
)
+ pred_inv_curr_view_pose_trans
)
# Get the relative GT pose
gt_ov_rel_pose_quats = quaternion_multiply(
gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
)
gt_ov_rel_pose_trans = (
ein.einsum(
gt_inv_curr_view_pose_rot_mat,
gt_info[ov_idx]["pose_trans"],
"b i j, b j -> b i",
)
+ gt_inv_curr_view_pose_trans
)
# Get the valid translations using valid_norm_factor_masks for current view and other view
overall_valid_mask_for_trans = (
valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
)
# Append the relative poses
pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
pred_rel_pose_trans.append(
pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
)
gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
gt_rel_pose_trans.append(
gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
)
# Cat the N-1 relative poses along the batch dimension
pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
# Handle quaternion two-to-one mapping
pose_quats_loss = torch.minimum(
self.criterion(
pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
),
self.criterion(
pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
),
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
else:
# Get the pose info for the current view
pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
pred_pose_quats = pred_info[i]["pose_quats"]
gt_pose_quats = gt_info[i]["pose_quats"]
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_pose_trans, gt_pose_trans, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
# Handle quaternion two-to-one mapping
pose_quats_loss = torch.minimum(
self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
self.criterion(
pred_pose_quats, -gt_pose_quats, factor="pose_quats"
),
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
# Compute ray direction loss
ray_directions_loss = self.criterion(
pred_ray_directions, gt_ray_directions, factor="ray_directions"
)
ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
ray_directions_losses.append(ray_directions_loss)
# Compute depth loss
depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
depth_loss = depth_loss * self.depth_loss_weight
depth_losses.append(depth_loss)
# Compute camera frame point loss
cam_pts3d_loss = self.criterion(
pred_cam_pts3d, gt_cam_pts3d, factor="points"
)
cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
cam_pts3d_losses.append(cam_pts3d_loss)
if self.compute_world_frame_points_loss:
# Compute point loss
pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
pts3d_losses.append(pts3d_loss)
# Handle ambiguous pixels
if self.ambiguous_loss_value > 0:
if not self.flatten_across_image_only:
depth_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
depth_losses[i],
)
cam_pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
cam_pts3d_losses[i],
)
if self.compute_world_frame_points_loss:
pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
pts3d_losses[i],
)
else:
depth_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
depth_losses[i],
)
cam_pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
cam_pts3d_losses[i],
)
if self.compute_world_frame_points_loss:
pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
pts3d_losses[i],
)
# Compute the scale loss
if gt_metric_norm_factor is not None:
if self.loss_in_log:
gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
scale_loss = (
self.criterion(
pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
)
* self.scale_loss_weight
)
else:
scale_loss = None
# Use helper function to generate loss terms and details
if self.compute_world_frame_points_loss:
losses_dict = {
"pts3d": {
"values": pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
}
else:
losses_dict = {}
losses_dict.update(
{
"cam_pts3d": {
"values": cam_pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
self.depth_type_for_loss: {
"values": depth_losses,
"use_mask": True,
"is_multi_view": True,
},
"ray_directions": {
"values": ray_directions_losses,
"use_mask": False,
"is_multi_view": True,
},
"pose_quats": {
"values": pose_quats_losses,
"use_mask": False,
"is_multi_view": True,
},
"pose_trans": {
"values": pose_trans_losses,
"use_mask": False,
"is_multi_view": True,
},
"scale": {
"values": scale_loss,
"use_mask": False,
"is_multi_view": False,
},
}
)
loss_terms, details = get_loss_terms_and_details(
losses_dict,
valid_masks,
type(self).__name__,
n_views,
self.flatten_across_image_only,
)
losses = Sum(*loss_terms)
return losses, (details | {})
class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
"""
Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale.
"""
def __init__(
self,
criterion,
norm_predictions=True,
norm_mode="avg_dis",
ambiguous_loss_value=0,
loss_in_log=True,
flatten_across_image_only=False,
depth_type_for_loss="depth_along_ray",
cam_frame_points_loss_weight=1,
depth_loss_weight=1,
ray_directions_loss_weight=1,
pose_quats_loss_weight=1,
pose_trans_loss_weight=1,
scale_loss_weight=1,
compute_pairwise_relative_pose_loss=False,
convert_predictions_to_view0_frame=False,
compute_world_frame_points_loss=True,
world_frame_points_loss_weight=1,
apply_normal_and_gm_loss_to_synthetic_data_only=True,
normal_loss_weight=1,
gm_loss_weight=1,
):
"""
Initialize the loss criterion for Ray Directions, Depth, Pose, Pointmaps & Scale.
Additionally computes:
(1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
(2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
Args:
criterion (BaseCriterion): The base criterion to use for computing the loss.
norm_predictions (bool): If True, normalize the predictions before computing the loss.
norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
If 0, ambiguous pixels are ignored. Default: 0.
loss_in_log (bool): If True, apply logarithmic transformation to input before
computing the loss for depth, pointmaps and scale. Default: True.
flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
the loss. If False, flatten across batch and spatial dimensions. Default: False.
depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
Options: "depth_along_ray", "depth_z"
cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
exhaustive pairwise relative poses. Default: False.
convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
Use this if the predictions are not already in the view0 frame. Default: False.
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
If False, apply the normal and gm loss to all data. Default: True.
normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
"""
super().__init__(
criterion=criterion,
norm_predictions=norm_predictions,
norm_mode=norm_mode,
ambiguous_loss_value=ambiguous_loss_value,
loss_in_log=loss_in_log,
flatten_across_image_only=flatten_across_image_only,
depth_type_for_loss=depth_type_for_loss,
cam_frame_points_loss_weight=cam_frame_points_loss_weight,
depth_loss_weight=depth_loss_weight,
ray_directions_loss_weight=ray_directions_loss_weight,
pose_quats_loss_weight=pose_quats_loss_weight,
pose_trans_loss_weight=pose_trans_loss_weight,
scale_loss_weight=scale_loss_weight,
compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
convert_predictions_to_view0_frame=convert_predictions_to_view0_frame,
compute_world_frame_points_loss=compute_world_frame_points_loss,
world_frame_points_loss_weight=world_frame_points_loss_weight,
)
self.apply_normal_and_gm_loss_to_synthetic_data_only = (
apply_normal_and_gm_loss_to_synthetic_data_only
)
self.normal_loss_weight = normal_loss_weight
self.gm_loss_weight = gm_loss_weight
def compute_loss(self, batch, preds, **kw):
(
gt_info,
pred_info,
valid_masks,
ambiguous_masks,
gt_metric_norm_factor,
pr_metric_norm_factor,
) = self.get_all_info(batch, preds, **kw)
n_views = len(batch)
# Mask out samples in the batch where the gt depth validity mask is entirely zero
valid_norm_factor_masks = [
mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
] # List of (B,)
if self.ambiguous_loss_value > 0:
assert self.criterion.reduction == "none", (
"ambiguous_loss_value should be 0 if no conf loss"
)
# Add the ambiguous pixel as "valid" pixels...
valid_masks = [
mask | ambig_mask
for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
]
normal_losses = []
gradient_matching_losses = []
pose_trans_losses = []
pose_quats_losses = []
ray_directions_losses = []
depth_losses = []
cam_pts3d_losses = []
if self.compute_world_frame_points_loss:
pts3d_losses = []
for i in range(n_views):
# Get the camera frame points, log space depth_z & valid masks
pred_local_pts3d = pred_info[i]["pts3d_cam"]
pred_depth_z = pred_local_pts3d[..., 2:]
pred_depth_z = apply_log_to_norm(pred_depth_z)
gt_local_pts3d = gt_info[i]["pts3d_cam"]
gt_depth_z = gt_local_pts3d[..., 2:]
gt_depth_z = apply_log_to_norm(gt_depth_z)
valid_mask_for_normal_gm_loss = valid_masks[i].clone()
# Update the validity mask for normal & gm loss based on the synthetic data mask if required
if self.apply_normal_and_gm_loss_to_synthetic_data_only:
synthetic_mask = batch[i]["is_synthetic"] # (B, )
synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
synthetic_mask = synthetic_mask.expand(
-1, pred_depth_z.shape[1], pred_depth_z.shape[2]
) # (B, H, W)
valid_mask_for_normal_gm_loss = (
valid_mask_for_normal_gm_loss & synthetic_mask
)
# Compute the normal loss
normal_loss = compute_normal_loss(
pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
)
normal_loss = normal_loss * self.normal_loss_weight
normal_losses.append(normal_loss)
# Compute the gradient matching loss
gradient_matching_loss = compute_gradient_matching_loss(
pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
)
gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
gradient_matching_losses.append(gradient_matching_loss)
# Get the predicted dense quantities
if not self.flatten_across_image_only:
# Flatten the points across the entire batch with the masks and compute the metrics
pred_ray_directions = pred_info[i]["ray_directions"]
gt_ray_directions = gt_info[i]["ray_directions"]
pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
if self.compute_world_frame_points_loss:
pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
else:
# Flatten the H x W dimensions to H*W and compute the metrics
batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
gt_ray_directions = gt_info[i]["ray_directions"].view(
batch_size, -1, direction_dim
)
pred_ray_directions = pred_info[i]["ray_directions"].view(
batch_size, -1, direction_dim
)
depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
gt_depth = gt_info[i][self.depth_type_for_loss].view(
batch_size, -1, depth_dim
)
pred_depth = pred_info[i][self.depth_type_for_loss].view(
batch_size, -1, depth_dim
)
cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
batch_size, -1, cam_pts_dim
)
if self.compute_world_frame_points_loss:
pts_dim = gt_info[i]["pts3d"].shape[-1]
gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
valid_masks[i] = valid_masks[i].view(batch_size, -1)
# Apply loss in log space for depth if specified
if self.loss_in_log:
gt_depth = apply_log_to_norm(gt_depth)
pred_depth = apply_log_to_norm(pred_depth)
gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
if self.compute_world_frame_points_loss:
gt_pts3d = apply_log_to_norm(gt_pts3d)
pred_pts3d = apply_log_to_norm(pred_pts3d)
if self.compute_pairwise_relative_pose_loss:
# Get the inverse of current view predicted pose
pred_inv_curr_view_pose_quats = quaternion_inverse(
pred_info[i]["pose_quats"]
)
pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
pred_inv_curr_view_pose_quats
)
pred_inv_curr_view_pose_trans = -1 * ein.einsum(
pred_inv_curr_view_pose_rot_mat,
pred_info[i]["pose_trans"],
"b i j, b j -> b i",
)
# Get the inverse of the current view GT pose
gt_inv_curr_view_pose_quats = quaternion_inverse(
gt_info[i]["pose_quats"]
)
gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
gt_inv_curr_view_pose_quats
)
gt_inv_curr_view_pose_trans = -1 * ein.einsum(
gt_inv_curr_view_pose_rot_mat,
gt_info[i]["pose_trans"],
"b i j, b j -> b i",
)
# Get the other N-1 relative poses using the current pose as reference frame
pred_rel_pose_quats = []
pred_rel_pose_trans = []
gt_rel_pose_quats = []
gt_rel_pose_trans = []
for ov_idx in range(n_views):
if ov_idx == i:
continue
# Get the relative predicted pose
pred_ov_rel_pose_quats = quaternion_multiply(
pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
)
pred_ov_rel_pose_trans = (
ein.einsum(
pred_inv_curr_view_pose_rot_mat,
pred_info[ov_idx]["pose_trans"],
"b i j, b j -> b i",
)
+ pred_inv_curr_view_pose_trans
)
# Get the relative GT pose
gt_ov_rel_pose_quats = quaternion_multiply(
gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
)
gt_ov_rel_pose_trans = (
ein.einsum(
gt_inv_curr_view_pose_rot_mat,
gt_info[ov_idx]["pose_trans"],
"b i j, b j -> b i",
)
+ gt_inv_curr_view_pose_trans
)
# Get the valid translations using valid_norm_factor_masks for current view and other view
overall_valid_mask_for_trans = (
valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
)
# Append the relative poses
pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
pred_rel_pose_trans.append(
pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
)
gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
gt_rel_pose_trans.append(
gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
)
# Cat the N-1 relative poses along the batch dimension
pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
# Handle quaternion two-to-one mapping
pose_quats_loss = torch.minimum(
self.criterion(
pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
),
self.criterion(
pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
),
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
else:
# Get the pose info for the current view
pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
pred_pose_quats = pred_info[i]["pose_quats"]
gt_pose_quats = gt_info[i]["pose_quats"]
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_pose_trans, gt_pose_trans, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
# Handle quaternion two-to-one mapping
pose_quats_loss = torch.minimum(
self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
self.criterion(
pred_pose_quats, -gt_pose_quats, factor="pose_quats"
),
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
# Compute ray direction loss
ray_directions_loss = self.criterion(
pred_ray_directions, gt_ray_directions, factor="ray_directions"
)
ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
ray_directions_losses.append(ray_directions_loss)
# Compute depth loss
depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
depth_loss = depth_loss * self.depth_loss_weight
depth_losses.append(depth_loss)
# Compute camera frame point loss
cam_pts3d_loss = self.criterion(
pred_cam_pts3d, gt_cam_pts3d, factor="points"
)
cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
cam_pts3d_losses.append(cam_pts3d_loss)
if self.compute_world_frame_points_loss:
# Compute point loss
pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
pts3d_losses.append(pts3d_loss)
# Handle ambiguous pixels
if self.ambiguous_loss_value > 0:
if not self.flatten_across_image_only:
depth_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
depth_losses[i],
)
cam_pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
cam_pts3d_losses[i],
)
if self.compute_world_frame_points_loss:
pts3d_losses[i] = torch.where(
ambiguous_masks[i][valid_masks[i]],
self.ambiguous_loss_value,
pts3d_losses[i],
)
else:
depth_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
depth_losses[i],
)
cam_pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
cam_pts3d_losses[i],
)
if self.compute_world_frame_points_loss:
pts3d_losses[i] = torch.where(
ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
self.ambiguous_loss_value,
pts3d_losses[i],
)
# Compute the scale loss
if gt_metric_norm_factor is not None:
if self.loss_in_log:
gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
scale_loss = (
self.criterion(
pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
)
* self.scale_loss_weight
)
else:
scale_loss = None
# Use helper function to generate loss terms and details
if self.compute_world_frame_points_loss:
losses_dict = {
"pts3d": {
"values": pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
}
else:
losses_dict = {}
losses_dict.update(
{
"cam_pts3d": {
"values": cam_pts3d_losses,
"use_mask": True,
"is_multi_view": True,
},
self.depth_type_for_loss: {
"values": depth_losses,
"use_mask": True,
"is_multi_view": True,
},
"ray_directions": {
"values": ray_directions_losses,
"use_mask": False,
"is_multi_view": True,
},
"pose_quats": {
"values": pose_quats_losses,
"use_mask": False,
"is_multi_view": True,
},
"pose_trans": {
"values": pose_trans_losses,
"use_mask": False,
"is_multi_view": True,
},
"scale": {
"values": scale_loss,
"use_mask": False,
"is_multi_view": False,
},
"normal": {
"values": normal_losses,
"use_mask": False,
"is_multi_view": True,
},
"gradient_matching": {
"values": gradient_matching_losses,
"use_mask": False,
"is_multi_view": True,
},
}
)
loss_terms, details = get_loss_terms_and_details(
losses_dict,
valid_masks,
type(self).__name__,
n_views,
self.flatten_across_image_only,
)
losses = Sum(*loss_terms)
return losses, (details | {})
class DisentangledFactoredGeometryScaleRegr3D(Criterion, MultiLoss):
"""
Disentangled Regression Loss for Factored Geometry & Scale.
"""
def __init__(
self,
criterion,
norm_predictions=True,
norm_mode="avg_dis",
loss_in_log=True,
flatten_across_image_only=False,
depth_type_for_loss="depth_along_ray",
depth_loss_weight=1,
ray_directions_loss_weight=1,
pose_quats_loss_weight=1,
pose_trans_loss_weight=1,
scale_loss_weight=1,
):
"""
Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale.
It isolates/disentangles the contribution of each factor to the final task of 3D reconstruction.
All the losses are in the same space where the loss for each factor is computed by constructing world-frame pointmaps.
This sidesteps the difficulty of finding a proper weighting.
For insance, for predicted rays, the GT depth & pose is used to construct the predicted world-frame pointmaps on which the loss is computed.
Inspired by https://openaccess.thecvf.com/content_ICCV_2019/papers/Simonelli_Disentangling_Monocular_3D_Object_Detection_ICCV_2019_paper.pdf
The pixel-level losses are computed in the following order:
(1) depth, (2) ray directions, (3) pose quats, (4) pose trans, (5) scale.
The predicited scene representation is always normalized w.r.t. the frame of view0.
Loss is applied between the predicted metric scale and the ground truth metric scale.
Args:
criterion (BaseCriterion): The base criterion to use for computing the loss.
norm_predictions (bool): If True, normalize the predictions before computing the loss.
norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
loss_in_log (bool): If True, apply logarithmic transformation to input before
computing the loss for depth, pointmaps and scale. Default: True.
flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
the loss. If False, flatten across batch and spatial dimensions. Default: False.
depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
Options: "depth_along_ray", "depth_z"
depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
"""
super().__init__(criterion)
self.norm_predictions = norm_predictions
self.norm_mode = norm_mode
self.loss_in_log = loss_in_log
self.flatten_across_image_only = flatten_across_image_only
self.depth_type_for_loss = depth_type_for_loss
assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
"depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
)
self.depth_loss_weight = depth_loss_weight
self.ray_directions_loss_weight = ray_directions_loss_weight
self.pose_quats_loss_weight = pose_quats_loss_weight
self.pose_trans_loss_weight = pose_trans_loss_weight
self.scale_loss_weight = scale_loss_weight
def get_all_info(self, batch, preds, dist_clip=None):
"""
Function to get all the information needed to compute the loss.
Returns all quantities normalized w.r.t. camera of view0.
"""
n_views = len(batch)
# Everything is normalized w.r.t. camera of view0
# Intialize lists to store data for all views
# Ground truth quantities
in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
no_norm_gt_pts = []
no_norm_gt_pts_cam = []
no_norm_gt_depth = []
no_norm_gt_pose_trans = []
valid_masks = []
gt_ray_directions = []
gt_pose_quats = []
# Predicted quantities
no_norm_pr_pts = []
no_norm_pr_pts_cam = []
no_norm_pr_depth = []
no_norm_pr_pose_trans = []
pr_ray_directions = []
pr_pose_quats = []
metric_pr_pts_to_compute_scale = []
# Get ground truth & prediction info for all views
for i in range(n_views):
# Get the ground truth
no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
valid_masks.append(batch[i]["valid_mask"].clone())
no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
gt_ray_directions.append(batch[i]["ray_directions_cam"])
if self.depth_type_for_loss == "depth_along_ray":
no_norm_gt_depth.append(batch[i]["depth_along_ray"])
elif self.depth_type_for_loss == "depth_z":
no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
if i == 0:
# For view0, initialize identity pose
gt_pose_quats.append(
torch.tensor(
[0, 0, 0, 1],
dtype=gt_ray_directions[0].dtype,
device=gt_ray_directions[0].device,
)
.unsqueeze(0)
.repeat(gt_ray_directions[0].shape[0], 1)
)
no_norm_gt_pose_trans.append(
torch.tensor(
[0, 0, 0],
dtype=gt_ray_directions[0].dtype,
device=gt_ray_directions[0].device,
)
.unsqueeze(0)
.repeat(gt_ray_directions[0].shape[0], 1)
)
else:
# For other views, transform pose to view0's frame
gt_pose_quats_world = batch[i]["camera_pose_quats"]
no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
transform_pose_using_quats_and_trans_2_to_1(
batch[0]["camera_pose_quats"],
batch[0]["camera_pose_trans"],
gt_pose_quats_world,
no_norm_gt_pose_trans_world,
)
)
gt_pose_quats.append(gt_pose_quats_in_view0)
no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
# Get predictions for normalized loss
if self.depth_type_for_loss == "depth_along_ray":
curr_view_no_norm_depth = preds[i]["depth_along_ray"]
elif self.depth_type_for_loss == "depth_z":
curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:]
if "metric_scaling_factor" in preds[i].keys():
# Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
# This detaches the predicted metric scaling factor from the geometry based loss
curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
"metric_scaling_factor"
].unsqueeze(-1).unsqueeze(-1)
curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
"metric_scaling_factor"
].unsqueeze(-1).unsqueeze(-1)
curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][
"metric_scaling_factor"
].unsqueeze(-1).unsqueeze(-1)
curr_view_no_norm_pr_pose_trans = (
preds[i]["cam_trans"] / preds[i]["metric_scaling_factor"]
)
else:
curr_view_no_norm_pr_pts = preds[i]["pts3d"]
curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
curr_view_no_norm_depth = curr_view_no_norm_depth
curr_view_no_norm_pr_pose_trans = preds[i]["cam_trans"]
no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
no_norm_pr_depth.append(curr_view_no_norm_depth)
no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
pr_ray_directions.append(preds[i]["ray_directions"])
pr_pose_quats.append(preds[i]["cam_quats"])
# Get the predicted metric scale points
if "metric_scaling_factor" in preds[i].keys():
# Detach the raw predicted points so that the scale loss is only applied to the scaling factor
curr_view_metric_pr_pts_to_compute_scale = (
curr_view_no_norm_pr_pts.detach()
* preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
)
else:
curr_view_metric_pr_pts_to_compute_scale = (
curr_view_no_norm_pr_pts.clone()
)
metric_pr_pts_to_compute_scale.append(
curr_view_metric_pr_pts_to_compute_scale
)
if dist_clip is not None:
# Points that are too far-away == invalid
for i in range(n_views):
dis = no_norm_gt_pts[i].norm(dim=-1)
valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
# Initialize normalized tensors
gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
# Normalize the predicted points if specified
if self.norm_predictions:
pr_normalization_output = normalize_multiple_pointclouds(
no_norm_pr_pts,
valid_masks,
self.norm_mode,
ret_factor=True,
)
pr_pts_norm = pr_normalization_output[:-1]
pr_norm_factor = pr_normalization_output[-1]
# Normalize the ground truth points
gt_normalization_output = normalize_multiple_pointclouds(
no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
)
gt_pts_norm = gt_normalization_output[:-1]
gt_norm_factor = gt_normalization_output[-1]
for i in range(n_views):
if self.norm_predictions:
# Assign the normalized predictions
pr_pts[i] = pr_pts_norm[i]
pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor
pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor
pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0]
else:
pr_pts[i] = no_norm_pr_pts[i]
pr_pts_cam[i] = no_norm_pr_pts_cam[i]
pr_depth[i] = no_norm_pr_depth[i]
pr_pose_trans[i] = no_norm_pr_pose_trans[i]
# Assign the normalized ground truth quantities
gt_pts[i] = gt_pts_norm[i]
gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor
gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor
gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0]
# Get the mask indicating ground truth metric scale quantities
metric_scale_mask = batch[0]["is_metric_scale"]
valid_gt_norm_factor_mask = (
gt_norm_factor[:, 0, 0, 0] > 1e-8
) # Mask out cases where depth for all views is invalid
valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
if valid_metric_scale_mask.any():
# Compute the scale norm factor using the predicted metric scale points
metric_pr_normalization_output = normalize_multiple_pointclouds(
metric_pr_pts_to_compute_scale,
valid_masks,
self.norm_mode,
ret_factor=True,
)
pr_metric_norm_factor = metric_pr_normalization_output[-1]
# Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
else:
gt_metric_norm_factor = None
pr_metric_norm_factor = None
# Get ambiguous masks
ambiguous_masks = []
for i in range(n_views):
ambiguous_masks.append(
(~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
)
# Pack into info dicts
gt_info = []
pred_info = []
for i in range(n_views):
gt_info.append(
{
"ray_directions": gt_ray_directions[i],
self.depth_type_for_loss: gt_depth[i],
"pose_trans": gt_pose_trans[i],
"pose_quats": gt_pose_quats[i],
"pts3d": gt_pts[i],
"pts3d_cam": gt_pts_cam[i],
}
)
pred_info.append(
{
"ray_directions": pr_ray_directions[i],
self.depth_type_for_loss: pr_depth[i],
"pose_trans": pr_pose_trans[i],
"pose_quats": pr_pose_quats[i],
"pts3d": pr_pts[i],
"pts3d_cam": pr_pts_cam[i],
}
)
return (
gt_info,
pred_info,
valid_masks,
ambiguous_masks,
gt_metric_norm_factor,
pr_metric_norm_factor,
)
def compute_loss(self, batch, preds, **kw):
(
gt_info,
pred_info,
valid_masks,
ambiguous_masks,
gt_metric_norm_factor,
pr_metric_norm_factor,
) = self.get_all_info(batch, preds, **kw)
n_views = len(batch)
pose_trans_losses = []
pose_quats_losses = []
ray_directions_losses = []
depth_losses = []
for i in range(n_views):
# Get the GT factored quantities for the current view
gt_pts3d = gt_info[i]["pts3d"]
gt_ray_directions = gt_info[i]["ray_directions"]
gt_depth = gt_info[i][self.depth_type_for_loss]
gt_pose_trans = gt_info[i]["pose_trans"]
gt_pose_quats = gt_info[i]["pose_quats"]
# Get the predicted factored quantities for the current view
pred_ray_directions = pred_info[i]["ray_directions"]
pred_depth = pred_info[i][self.depth_type_for_loss]
pred_pose_trans = pred_info[i]["pose_trans"]
pred_pose_quats = pred_info[i]["pose_quats"]
# Get the predicted world-frame pointmaps using the different factors
if self.depth_type_for_loss == "depth_along_ray":
pred_ray_directions_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
pred_ray_directions,
gt_depth,
gt_pose_trans,
gt_pose_quats,
)
)
pred_depth_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
gt_ray_directions,
pred_depth,
gt_pose_trans,
gt_pose_quats,
)
)
pred_pose_trans_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
gt_ray_directions,
gt_depth,
pred_pose_trans,
gt_pose_quats,
)
)
pred_pose_quats_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
gt_ray_directions,
gt_depth,
gt_pose_trans,
pred_pose_quats,
)
)
else:
raise NotImplementedError
# Mask out the valid quantities as required
if not self.flatten_across_image_only:
# Flatten the points across the entire batch with the masks
pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]]
pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]]
pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]]
pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]]
gt_pts3d = gt_pts3d[valid_masks[i]]
else:
# Flatten the H x W dimensions to H*W
batch_size, _, _, pts_dim = gt_pts3d.shape
pred_ray_directions_pts3d = pred_ray_directions_pts3d.view(
batch_size, -1, pts_dim
)
pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim)
pred_pose_trans_pts3d = pred_pose_trans_pts3d.view(
batch_size, -1, pts_dim
)
pred_pose_quats_pts3d = pred_pose_quats_pts3d.view(
batch_size, -1, pts_dim
)
gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim)
valid_masks[i] = valid_masks[i].view(batch_size, -1)
# Apply loss in log space if specified
if self.loss_in_log:
gt_pts3d = apply_log_to_norm(gt_pts3d)
pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d)
pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d)
pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d)
pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d)
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
pose_quats_loss = self.criterion(
pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats"
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
# Compute ray direction loss
ray_directions_loss = self.criterion(
pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions"
)
ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
ray_directions_losses.append(ray_directions_loss)
# Compute depth loss
depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth")
depth_loss = depth_loss * self.depth_loss_weight
depth_losses.append(depth_loss)
# Compute the scale loss
if gt_metric_norm_factor is not None:
if self.loss_in_log:
gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
scale_loss = (
self.criterion(
pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
)
* self.scale_loss_weight
)
else:
scale_loss = None
# Use helper function to generate loss terms and details
losses_dict = {}
losses_dict.update(
{
self.depth_type_for_loss: {
"values": depth_losses,
"use_mask": True,
"is_multi_view": True,
},
"ray_directions": {
"values": ray_directions_losses,
"use_mask": True,
"is_multi_view": True,
},
"pose_quats": {
"values": pose_quats_losses,
"use_mask": True,
"is_multi_view": True,
},
"pose_trans": {
"values": pose_trans_losses,
"use_mask": True,
"is_multi_view": True,
},
"scale": {
"values": scale_loss,
"use_mask": False,
"is_multi_view": False,
},
}
)
loss_terms, details = get_loss_terms_and_details(
losses_dict,
valid_masks,
type(self).__name__,
n_views,
self.flatten_across_image_only,
)
losses = Sum(*loss_terms)
return losses, (details | {})
class DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss(
DisentangledFactoredGeometryScaleRegr3D
):
"""
Disentangled Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale.
"""
def __init__(
self,
criterion,
norm_predictions=True,
norm_mode="avg_dis",
loss_in_log=True,
flatten_across_image_only=False,
depth_type_for_loss="depth_along_ray",
depth_loss_weight=1,
ray_directions_loss_weight=1,
pose_quats_loss_weight=1,
pose_trans_loss_weight=1,
scale_loss_weight=1,
apply_normal_and_gm_loss_to_synthetic_data_only=True,
normal_loss_weight=1,
gm_loss_weight=1,
):
"""
Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale.
See parent class (DisentangledFactoredGeometryScaleRegr3D) for more details.
Additionally computes:
(1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
(2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
Args:
criterion (BaseCriterion): The base criterion to use for computing the loss.
norm_predictions (bool): If True, normalize the predictions before computing the loss.
norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
loss_in_log (bool): If True, apply logarithmic transformation to input before
computing the loss for depth, pointmaps and scale. Default: True.
flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
the loss. If False, flatten across batch and spatial dimensions. Default: False.
depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
Options: "depth_along_ray", "depth_z"
depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
If False, apply the normal and gm loss to all data. Default: True.
normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
"""
super().__init__(
criterion=criterion,
norm_predictions=norm_predictions,
norm_mode=norm_mode,
loss_in_log=loss_in_log,
flatten_across_image_only=flatten_across_image_only,
depth_type_for_loss=depth_type_for_loss,
depth_loss_weight=depth_loss_weight,
ray_directions_loss_weight=ray_directions_loss_weight,
pose_quats_loss_weight=pose_quats_loss_weight,
pose_trans_loss_weight=pose_trans_loss_weight,
scale_loss_weight=scale_loss_weight,
)
self.apply_normal_and_gm_loss_to_synthetic_data_only = (
apply_normal_and_gm_loss_to_synthetic_data_only
)
self.normal_loss_weight = normal_loss_weight
self.gm_loss_weight = gm_loss_weight
def compute_loss(self, batch, preds, **kw):
(
gt_info,
pred_info,
valid_masks,
ambiguous_masks,
gt_metric_norm_factor,
pr_metric_norm_factor,
) = self.get_all_info(batch, preds, **kw)
n_views = len(batch)
normal_losses = []
gradient_matching_losses = []
pose_trans_losses = []
pose_quats_losses = []
ray_directions_losses = []
depth_losses = []
for i in range(n_views):
# Get the camera frame points, log space depth_z & valid masks
pred_local_pts3d = pred_info[i]["pts3d_cam"]
pred_depth_z = pred_local_pts3d[..., 2:]
pred_depth_z = apply_log_to_norm(pred_depth_z)
gt_local_pts3d = gt_info[i]["pts3d_cam"]
gt_depth_z = gt_local_pts3d[..., 2:]
gt_depth_z = apply_log_to_norm(gt_depth_z)
valid_mask_for_normal_gm_loss = valid_masks[i].clone()
# Update the validity mask for normal & gm loss based on the synthetic data mask if required
if self.apply_normal_and_gm_loss_to_synthetic_data_only:
synthetic_mask = batch[i]["is_synthetic"] # (B, )
synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
synthetic_mask = synthetic_mask.expand(
-1, pred_depth_z.shape[1], pred_depth_z.shape[2]
) # (B, H, W)
valid_mask_for_normal_gm_loss = (
valid_mask_for_normal_gm_loss & synthetic_mask
)
# Compute the normal loss
normal_loss = compute_normal_loss(
pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
)
normal_loss = normal_loss * self.normal_loss_weight
normal_losses.append(normal_loss)
# Compute the gradient matching loss
gradient_matching_loss = compute_gradient_matching_loss(
pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
)
gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
gradient_matching_losses.append(gradient_matching_loss)
# Get the GT factored quantities for the current view
gt_pts3d = gt_info[i]["pts3d"]
gt_ray_directions = gt_info[i]["ray_directions"]
gt_depth = gt_info[i][self.depth_type_for_loss]
gt_pose_trans = gt_info[i]["pose_trans"]
gt_pose_quats = gt_info[i]["pose_quats"]
# Get the predicted factored quantities for the current view
pred_ray_directions = pred_info[i]["ray_directions"]
pred_depth = pred_info[i][self.depth_type_for_loss]
pred_pose_trans = pred_info[i]["pose_trans"]
pred_pose_quats = pred_info[i]["pose_quats"]
# Get the predicted world-frame pointmaps using the different factors
if self.depth_type_for_loss == "depth_along_ray":
pred_ray_directions_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
pred_ray_directions,
gt_depth,
gt_pose_trans,
gt_pose_quats,
)
)
pred_depth_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
gt_ray_directions,
pred_depth,
gt_pose_trans,
gt_pose_quats,
)
)
pred_pose_trans_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
gt_ray_directions,
gt_depth,
pred_pose_trans,
gt_pose_quats,
)
)
pred_pose_quats_pts3d = (
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
gt_ray_directions,
gt_depth,
gt_pose_trans,
pred_pose_quats,
)
)
else:
raise NotImplementedError
# Mask out the valid quantities as required
if not self.flatten_across_image_only:
# Flatten the points across the entire batch with the masks
pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]]
pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]]
pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]]
pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]]
gt_pts3d = gt_pts3d[valid_masks[i]]
else:
# Flatten the H x W dimensions to H*W
batch_size, _, _, pts_dim = gt_pts3d.shape
pred_ray_directions_pts3d = pred_ray_directions_pts3d.view(
batch_size, -1, pts_dim
)
pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim)
pred_pose_trans_pts3d = pred_pose_trans_pts3d.view(
batch_size, -1, pts_dim
)
pred_pose_quats_pts3d = pred_pose_quats_pts3d.view(
batch_size, -1, pts_dim
)
gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim)
valid_masks[i] = valid_masks[i].view(batch_size, -1)
# Apply loss in log space if specified
if self.loss_in_log:
gt_pts3d = apply_log_to_norm(gt_pts3d)
pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d)
pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d)
pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d)
pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d)
# Compute pose translation loss
pose_trans_loss = self.criterion(
pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans"
)
pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
pose_trans_losses.append(pose_trans_loss)
# Compute pose rotation loss
pose_quats_loss = self.criterion(
pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats"
)
pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
pose_quats_losses.append(pose_quats_loss)
# Compute ray direction loss
ray_directions_loss = self.criterion(
pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions"
)
ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
ray_directions_losses.append(ray_directions_loss)
# Compute depth loss
depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth")
depth_loss = depth_loss * self.depth_loss_weight
depth_losses.append(depth_loss)
# Compute the scale loss
if gt_metric_norm_factor is not None:
if self.loss_in_log:
gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
scale_loss = (
self.criterion(
pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
)
* self.scale_loss_weight
)
else:
scale_loss = None
# Use helper function to generate loss terms and details
losses_dict = {}
losses_dict.update(
{
self.depth_type_for_loss: {
"values": depth_losses,
"use_mask": True,
"is_multi_view": True,
},
"ray_directions": {
"values": ray_directions_losses,
"use_mask": True,
"is_multi_view": True,
},
"pose_quats": {
"values": pose_quats_losses,
"use_mask": True,
"is_multi_view": True,
},
"pose_trans": {
"values": pose_trans_losses,
"use_mask": True,
"is_multi_view": True,
},
"scale": {
"values": scale_loss,
"use_mask": False,
"is_multi_view": False,
},
"normal": {
"values": normal_losses,
"use_mask": False,
"is_multi_view": True,
},
"gradient_matching": {
"values": gradient_matching_losses,
"use_mask": False,
"is_multi_view": True,
},
}
)
loss_terms, details = get_loss_terms_and_details(
losses_dict,
valid_masks,
type(self).__name__,
n_views,
self.flatten_across_image_only,
)
losses = Sum(*loss_terms)
return losses, (details | {})