Spaces:
Running
on
Zero
Running
on
Zero
aknapitsch user
commited on
Commit
·
37de32d
1
Parent(s):
8c1e404
simpler inference and refactoring
Browse files- app.py +165 -445
- hf_utils/vgg_geometry.py +0 -166
- hf_utils/visual_util.py +5 -5
- mapanything/__init__.py +0 -0
- mapanything/datasets/wai/ase.py +1 -1
- mapanything/datasets/wai/bedlam.py +1 -1
- mapanything/datasets/wai/blendedmvs.py +2 -9
- mapanything/datasets/wai/dl3dv.py +7 -28
- mapanything/datasets/wai/dtu.py +1 -1
- mapanything/datasets/wai/dynamicreplica.py +1 -1
- mapanything/datasets/wai/eth3d.py +1 -1
- mapanything/datasets/wai/gta_sfm.py +1 -1
- mapanything/datasets/wai/matrixcity.py +1 -1
- mapanything/datasets/wai/megadepth.py +2 -9
- mapanything/datasets/wai/mpsd.py +2 -9
- mapanything/datasets/wai/mvs_synth.py +1 -1
- mapanything/datasets/wai/paralleldomain4d.py +1 -1
- mapanything/datasets/wai/sailvos3d.py +1 -1
- mapanything/datasets/wai/scannetpp.py +1 -1
- mapanything/datasets/wai/spring.py +2 -9
- mapanything/datasets/wai/structured3d.py +1 -1
- mapanything/datasets/wai/tav2_wb.py +2 -9
- mapanything/datasets/wai/unrealstereo4k.py +1 -1
- mapanything/datasets/wai/xrooms.py +1 -1
- mapanything/models/external/README.md +5 -0
- mapanything/models/external/moge/models/v1.py +1 -1
- mapanything/models/external/moge/models/v2.py +1 -1
- mapanything/models/mapanything/ablations.py +4 -2
- mapanything/models/mapanything/model.py +220 -4
- mapanything/models/mapanything/modular_dust3r.py +4 -2
- mapanything/train/losses.py +283 -9
- mapanything/utils/geometry.py +91 -0
- mapanything/utils/image.py +11 -10
- mapanything/utils/inference.py +389 -0
- mapanything/utils/viz.py +2 -2
- mapanything/utils/wai/__init__.py +3 -0
- mapanything/utils/wai/basic_dataset.py +131 -0
- mapanything/utils/wai/camera.py +263 -0
- mapanything/utils/wai/colormaps/colors_fps_5k.npz +3 -0
- mapanything/utils/wai/core.py +492 -0
- mapanything/utils/wai/intersection_check.py +462 -0
- mapanything/utils/wai/io.py +1373 -0
- mapanything/utils/wai/m_ops.py +346 -0
- mapanything/utils/wai/ops.py +368 -0
- mapanything/utils/wai/scene_frame.py +431 -0
- mapanything/utils/wai/semantics.py +40 -0
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -18,24 +18,21 @@ import gradio as gr
|
|
| 18 |
import numpy as np
|
| 19 |
import spaces
|
| 20 |
import torch
|
| 21 |
-
from huggingface_hub import hf_hub_download
|
| 22 |
|
| 23 |
sys.path.append("mapanything/")
|
| 24 |
|
| 25 |
from hf_utils.css_and_html import (
|
|
|
|
|
|
|
| 26 |
get_acknowledgements_html,
|
| 27 |
get_description_html,
|
| 28 |
get_gradio_theme,
|
| 29 |
get_header_html,
|
| 30 |
-
GRADIO_CSS,
|
| 31 |
-
MEASURE_INSTRUCTIONS_HTML,
|
| 32 |
)
|
| 33 |
-
from hf_utils.vgg_geometry import unproject_depth_map_to_point_map
|
| 34 |
from hf_utils.visual_util import predictions_to_glb
|
| 35 |
-
from mapanything.models import
|
| 36 |
-
from mapanything.utils.geometry import
|
| 37 |
from mapanything.utils.image import load_images, rgb
|
| 38 |
-
from mapanything.utils.inference import loss_of_one_batch_multi_view
|
| 39 |
|
| 40 |
|
| 41 |
def get_logo_base64():
|
|
@@ -103,69 +100,16 @@ def init_hydra_config(config_path, overrides=None):
|
|
| 103 |
return cfg
|
| 104 |
|
| 105 |
|
| 106 |
-
def init_inference_model(config, ckpt_path, device):
|
| 107 |
-
"Initialize the model for inference"
|
| 108 |
-
if isinstance(config, dict):
|
| 109 |
-
config_path = config["path"]
|
| 110 |
-
overrrides = config["config_overrides"]
|
| 111 |
-
model_args = init_hydra_config(config_path, overrides=overrrides)
|
| 112 |
-
model = init_model(model_args.model.model_str, model_args.model.model_config)
|
| 113 |
-
else:
|
| 114 |
-
config_path = config
|
| 115 |
-
model_args = init_hydra_config(config_path)
|
| 116 |
-
model = init_model(model_args.model_str, model_args.model_config)
|
| 117 |
-
model.to(device)
|
| 118 |
-
if ckpt_path is not None:
|
| 119 |
-
print("Loading model from: ", ckpt_path)
|
| 120 |
-
|
| 121 |
-
# Load HuggingFace token for private repositories
|
| 122 |
-
hf_token = load_hf_token()
|
| 123 |
-
|
| 124 |
-
# Try to download from HuggingFace Hub first if it's a HF URL
|
| 125 |
-
if "huggingface.co" in ckpt_path:
|
| 126 |
-
try:
|
| 127 |
-
# Extract repo_id and filename from URL
|
| 128 |
-
# URL format: https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth
|
| 129 |
-
parts = ckpt_path.replace("https://huggingface.co/", "").split("/")
|
| 130 |
-
repo_id = f"{parts[0]}/{parts[1]}" # e.g., "facebook/MapAnything"
|
| 131 |
-
filename = "/".join(
|
| 132 |
-
parts[4:]
|
| 133 |
-
) # e.g., "mapa_curri_24v_13d_48ipg_64g.pth"
|
| 134 |
-
|
| 135 |
-
print(f"Downloading from HuggingFace Hub: {repo_id}/{filename}")
|
| 136 |
-
local_file = hf_hub_download(
|
| 137 |
-
repo_id=repo_id,
|
| 138 |
-
filename=filename,
|
| 139 |
-
token=hf_token,
|
| 140 |
-
cache_dir=None, # Use default cache
|
| 141 |
-
)
|
| 142 |
-
ckpt = torch.load(local_file, map_location=device, weights_only=False)
|
| 143 |
-
except Exception as e:
|
| 144 |
-
print(f"HuggingFace Hub download failed: {e}")
|
| 145 |
-
print("Falling back to torch.hub.load_state_dict_from_url...")
|
| 146 |
-
# Fallback to original method
|
| 147 |
-
ckpt = torch.hub.load_state_dict_from_url(
|
| 148 |
-
ckpt_path, map_location=device
|
| 149 |
-
)
|
| 150 |
-
else:
|
| 151 |
-
# Use original method for non-HF URLs
|
| 152 |
-
ckpt = torch.hub.load_state_dict_from_url(ckpt_path, map_location=device)
|
| 153 |
-
|
| 154 |
-
print(model.load_state_dict(ckpt["model"], strict=False))
|
| 155 |
-
model.eval()
|
| 156 |
-
return model
|
| 157 |
-
|
| 158 |
-
|
| 159 |
# MapAnything Configuration
|
| 160 |
high_level_config = {
|
| 161 |
"path": "configs/train.yaml",
|
|
|
|
| 162 |
"config_overrides": [
|
| 163 |
"machine=aws",
|
| 164 |
"model=mapanything",
|
| 165 |
"model/task=images_only",
|
| 166 |
"model.encoder.uses_torch_hub=false",
|
| 167 |
],
|
| 168 |
-
"checkpoint_path": "https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth",
|
| 169 |
"trained_with_amp": True,
|
| 170 |
"trained_with_amp_dtype": "fp16",
|
| 171 |
"data_norm_type": "dinov2",
|
|
@@ -181,7 +125,7 @@ model = None
|
|
| 181 |
# 1) Core model inference
|
| 182 |
# -------------------------------------------------------------------------
|
| 183 |
@spaces.GPU(duration=120)
|
| 184 |
-
def run_model(target_dir, model_placeholder):
|
| 185 |
"""
|
| 186 |
Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
|
| 187 |
"""
|
|
@@ -191,15 +135,16 @@ def run_model(target_dir, model_placeholder):
|
|
| 191 |
# Device check
|
| 192 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 193 |
device = torch.device(device)
|
| 194 |
-
# if not torch.cuda.is_available():
|
| 195 |
-
# raise ValueError("CUDA is not available. Check your environment.")
|
| 196 |
|
| 197 |
# Initialize model if not already done
|
| 198 |
if model is None:
|
| 199 |
print("Initializing MapAnything model...")
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
)
|
|
|
|
| 203 |
else:
|
| 204 |
model = model.to(device)
|
| 205 |
|
|
@@ -208,30 +153,18 @@ def run_model(target_dir, model_placeholder):
|
|
| 208 |
# Load images using MapAnything's load_images function
|
| 209 |
print("Loading images...")
|
| 210 |
image_folder_path = os.path.join(target_dir, "images")
|
| 211 |
-
views = load_images(
|
| 212 |
-
image_folder_path,
|
| 213 |
-
resolution_set=high_level_config["resolution"],
|
| 214 |
-
verbose=False,
|
| 215 |
-
norm_type=high_level_config["data_norm_type"],
|
| 216 |
-
patch_size=high_level_config["patch_size"],
|
| 217 |
-
stride=1,
|
| 218 |
-
)
|
| 219 |
|
| 220 |
print(f"Loaded {len(views)} images")
|
| 221 |
if len(views) == 0:
|
| 222 |
raise ValueError("No images found. Check your upload.")
|
| 223 |
|
| 224 |
-
# Run
|
| 225 |
-
print("Running
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
None,
|
| 231 |
-
device,
|
| 232 |
-
use_amp=high_level_config["trained_with_amp"],
|
| 233 |
-
amp_dtype=high_level_config["trained_with_amp_dtype"],
|
| 234 |
-
)
|
| 235 |
|
| 236 |
# Convert predictions to format expected by visualization
|
| 237 |
predictions = {}
|
|
@@ -242,167 +175,40 @@ def run_model(target_dir, model_placeholder):
|
|
| 242 |
world_points_list = []
|
| 243 |
depth_maps_list = []
|
| 244 |
images_list = []
|
| 245 |
-
confidence_list = []
|
| 246 |
final_mask_list = []
|
| 247 |
|
| 248 |
-
#
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
break
|
| 255 |
-
|
| 256 |
-
# Extract predictions for each view
|
| 257 |
-
for view_idx, view in enumerate(views):
|
| 258 |
-
# Get image for colors
|
| 259 |
-
image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
# Get confidence data if available
|
| 266 |
-
confidence_map = None
|
| 267 |
-
if "conf" in pred_result[view_key]:
|
| 268 |
-
confidence_map = pred_result[view_key]["conf"][0].cpu().numpy()
|
| 269 |
-
|
| 270 |
-
# Compute final_mask just like in visualize_raw_inference_output function
|
| 271 |
-
# Create the prediction mask based on parameters
|
| 272 |
-
pred_mask = None
|
| 273 |
-
use_gt_mask_on_pred = False # Set based on your requirements
|
| 274 |
-
use_pred_mask = True # Set based on your requirements
|
| 275 |
-
use_non_ambi_mask = True # Set based on your requirements
|
| 276 |
-
use_conf_mask = False # Set based on your requirements
|
| 277 |
-
conf_percentile = 10 # Set based on your requirements
|
| 278 |
-
use_edge_mask = True # Set based on your requirements
|
| 279 |
-
pts_edge_tol = 5 # Set based on your requirements
|
| 280 |
-
depth_edge_rtol = 0.03 # Set based on your requirements
|
| 281 |
-
|
| 282 |
-
if use_pred_mask:
|
| 283 |
-
# Get non ambiguous mask if available and requested
|
| 284 |
-
has_non_ambiguous_mask = (
|
| 285 |
-
"non_ambiguous_mask" in pred_result[view_key] and use_non_ambi_mask
|
| 286 |
-
)
|
| 287 |
-
if has_non_ambiguous_mask:
|
| 288 |
-
non_ambiguous_mask = (
|
| 289 |
-
pred_result[view_key]["non_ambiguous_mask"][0].cpu().numpy()
|
| 290 |
-
)
|
| 291 |
-
pred_mask = non_ambiguous_mask
|
| 292 |
-
|
| 293 |
-
# Get confidence mask if available and requested
|
| 294 |
-
has_conf = "conf" in pred_result[view_key] and use_conf_mask
|
| 295 |
-
if has_conf:
|
| 296 |
-
confidences = pred_result[view_key]["conf"][0].cpu()
|
| 297 |
-
percentile_threshold = torch.quantile(
|
| 298 |
-
confidences, conf_percentile / 100.0
|
| 299 |
-
)
|
| 300 |
-
conf_mask = confidences > percentile_threshold
|
| 301 |
-
conf_mask = conf_mask.numpy()
|
| 302 |
-
if pred_mask is not None:
|
| 303 |
-
pred_mask = pred_mask & conf_mask
|
| 304 |
-
else:
|
| 305 |
-
pred_mask = conf_mask
|
| 306 |
-
|
| 307 |
-
# Apply edge mask if requested
|
| 308 |
-
if use_edge_mask and pred_mask is not None:
|
| 309 |
-
if "cam_quats" not in pred_result[view_key]:
|
| 310 |
-
# For direct point prediction
|
| 311 |
-
# Compute normals and edge mask
|
| 312 |
-
normals, normals_mask = points_to_normals(
|
| 313 |
-
pred_pts3d, mask=pred_mask
|
| 314 |
-
)
|
| 315 |
-
edge_mask = ~(
|
| 316 |
-
normals_edge(normals, tol=pts_edge_tol, mask=normals_mask)
|
| 317 |
-
)
|
| 318 |
-
else:
|
| 319 |
-
# For ray-based prediction
|
| 320 |
-
ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
|
| 321 |
-
local_pts3d = (
|
| 322 |
-
pred_result[view_key]["ray_directions"][0].cpu() * ray_depth
|
| 323 |
-
)
|
| 324 |
-
depth_z = local_pts3d[..., 2].numpy()
|
| 325 |
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
)
|
| 334 |
-
if pred_mask is not None:
|
| 335 |
-
pred_mask = pred_mask & edge_mask
|
| 336 |
-
|
| 337 |
-
# Determine final mask to use (like in visualize_raw_inference_output)
|
| 338 |
-
final_mask = None
|
| 339 |
-
valid_mask = np.ones_like(
|
| 340 |
-
pred_pts3d[..., 0], dtype=bool
|
| 341 |
-
) # Create dummy valid_mask for app.py context
|
| 342 |
-
|
| 343 |
-
if use_gt_mask_on_pred:
|
| 344 |
-
final_mask = valid_mask
|
| 345 |
-
if use_pred_mask and pred_mask is not None:
|
| 346 |
-
final_mask = final_mask & pred_mask
|
| 347 |
-
elif use_pred_mask and pred_mask is not None:
|
| 348 |
-
final_mask = pred_mask
|
| 349 |
-
else:
|
| 350 |
-
final_mask = np.ones_like(valid_mask, dtype=bool)
|
| 351 |
-
|
| 352 |
-
# Check if we have camera pose and intrinsics data
|
| 353 |
-
if "cam_quats" in pred_result[view_key]:
|
| 354 |
-
# Get decoupled quantities (like in visualize_raw_custom_data_inference_output)
|
| 355 |
-
cam_quats = pred_result[view_key]["cam_quats"][0].cpu()
|
| 356 |
-
cam_trans = pred_result[view_key]["cam_trans"][0].cpu()
|
| 357 |
-
ray_directions = pred_result[view_key]["ray_directions"][0].cpu()
|
| 358 |
-
ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
|
| 359 |
-
|
| 360 |
-
# Convert the quantities
|
| 361 |
-
from mapanything.utils.geometry import (
|
| 362 |
-
quaternion_to_rotation_matrix,
|
| 363 |
-
recover_pinhole_intrinsics_from_ray_directions,
|
| 364 |
-
)
|
| 365 |
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
cam_pose[:3, :3] = cam_rot
|
| 369 |
-
cam_pose[:3, 3] = cam_trans
|
| 370 |
-
cam_pose = np.linalg.inv(cam_pose)
|
| 371 |
-
cam_intrinsics = recover_pinhole_intrinsics_from_ray_directions(
|
| 372 |
-
ray_directions, use_geometric_calculation=True
|
| 373 |
-
)
|
| 374 |
|
| 375 |
-
|
| 376 |
-
local_pts3d = ray_directions * ray_depth
|
| 377 |
-
depth_z = local_pts3d[..., 2]
|
| 378 |
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
extrinsic = np.eye(3, 4) # Identity rotation, zero translation
|
| 387 |
-
# intrinsic: (3, 3) - camera intrinsic matrix
|
| 388 |
-
intrinsic = np.eye(3)
|
| 389 |
-
# depth_z: (H, W) - dummy depth values
|
| 390 |
-
depth_z = np.zeros_like(pred_pts3d[..., 0])
|
| 391 |
-
|
| 392 |
-
# Append to lists
|
| 393 |
-
extrinsic_list.append(extrinsic)
|
| 394 |
-
intrinsic_list.append(intrinsic)
|
| 395 |
-
world_points_list.append(pred_pts3d)
|
| 396 |
-
depth_maps_list.append(depth_z)
|
| 397 |
-
images_list.append(image[0]) # Add image to list
|
| 398 |
-
final_mask_list.append(final_mask) # Add final_mask to list
|
| 399 |
-
|
| 400 |
-
# Add confidence data (or None if not available)
|
| 401 |
-
if confidence_map is not None:
|
| 402 |
-
confidence_list.append(confidence_map)
|
| 403 |
-
elif has_confidence:
|
| 404 |
-
# If some views have confidence but this one doesn't, add dummy confidence
|
| 405 |
-
confidence_list.append(np.ones_like(depth_z))
|
| 406 |
|
| 407 |
# Convert lists to numpy arrays with required shapes
|
| 408 |
# extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
|
|
@@ -419,26 +225,18 @@ def run_model(target_dir, model_placeholder):
|
|
| 419 |
# Add channel dimension if needed to match (S, H, W, 1) format
|
| 420 |
if len(depth_maps.shape) == 3:
|
| 421 |
depth_maps = depth_maps[..., np.newaxis]
|
|
|
|
| 422 |
predictions["depth"] = depth_maps
|
| 423 |
|
| 424 |
# images: (S, H, W, 3) - batch of input images
|
| 425 |
predictions["images"] = np.stack(images_list, axis=0)
|
| 426 |
|
| 427 |
-
# confidence: (S, H, W) - batch of confidence maps (only if available)
|
| 428 |
-
if confidence_list:
|
| 429 |
-
predictions["confidence"] = np.stack(confidence_list, axis=0)
|
| 430 |
-
|
| 431 |
# final_mask: (S, H, W) - batch of final masks for filtering
|
| 432 |
predictions["final_mask"] = np.stack(final_mask_list, axis=0)
|
| 433 |
|
| 434 |
-
world_points = unproject_depth_map_to_point_map(
|
| 435 |
-
depth_maps, predictions["extrinsic"], predictions["intrinsic"]
|
| 436 |
-
)
|
| 437 |
-
predictions["world_points_from_depth"] = world_points
|
| 438 |
-
|
| 439 |
# Process data for visualization tabs (depth, normal, measure)
|
| 440 |
processed_data = process_predictions_for_visualization(
|
| 441 |
-
|
| 442 |
)
|
| 443 |
|
| 444 |
# Clean up
|
|
@@ -474,43 +272,69 @@ def get_view_data_by_index(processed_data, view_index):
|
|
| 474 |
return processed_data[view_keys[view_index]]
|
| 475 |
|
| 476 |
|
| 477 |
-
def update_depth_view(processed_data, view_index
|
| 478 |
-
"""Update depth view for a specific view index
|
| 479 |
view_data = get_view_data_by_index(processed_data, view_index)
|
| 480 |
if view_data is None or view_data["depth"] is None:
|
| 481 |
return None
|
| 482 |
|
| 483 |
# Use confidence filtering if available
|
| 484 |
confidence = view_data.get("confidence")
|
| 485 |
-
return colorize_depth(
|
| 486 |
-
view_data["depth"], confidence=confidence, conf_thres=conf_thres
|
| 487 |
-
)
|
| 488 |
|
| 489 |
|
| 490 |
-
def update_normal_view(processed_data, view_index
|
| 491 |
-
"""Update normal view for a specific view index
|
| 492 |
view_data = get_view_data_by_index(processed_data, view_index)
|
| 493 |
if view_data is None or view_data["normal"] is None:
|
| 494 |
return None
|
| 495 |
|
| 496 |
# Use confidence filtering if available
|
| 497 |
confidence = view_data.get("confidence")
|
| 498 |
-
return colorize_normal(
|
| 499 |
-
view_data["normal"], confidence=confidence, conf_thres=conf_thres
|
| 500 |
-
)
|
| 501 |
|
| 502 |
|
| 503 |
def update_measure_view(processed_data, view_index):
|
| 504 |
-
"""Update measure view for a specific view index"""
|
| 505 |
view_data = get_view_data_by_index(processed_data, view_index)
|
| 506 |
if view_data is None:
|
| 507 |
return None, [] # image, measure_points
|
| 508 |
-
return view_data["image"], []
|
| 509 |
|
|
|
|
|
|
|
| 510 |
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
"""Navigate depth view (direction: -1 for previous, +1 for next)"""
|
| 515 |
if processed_data is None or len(processed_data) == 0:
|
| 516 |
return "View 1", None
|
|
@@ -525,14 +349,12 @@ def navigate_depth_view(
|
|
| 525 |
new_view = (current_view + direction) % num_views
|
| 526 |
|
| 527 |
new_selector_value = f"View {new_view + 1}"
|
| 528 |
-
depth_vis = update_depth_view(processed_data, new_view
|
| 529 |
|
| 530 |
return new_selector_value, depth_vis
|
| 531 |
|
| 532 |
|
| 533 |
-
def navigate_normal_view(
|
| 534 |
-
processed_data, current_selector_value, direction, conf_thres=None
|
| 535 |
-
):
|
| 536 |
"""Navigate normal view (direction: -1 for previous, +1 for next)"""
|
| 537 |
if processed_data is None or len(processed_data) == 0:
|
| 538 |
return "View 1", None
|
|
@@ -547,7 +369,7 @@ def navigate_normal_view(
|
|
| 547 |
new_view = (current_view + direction) % num_views
|
| 548 |
|
| 549 |
new_selector_value = f"View {new_view + 1}"
|
| 550 |
-
normal_vis = update_normal_view(processed_data, new_view
|
| 551 |
|
| 552 |
return new_selector_value, normal_vis
|
| 553 |
|
|
@@ -572,14 +394,14 @@ def navigate_measure_view(processed_data, current_selector_value, direction):
|
|
| 572 |
return new_selector_value, measure_image, measure_points
|
| 573 |
|
| 574 |
|
| 575 |
-
def populate_visualization_tabs(processed_data
|
| 576 |
"""Populate the depth, normal, and measure tabs with processed data"""
|
| 577 |
if processed_data is None or len(processed_data) == 0:
|
| 578 |
return None, None, None, []
|
| 579 |
|
| 580 |
# Use update functions to ensure confidence filtering is applied from the start
|
| 581 |
-
depth_vis = update_depth_view(processed_data, 0
|
| 582 |
-
normal_vis = update_normal_view(processed_data, 0
|
| 583 |
measure_img, _ = update_measure_view(processed_data, 0)
|
| 584 |
|
| 585 |
return depth_vis, normal_vis, measure_img, []
|
|
@@ -683,13 +505,13 @@ def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
|
|
| 683 |
@spaces.GPU(duration=120)
|
| 684 |
def gradio_demo(
|
| 685 |
target_dir,
|
| 686 |
-
conf_thres=3.0,
|
| 687 |
frame_filter="All",
|
| 688 |
show_cam=True,
|
| 689 |
filter_sky=False,
|
| 690 |
filter_black_bg=False,
|
| 691 |
filter_white_bg=False,
|
| 692 |
-
|
|
|
|
| 693 |
):
|
| 694 |
"""
|
| 695 |
Perform reconstruction using the already-created target_dir/images.
|
|
@@ -716,7 +538,9 @@ def gradio_demo(
|
|
| 716 |
|
| 717 |
print("Running MapAnything model...")
|
| 718 |
with torch.no_grad():
|
| 719 |
-
predictions, processed_data = run_model(
|
|
|
|
|
|
|
| 720 |
|
| 721 |
# Save predictions
|
| 722 |
prediction_save_path = os.path.join(target_dir, "predictions.npz")
|
|
@@ -729,13 +553,12 @@ def gradio_demo(
|
|
| 729 |
# Build a GLB file name
|
| 730 |
glbfile = os.path.join(
|
| 731 |
target_dir,
|
| 732 |
-
f"glbscene_{
|
| 733 |
)
|
| 734 |
|
| 735 |
# Convert predictions to GLB
|
| 736 |
glbscene = predictions_to_glb(
|
| 737 |
predictions,
|
| 738 |
-
conf_thres=conf_thres,
|
| 739 |
filter_by_frames=frame_filter,
|
| 740 |
show_cam=show_cam,
|
| 741 |
target_dir=target_dir,
|
|
@@ -743,7 +566,6 @@ def gradio_demo(
|
|
| 743 |
mask_sky=filter_sky,
|
| 744 |
mask_black_bg=filter_black_bg,
|
| 745 |
mask_white_bg=filter_white_bg,
|
| 746 |
-
mask_ambiguous=mask_ambiguous,
|
| 747 |
)
|
| 748 |
glbscene.export(file_obj=glbfile)
|
| 749 |
|
|
@@ -760,7 +582,7 @@ def gradio_demo(
|
|
| 760 |
|
| 761 |
# Populate visualization tabs with processed data
|
| 762 |
depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
|
| 763 |
-
processed_data
|
| 764 |
)
|
| 765 |
|
| 766 |
# Update view selectors based on available views
|
|
@@ -860,29 +682,30 @@ def colorize_normal(normal_map, confidence=None, conf_thres=None):
|
|
| 860 |
return normal_vis
|
| 861 |
|
| 862 |
|
| 863 |
-
def process_predictions_for_visualization(
|
| 864 |
"""Extract depth, normal, and 3D points from predictions for visualization"""
|
| 865 |
processed_data = {}
|
| 866 |
|
| 867 |
# Check if confidence data is available in any view
|
| 868 |
has_confidence_data = False
|
| 869 |
-
for view_idx, view in enumerate(views):
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
|
| 875 |
# Process each view
|
| 876 |
for view_idx, view in enumerate(views):
|
| 877 |
-
view_key = f"pred{view_idx + 1}"
|
| 878 |
-
if view_key not in pred_result:
|
| 879 |
-
|
| 880 |
|
| 881 |
# Get image
|
| 882 |
image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
|
|
|
|
| 883 |
|
| 884 |
# Get predicted points
|
| 885 |
-
pred_pts3d =
|
| 886 |
|
| 887 |
# Initialize data for this view
|
| 888 |
view_data = {
|
|
@@ -895,36 +718,12 @@ def process_predictions_for_visualization(pred_result, views, high_level_config)
|
|
| 895 |
"has_confidence": has_confidence_data,
|
| 896 |
}
|
| 897 |
|
| 898 |
-
|
| 899 |
-
if "conf" in pred_result[view_key]:
|
| 900 |
-
confidence = pred_result[view_key]["conf"][0].cpu().numpy()
|
| 901 |
-
view_data["confidence"] = confidence
|
| 902 |
|
| 903 |
-
|
| 904 |
-
has_non_ambiguous_mask = "non_ambiguous_mask" in pred_result[view_key]
|
| 905 |
-
if has_non_ambiguous_mask:
|
| 906 |
-
view_data["mask"] = (
|
| 907 |
-
pred_result[view_key]["non_ambiguous_mask"][0].cpu().numpy()
|
| 908 |
-
)
|
| 909 |
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
ray_directions = pred_result[view_key]["ray_directions"][0].cpu()
|
| 913 |
-
ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
|
| 914 |
-
|
| 915 |
-
# Compute depth
|
| 916 |
-
local_pts3d = ray_directions * ray_depth
|
| 917 |
-
depth_z = local_pts3d[..., 2].numpy()
|
| 918 |
-
view_data["depth"] = depth_z
|
| 919 |
-
|
| 920 |
-
# Compute normals if we have valid points
|
| 921 |
-
if has_non_ambiguous_mask:
|
| 922 |
-
try:
|
| 923 |
-
normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
|
| 924 |
-
view_data["normal"] = normals
|
| 925 |
-
except:
|
| 926 |
-
# If normal computation fails, skip it
|
| 927 |
-
pass
|
| 928 |
|
| 929 |
processed_data[view_idx] = view_data
|
| 930 |
|
|
@@ -972,10 +771,29 @@ def measure(
|
|
| 972 |
point2d = event.index[0], event.index[1]
|
| 973 |
print(f"Clicked point: {point2d}")
|
| 974 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 975 |
measure_points.append(point2d)
|
| 976 |
|
| 977 |
-
# Get image and ensure it's valid
|
| 978 |
-
image =
|
| 979 |
if image is None:
|
| 980 |
return None, [], "No image available"
|
| 981 |
|
|
@@ -1093,14 +911,12 @@ def update_log():
|
|
| 1093 |
|
| 1094 |
def update_visualization(
|
| 1095 |
target_dir,
|
| 1096 |
-
conf_thres,
|
| 1097 |
frame_filter,
|
| 1098 |
show_cam,
|
| 1099 |
is_example,
|
| 1100 |
filter_sky=False,
|
| 1101 |
filter_black_bg=False,
|
| 1102 |
filter_white_bg=False,
|
| 1103 |
-
mask_ambiguous=False,
|
| 1104 |
):
|
| 1105 |
"""
|
| 1106 |
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
|
@@ -1135,13 +951,12 @@ def update_visualization(
|
|
| 1135 |
|
| 1136 |
glbfile = os.path.join(
|
| 1137 |
target_dir,
|
| 1138 |
-
f"glbscene_{
|
| 1139 |
)
|
| 1140 |
|
| 1141 |
if not os.path.exists(glbfile):
|
| 1142 |
glbscene = predictions_to_glb(
|
| 1143 |
predictions,
|
| 1144 |
-
conf_thres=conf_thres,
|
| 1145 |
filter_by_frames=frame_filter,
|
| 1146 |
show_cam=show_cam,
|
| 1147 |
target_dir=target_dir,
|
|
@@ -1149,7 +964,6 @@ def update_visualization(
|
|
| 1149 |
mask_sky=filter_sky,
|
| 1150 |
mask_black_bg=filter_black_bg,
|
| 1151 |
mask_white_bg=filter_white_bg,
|
| 1152 |
-
mask_ambiguous=mask_ambiguous,
|
| 1153 |
)
|
| 1154 |
glbscene.export(file_obj=glbfile)
|
| 1155 |
|
|
@@ -1346,6 +1160,9 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1346 |
interactive=False,
|
| 1347 |
sources=[],
|
| 1348 |
)
|
|
|
|
|
|
|
|
|
|
| 1349 |
measure_text = gr.Markdown("")
|
| 1350 |
|
| 1351 |
with gr.Row():
|
|
@@ -1363,17 +1180,11 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1363 |
)
|
| 1364 |
|
| 1365 |
with gr.Row():
|
| 1366 |
-
conf_thres = gr.Slider(
|
| 1367 |
-
minimum=0,
|
| 1368 |
-
maximum=100,
|
| 1369 |
-
value=0,
|
| 1370 |
-
step=0.1,
|
| 1371 |
-
label="Confidence Threshold (%), only shown in depth and normals",
|
| 1372 |
-
)
|
| 1373 |
frame_filter = gr.Dropdown(
|
| 1374 |
choices=["All"], value="All", label="Show Points from Frame"
|
| 1375 |
)
|
| 1376 |
with gr.Column():
|
|
|
|
| 1377 |
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 1378 |
filter_sky = gr.Checkbox(
|
| 1379 |
label="Filter Sky (using skyseg.onnx)", value=False
|
|
@@ -1384,8 +1195,11 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1384 |
filter_white_bg = gr.Checkbox(
|
| 1385 |
label="Filter White Background", value=False
|
| 1386 |
)
|
| 1387 |
-
|
| 1388 |
-
|
|
|
|
|
|
|
|
|
|
| 1389 |
# ---------------------- Example Scenes Section ----------------------
|
| 1390 |
gr.Markdown("## Example Scenes")
|
| 1391 |
gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
|
|
@@ -1446,13 +1260,13 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1446 |
fn=gradio_demo,
|
| 1447 |
inputs=[
|
| 1448 |
target_dir_output,
|
| 1449 |
-
conf_thres,
|
| 1450 |
frame_filter,
|
| 1451 |
show_cam,
|
| 1452 |
filter_sky,
|
| 1453 |
filter_black_bg,
|
| 1454 |
filter_white_bg,
|
| 1455 |
-
|
|
|
|
| 1456 |
],
|
| 1457 |
outputs=[
|
| 1458 |
reconstruction_output,
|
|
@@ -1476,76 +1290,10 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1476 |
# -------------------------------------------------------------------------
|
| 1477 |
# Real-time Visualization Updates
|
| 1478 |
# -------------------------------------------------------------------------
|
| 1479 |
-
def update_all_visualizations_on_conf_change(
|
| 1480 |
-
processed_data,
|
| 1481 |
-
depth_selector,
|
| 1482 |
-
normal_selector,
|
| 1483 |
-
conf_thres_val,
|
| 1484 |
-
target_dir,
|
| 1485 |
-
frame_filter,
|
| 1486 |
-
show_cam,
|
| 1487 |
-
is_example,
|
| 1488 |
-
):
|
| 1489 |
-
"""Update 3D view and all tabs when confidence threshold changes"""
|
| 1490 |
-
|
| 1491 |
-
# Update 3D pointcloud visualization
|
| 1492 |
-
glb_file, log_msg = update_visualization(
|
| 1493 |
-
target_dir,
|
| 1494 |
-
conf_thres_val,
|
| 1495 |
-
frame_filter,
|
| 1496 |
-
show_cam,
|
| 1497 |
-
is_example,
|
| 1498 |
-
)
|
| 1499 |
-
|
| 1500 |
-
# Update depth and normal tabs with new confidence threshold
|
| 1501 |
-
depth_vis = None
|
| 1502 |
-
normal_vis = None
|
| 1503 |
-
|
| 1504 |
-
if processed_data is not None:
|
| 1505 |
-
# Get current view indices from selectors
|
| 1506 |
-
try:
|
| 1507 |
-
depth_view_idx = (
|
| 1508 |
-
int(depth_selector.split()[1]) - 1 if depth_selector else 0
|
| 1509 |
-
)
|
| 1510 |
-
except:
|
| 1511 |
-
depth_view_idx = 0
|
| 1512 |
-
|
| 1513 |
-
try:
|
| 1514 |
-
normal_view_idx = (
|
| 1515 |
-
int(normal_selector.split()[1]) - 1 if normal_selector else 0
|
| 1516 |
-
)
|
| 1517 |
-
except:
|
| 1518 |
-
normal_view_idx = 0
|
| 1519 |
-
|
| 1520 |
-
# Update visualizations with new confidence threshold
|
| 1521 |
-
depth_vis = update_depth_view(
|
| 1522 |
-
processed_data, depth_view_idx, conf_thres=conf_thres_val
|
| 1523 |
-
)
|
| 1524 |
-
normal_vis = update_normal_view(
|
| 1525 |
-
processed_data, normal_view_idx, conf_thres=conf_thres_val
|
| 1526 |
-
)
|
| 1527 |
-
|
| 1528 |
-
return glb_file, log_msg, depth_vis, normal_vis
|
| 1529 |
-
|
| 1530 |
-
conf_thres.change(
|
| 1531 |
-
fn=update_all_visualizations_on_conf_change,
|
| 1532 |
-
inputs=[
|
| 1533 |
-
processed_data_state,
|
| 1534 |
-
depth_view_selector,
|
| 1535 |
-
normal_view_selector,
|
| 1536 |
-
conf_thres,
|
| 1537 |
-
target_dir_output,
|
| 1538 |
-
frame_filter,
|
| 1539 |
-
show_cam,
|
| 1540 |
-
is_example,
|
| 1541 |
-
],
|
| 1542 |
-
outputs=[reconstruction_output, log_output, depth_map, normal_map],
|
| 1543 |
-
)
|
| 1544 |
frame_filter.change(
|
| 1545 |
update_visualization,
|
| 1546 |
[
|
| 1547 |
target_dir_output,
|
| 1548 |
-
conf_thres,
|
| 1549 |
frame_filter,
|
| 1550 |
show_cam,
|
| 1551 |
is_example,
|
|
@@ -1556,7 +1304,6 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1556 |
update_visualization,
|
| 1557 |
[
|
| 1558 |
target_dir_output,
|
| 1559 |
-
conf_thres,
|
| 1560 |
frame_filter,
|
| 1561 |
show_cam,
|
| 1562 |
is_example,
|
|
@@ -1567,14 +1314,12 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1567 |
update_visualization,
|
| 1568 |
[
|
| 1569 |
target_dir_output,
|
| 1570 |
-
conf_thres,
|
| 1571 |
frame_filter,
|
| 1572 |
show_cam,
|
| 1573 |
is_example,
|
| 1574 |
filter_sky,
|
| 1575 |
filter_black_bg,
|
| 1576 |
filter_white_bg,
|
| 1577 |
-
mask_ambiguous,
|
| 1578 |
],
|
| 1579 |
[reconstruction_output, log_output],
|
| 1580 |
)
|
|
@@ -1582,14 +1327,12 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1582 |
update_visualization,
|
| 1583 |
[
|
| 1584 |
target_dir_output,
|
| 1585 |
-
conf_thres,
|
| 1586 |
frame_filter,
|
| 1587 |
show_cam,
|
| 1588 |
is_example,
|
| 1589 |
filter_sky,
|
| 1590 |
filter_black_bg,
|
| 1591 |
filter_white_bg,
|
| 1592 |
-
mask_ambiguous,
|
| 1593 |
],
|
| 1594 |
[reconstruction_output, log_output],
|
| 1595 |
)
|
|
@@ -1597,29 +1340,12 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1597 |
update_visualization,
|
| 1598 |
[
|
| 1599 |
target_dir_output,
|
| 1600 |
-
conf_thres,
|
| 1601 |
frame_filter,
|
| 1602 |
show_cam,
|
| 1603 |
is_example,
|
| 1604 |
filter_sky,
|
| 1605 |
filter_black_bg,
|
| 1606 |
filter_white_bg,
|
| 1607 |
-
mask_ambiguous,
|
| 1608 |
-
],
|
| 1609 |
-
[reconstruction_output, log_output],
|
| 1610 |
-
)
|
| 1611 |
-
mask_ambiguous.change(
|
| 1612 |
-
update_visualization,
|
| 1613 |
-
[
|
| 1614 |
-
target_dir_output,
|
| 1615 |
-
conf_thres,
|
| 1616 |
-
frame_filter,
|
| 1617 |
-
show_cam,
|
| 1618 |
-
is_example,
|
| 1619 |
-
filter_sky,
|
| 1620 |
-
filter_black_bg,
|
| 1621 |
-
filter_white_bg,
|
| 1622 |
-
mask_ambiguous,
|
| 1623 |
],
|
| 1624 |
[reconstruction_output, log_output],
|
| 1625 |
)
|
|
@@ -1653,67 +1379,61 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
|
|
| 1653 |
|
| 1654 |
# Depth tab navigation
|
| 1655 |
prev_depth_btn.click(
|
| 1656 |
-
fn=lambda processed_data, current_selector
|
| 1657 |
-
processed_data, current_selector, -1
|
| 1658 |
),
|
| 1659 |
-
inputs=[processed_data_state, depth_view_selector
|
| 1660 |
outputs=[depth_view_selector, depth_map],
|
| 1661 |
)
|
| 1662 |
|
| 1663 |
next_depth_btn.click(
|
| 1664 |
-
fn=lambda processed_data, current_selector
|
| 1665 |
-
processed_data, current_selector, 1
|
| 1666 |
),
|
| 1667 |
-
inputs=[processed_data_state, depth_view_selector
|
| 1668 |
outputs=[depth_view_selector, depth_map],
|
| 1669 |
)
|
| 1670 |
|
| 1671 |
depth_view_selector.change(
|
| 1672 |
-
fn=lambda processed_data, selector_value
|
| 1673 |
update_depth_view(
|
| 1674 |
processed_data,
|
| 1675 |
int(selector_value.split()[1]) - 1,
|
| 1676 |
-
conf_thres=conf_thres_val,
|
| 1677 |
)
|
| 1678 |
if selector_value
|
| 1679 |
else None
|
| 1680 |
),
|
| 1681 |
-
inputs=[processed_data_state, depth_view_selector
|
| 1682 |
outputs=[depth_map],
|
| 1683 |
)
|
| 1684 |
|
| 1685 |
# Normal tab navigation
|
| 1686 |
prev_normal_btn.click(
|
| 1687 |
-
fn=lambda processed_data,
|
| 1688 |
-
|
| 1689 |
-
conf_thres_val: navigate_normal_view(
|
| 1690 |
-
processed_data, current_selector, -1, conf_thres=conf_thres_val
|
| 1691 |
),
|
| 1692 |
-
inputs=[processed_data_state, normal_view_selector
|
| 1693 |
outputs=[normal_view_selector, normal_map],
|
| 1694 |
)
|
| 1695 |
|
| 1696 |
next_normal_btn.click(
|
| 1697 |
-
fn=lambda processed_data,
|
| 1698 |
-
|
| 1699 |
-
conf_thres_val: navigate_normal_view(
|
| 1700 |
-
processed_data, current_selector, 1, conf_thres=conf_thres_val
|
| 1701 |
),
|
| 1702 |
-
inputs=[processed_data_state, normal_view_selector
|
| 1703 |
outputs=[normal_view_selector, normal_map],
|
| 1704 |
)
|
| 1705 |
|
| 1706 |
normal_view_selector.change(
|
| 1707 |
-
fn=lambda processed_data, selector_value
|
| 1708 |
update_normal_view(
|
| 1709 |
processed_data,
|
| 1710 |
int(selector_value.split()[1]) - 1,
|
| 1711 |
-
conf_thres=conf_thres_val,
|
| 1712 |
)
|
| 1713 |
if selector_value
|
| 1714 |
else None
|
| 1715 |
),
|
| 1716 |
-
inputs=[processed_data_state, normal_view_selector
|
| 1717 |
outputs=[normal_map],
|
| 1718 |
)
|
| 1719 |
|
|
|
|
| 18 |
import numpy as np
|
| 19 |
import spaces
|
| 20 |
import torch
|
|
|
|
| 21 |
|
| 22 |
sys.path.append("mapanything/")
|
| 23 |
|
| 24 |
from hf_utils.css_and_html import (
|
| 25 |
+
GRADIO_CSS,
|
| 26 |
+
MEASURE_INSTRUCTIONS_HTML,
|
| 27 |
get_acknowledgements_html,
|
| 28 |
get_description_html,
|
| 29 |
get_gradio_theme,
|
| 30 |
get_header_html,
|
|
|
|
|
|
|
| 31 |
)
|
|
|
|
| 32 |
from hf_utils.visual_util import predictions_to_glb
|
| 33 |
+
from mapanything.models import MapAnything
|
| 34 |
+
from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
|
| 35 |
from mapanything.utils.image import load_images, rgb
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
def get_logo_base64():
|
|
|
|
| 100 |
return cfg
|
| 101 |
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# MapAnything Configuration
|
| 104 |
high_level_config = {
|
| 105 |
"path": "configs/train.yaml",
|
| 106 |
+
"hf_model_name": "facebook/map-anything",
|
| 107 |
"config_overrides": [
|
| 108 |
"machine=aws",
|
| 109 |
"model=mapanything",
|
| 110 |
"model/task=images_only",
|
| 111 |
"model.encoder.uses_torch_hub=false",
|
| 112 |
],
|
|
|
|
| 113 |
"trained_with_amp": True,
|
| 114 |
"trained_with_amp_dtype": "fp16",
|
| 115 |
"data_norm_type": "dinov2",
|
|
|
|
| 125 |
# 1) Core model inference
|
| 126 |
# -------------------------------------------------------------------------
|
| 127 |
@spaces.GPU(duration=120)
|
| 128 |
+
def run_model(target_dir, model_placeholder, apply_mask=True, mask_edges=True):
|
| 129 |
"""
|
| 130 |
Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
|
| 131 |
"""
|
|
|
|
| 135 |
# Device check
|
| 136 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 137 |
device = torch.device(device)
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# Initialize model if not already done
|
| 140 |
if model is None:
|
| 141 |
print("Initializing MapAnything model...")
|
| 142 |
+
|
| 143 |
+
print("Loading CC-BY-NC 4.0 licensed MapAnything model...")
|
| 144 |
+
model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to(
|
| 145 |
+
device
|
| 146 |
)
|
| 147 |
+
|
| 148 |
else:
|
| 149 |
model = model.to(device)
|
| 150 |
|
|
|
|
| 153 |
# Load images using MapAnything's load_images function
|
| 154 |
print("Loading images...")
|
| 155 |
image_folder_path = os.path.join(target_dir, "images")
|
| 156 |
+
views = load_images(image_folder_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
print(f"Loaded {len(views)} images")
|
| 159 |
if len(views) == 0:
|
| 160 |
raise ValueError("No images found. Check your upload.")
|
| 161 |
|
| 162 |
+
# Run model inference
|
| 163 |
+
print("Running inference...")
|
| 164 |
+
# apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
|
| 165 |
+
# mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
|
| 166 |
+
# Use checkbox values
|
| 167 |
+
outputs = model.infer(views, apply_mask=apply_mask, mask_edges=mask_edges)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
# Convert predictions to format expected by visualization
|
| 170 |
predictions = {}
|
|
|
|
| 175 |
world_points_list = []
|
| 176 |
depth_maps_list = []
|
| 177 |
images_list = []
|
|
|
|
| 178 |
final_mask_list = []
|
| 179 |
|
| 180 |
+
# Loop through the outputs
|
| 181 |
+
for pred in outputs:
|
| 182 |
+
# Extract data from predictions
|
| 183 |
+
depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
|
| 184 |
+
intrinsics_torch = pred["intrinsics"][0] # (3, 3)
|
| 185 |
+
camera_pose_torch = pred["camera_poses"][0] # (4, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
+
# Compute new pts3d using depth, intrinsics, and camera pose
|
| 188 |
+
pts3d_computed, valid_mask = depthmap_to_world_frame(
|
| 189 |
+
depthmap_torch, intrinsics_torch, camera_pose_torch
|
| 190 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
+
# Convert to numpy arrays for visualization
|
| 193 |
+
# Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
|
| 194 |
+
if "mask" in pred:
|
| 195 |
+
mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
|
| 196 |
+
else:
|
| 197 |
+
# Fill with boolean trues in the size of depthmap_torch
|
| 198 |
+
mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
# Combine with valid depth mask
|
| 201 |
+
mask = mask & valid_mask.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
image = pred["img_no_norm"][0].cpu().numpy()
|
|
|
|
|
|
|
| 204 |
|
| 205 |
+
# Append to lists
|
| 206 |
+
extrinsic_list.append(camera_pose_torch.cpu().numpy())
|
| 207 |
+
intrinsic_list.append(intrinsics_torch.cpu().numpy())
|
| 208 |
+
world_points_list.append(pts3d_computed.cpu().numpy())
|
| 209 |
+
depth_maps_list.append(depthmap_torch.cpu().numpy())
|
| 210 |
+
images_list.append(image) # Add image to list
|
| 211 |
+
final_mask_list.append(mask) # Add final_mask to list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
# Convert lists to numpy arrays with required shapes
|
| 214 |
# extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
|
|
|
|
| 225 |
# Add channel dimension if needed to match (S, H, W, 1) format
|
| 226 |
if len(depth_maps.shape) == 3:
|
| 227 |
depth_maps = depth_maps[..., np.newaxis]
|
| 228 |
+
|
| 229 |
predictions["depth"] = depth_maps
|
| 230 |
|
| 231 |
# images: (S, H, W, 3) - batch of input images
|
| 232 |
predictions["images"] = np.stack(images_list, axis=0)
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
# final_mask: (S, H, W) - batch of final masks for filtering
|
| 235 |
predictions["final_mask"] = np.stack(final_mask_list, axis=0)
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
# Process data for visualization tabs (depth, normal, measure)
|
| 238 |
processed_data = process_predictions_for_visualization(
|
| 239 |
+
predictions, views, high_level_config
|
| 240 |
)
|
| 241 |
|
| 242 |
# Clean up
|
|
|
|
| 272 |
return processed_data[view_keys[view_index]]
|
| 273 |
|
| 274 |
|
| 275 |
+
def update_depth_view(processed_data, view_index):
|
| 276 |
+
"""Update depth view for a specific view index"""
|
| 277 |
view_data = get_view_data_by_index(processed_data, view_index)
|
| 278 |
if view_data is None or view_data["depth"] is None:
|
| 279 |
return None
|
| 280 |
|
| 281 |
# Use confidence filtering if available
|
| 282 |
confidence = view_data.get("confidence")
|
| 283 |
+
return colorize_depth(view_data["depth"], confidence=confidence)
|
|
|
|
|
|
|
| 284 |
|
| 285 |
|
| 286 |
+
def update_normal_view(processed_data, view_index):
|
| 287 |
+
"""Update normal view for a specific view index"""
|
| 288 |
view_data = get_view_data_by_index(processed_data, view_index)
|
| 289 |
if view_data is None or view_data["normal"] is None:
|
| 290 |
return None
|
| 291 |
|
| 292 |
# Use confidence filtering if available
|
| 293 |
confidence = view_data.get("confidence")
|
| 294 |
+
return colorize_normal(view_data["normal"], confidence=confidence)
|
|
|
|
|
|
|
| 295 |
|
| 296 |
|
| 297 |
def update_measure_view(processed_data, view_index):
|
| 298 |
+
"""Update measure view for a specific view index with mask overlay"""
|
| 299 |
view_data = get_view_data_by_index(processed_data, view_index)
|
| 300 |
if view_data is None:
|
| 301 |
return None, [] # image, measure_points
|
|
|
|
| 302 |
|
| 303 |
+
# Get the base image
|
| 304 |
+
image = view_data["image"].copy()
|
| 305 |
|
| 306 |
+
# Ensure image is in uint8 format
|
| 307 |
+
if image.dtype != np.uint8:
|
| 308 |
+
if image.max() <= 1.0:
|
| 309 |
+
image = (image * 255).astype(np.uint8)
|
| 310 |
+
else:
|
| 311 |
+
image = image.astype(np.uint8)
|
| 312 |
+
|
| 313 |
+
# Apply mask overlay if mask is available
|
| 314 |
+
if view_data["mask"] is not None:
|
| 315 |
+
mask = view_data["mask"]
|
| 316 |
+
|
| 317 |
+
# Create light grey overlay for masked areas
|
| 318 |
+
# Masked areas (False values) will be overlaid with light grey
|
| 319 |
+
invalid_mask = ~mask # Areas where mask is False
|
| 320 |
+
|
| 321 |
+
if invalid_mask.any():
|
| 322 |
+
# Create a light grey overlay (RGB: 192, 192, 192)
|
| 323 |
+
overlay_color = np.array([192, 192, 192], dtype=np.uint8)
|
| 324 |
+
|
| 325 |
+
# Apply overlay with some transparency
|
| 326 |
+
alpha = 0.5 # Transparency level
|
| 327 |
+
for c in range(3): # RGB channels
|
| 328 |
+
image[:, :, c] = np.where(
|
| 329 |
+
invalid_mask,
|
| 330 |
+
(1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
|
| 331 |
+
image[:, :, c],
|
| 332 |
+
).astype(np.uint8)
|
| 333 |
+
|
| 334 |
+
return image, []
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def navigate_depth_view(processed_data, current_selector_value, direction):
|
| 338 |
"""Navigate depth view (direction: -1 for previous, +1 for next)"""
|
| 339 |
if processed_data is None or len(processed_data) == 0:
|
| 340 |
return "View 1", None
|
|
|
|
| 349 |
new_view = (current_view + direction) % num_views
|
| 350 |
|
| 351 |
new_selector_value = f"View {new_view + 1}"
|
| 352 |
+
depth_vis = update_depth_view(processed_data, new_view)
|
| 353 |
|
| 354 |
return new_selector_value, depth_vis
|
| 355 |
|
| 356 |
|
| 357 |
+
def navigate_normal_view(processed_data, current_selector_value, direction):
|
|
|
|
|
|
|
| 358 |
"""Navigate normal view (direction: -1 for previous, +1 for next)"""
|
| 359 |
if processed_data is None or len(processed_data) == 0:
|
| 360 |
return "View 1", None
|
|
|
|
| 369 |
new_view = (current_view + direction) % num_views
|
| 370 |
|
| 371 |
new_selector_value = f"View {new_view + 1}"
|
| 372 |
+
normal_vis = update_normal_view(processed_data, new_view)
|
| 373 |
|
| 374 |
return new_selector_value, normal_vis
|
| 375 |
|
|
|
|
| 394 |
return new_selector_value, measure_image, measure_points
|
| 395 |
|
| 396 |
|
| 397 |
+
def populate_visualization_tabs(processed_data):
|
| 398 |
"""Populate the depth, normal, and measure tabs with processed data"""
|
| 399 |
if processed_data is None or len(processed_data) == 0:
|
| 400 |
return None, None, None, []
|
| 401 |
|
| 402 |
# Use update functions to ensure confidence filtering is applied from the start
|
| 403 |
+
depth_vis = update_depth_view(processed_data, 0)
|
| 404 |
+
normal_vis = update_normal_view(processed_data, 0)
|
| 405 |
measure_img, _ = update_measure_view(processed_data, 0)
|
| 406 |
|
| 407 |
return depth_vis, normal_vis, measure_img, []
|
|
|
|
| 505 |
@spaces.GPU(duration=120)
|
| 506 |
def gradio_demo(
|
| 507 |
target_dir,
|
|
|
|
| 508 |
frame_filter="All",
|
| 509 |
show_cam=True,
|
| 510 |
filter_sky=False,
|
| 511 |
filter_black_bg=False,
|
| 512 |
filter_white_bg=False,
|
| 513 |
+
apply_mask=True,
|
| 514 |
+
mask_edges=True,
|
| 515 |
):
|
| 516 |
"""
|
| 517 |
Perform reconstruction using the already-created target_dir/images.
|
|
|
|
| 538 |
|
| 539 |
print("Running MapAnything model...")
|
| 540 |
with torch.no_grad():
|
| 541 |
+
predictions, processed_data = run_model(
|
| 542 |
+
target_dir, None, apply_mask, mask_edges
|
| 543 |
+
)
|
| 544 |
|
| 545 |
# Save predictions
|
| 546 |
prediction_save_path = os.path.join(target_dir, "predictions.npz")
|
|
|
|
| 553 |
# Build a GLB file name
|
| 554 |
glbfile = os.path.join(
|
| 555 |
target_dir,
|
| 556 |
+
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_pred{prediction_mode.replace(' ', '_')}.glb",
|
| 557 |
)
|
| 558 |
|
| 559 |
# Convert predictions to GLB
|
| 560 |
glbscene = predictions_to_glb(
|
| 561 |
predictions,
|
|
|
|
| 562 |
filter_by_frames=frame_filter,
|
| 563 |
show_cam=show_cam,
|
| 564 |
target_dir=target_dir,
|
|
|
|
| 566 |
mask_sky=filter_sky,
|
| 567 |
mask_black_bg=filter_black_bg,
|
| 568 |
mask_white_bg=filter_white_bg,
|
|
|
|
| 569 |
)
|
| 570 |
glbscene.export(file_obj=glbfile)
|
| 571 |
|
|
|
|
| 582 |
|
| 583 |
# Populate visualization tabs with processed data
|
| 584 |
depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
|
| 585 |
+
processed_data
|
| 586 |
)
|
| 587 |
|
| 588 |
# Update view selectors based on available views
|
|
|
|
| 682 |
return normal_vis
|
| 683 |
|
| 684 |
|
| 685 |
+
def process_predictions_for_visualization(predictions, views, high_level_config):
|
| 686 |
"""Extract depth, normal, and 3D points from predictions for visualization"""
|
| 687 |
processed_data = {}
|
| 688 |
|
| 689 |
# Check if confidence data is available in any view
|
| 690 |
has_confidence_data = False
|
| 691 |
+
# for view_idx, view in enumerate(views):
|
| 692 |
+
# view_key = f"pred{view_idx + 1}"
|
| 693 |
+
# if view_key in pred_result and "conf" in pred_result[view_key]:
|
| 694 |
+
# has_confidence_data = True
|
| 695 |
+
# break
|
| 696 |
|
| 697 |
# Process each view
|
| 698 |
for view_idx, view in enumerate(views):
|
| 699 |
+
# view_key = f"pred{view_idx + 1}"
|
| 700 |
+
# if view_key not in pred_result:
|
| 701 |
+
# continue
|
| 702 |
|
| 703 |
# Get image
|
| 704 |
image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
|
| 705 |
+
# image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
|
| 706 |
|
| 707 |
# Get predicted points
|
| 708 |
+
pred_pts3d = predictions["world_points"][view_idx]
|
| 709 |
|
| 710 |
# Initialize data for this view
|
| 711 |
view_data = {
|
|
|
|
| 718 |
"has_confidence": has_confidence_data,
|
| 719 |
}
|
| 720 |
|
| 721 |
+
view_data["mask"] = predictions["final_mask"][view_idx]
|
|
|
|
|
|
|
|
|
|
| 722 |
|
| 723 |
+
view_data["depth"] = predictions["depth"][view_idx].squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
|
| 725 |
+
normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
|
| 726 |
+
view_data["normal"] = normals
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
|
| 728 |
processed_data[view_idx] = view_data
|
| 729 |
|
|
|
|
| 771 |
point2d = event.index[0], event.index[1]
|
| 772 |
print(f"Clicked point: {point2d}")
|
| 773 |
|
| 774 |
+
# Check if the clicked point is in a masked area (prevent interaction)
|
| 775 |
+
if (
|
| 776 |
+
current_view["mask"] is not None
|
| 777 |
+
and 0 <= point2d[1] < current_view["mask"].shape[0]
|
| 778 |
+
and 0 <= point2d[0] < current_view["mask"].shape[1]
|
| 779 |
+
):
|
| 780 |
+
# Check if the point is in a masked (invalid) area
|
| 781 |
+
if not current_view["mask"][point2d[1], point2d[0]]:
|
| 782 |
+
print(f"Clicked point {point2d} is in masked area, ignoring click")
|
| 783 |
+
# Always return image with mask overlay
|
| 784 |
+
masked_image, _ = update_measure_view(
|
| 785 |
+
processed_data, current_view_index
|
| 786 |
+
)
|
| 787 |
+
return (
|
| 788 |
+
masked_image,
|
| 789 |
+
measure_points,
|
| 790 |
+
'<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
measure_points.append(point2d)
|
| 794 |
|
| 795 |
+
# Get image with mask overlay and ensure it's valid
|
| 796 |
+
image, _ = update_measure_view(processed_data, current_view_index)
|
| 797 |
if image is None:
|
| 798 |
return None, [], "No image available"
|
| 799 |
|
|
|
|
| 911 |
|
| 912 |
def update_visualization(
|
| 913 |
target_dir,
|
|
|
|
| 914 |
frame_filter,
|
| 915 |
show_cam,
|
| 916 |
is_example,
|
| 917 |
filter_sky=False,
|
| 918 |
filter_black_bg=False,
|
| 919 |
filter_white_bg=False,
|
|
|
|
| 920 |
):
|
| 921 |
"""
|
| 922 |
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
|
|
|
| 951 |
|
| 952 |
glbfile = os.path.join(
|
| 953 |
target_dir,
|
| 954 |
+
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_pred{prediction_mode.replace(' ', '_')}.glb",
|
| 955 |
)
|
| 956 |
|
| 957 |
if not os.path.exists(glbfile):
|
| 958 |
glbscene = predictions_to_glb(
|
| 959 |
predictions,
|
|
|
|
| 960 |
filter_by_frames=frame_filter,
|
| 961 |
show_cam=show_cam,
|
| 962 |
target_dir=target_dir,
|
|
|
|
| 964 |
mask_sky=filter_sky,
|
| 965 |
mask_black_bg=filter_black_bg,
|
| 966 |
mask_white_bg=filter_white_bg,
|
|
|
|
| 967 |
)
|
| 968 |
glbscene.export(file_obj=glbfile)
|
| 969 |
|
|
|
|
| 1160 |
interactive=False,
|
| 1161 |
sources=[],
|
| 1162 |
)
|
| 1163 |
+
gr.Markdown(
|
| 1164 |
+
"**Note:** Gray areas indicate regions with no depth information where measurements cannot be taken."
|
| 1165 |
+
)
|
| 1166 |
measure_text = gr.Markdown("")
|
| 1167 |
|
| 1168 |
with gr.Row():
|
|
|
|
| 1180 |
)
|
| 1181 |
|
| 1182 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1183 |
frame_filter = gr.Dropdown(
|
| 1184 |
choices=["All"], value="All", label="Show Points from Frame"
|
| 1185 |
)
|
| 1186 |
with gr.Column():
|
| 1187 |
+
gr.Markdown("### Pointcloud options (live updates)")
|
| 1188 |
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 1189 |
filter_sky = gr.Checkbox(
|
| 1190 |
label="Filter Sky (using skyseg.onnx)", value=False
|
|
|
|
| 1195 |
filter_white_bg = gr.Checkbox(
|
| 1196 |
label="Filter White Background", value=False
|
| 1197 |
)
|
| 1198 |
+
gr.Markdown("### Reconstruction options: (updated on next run)")
|
| 1199 |
+
apply_mask_checkbox = gr.Checkbox(
|
| 1200 |
+
label="Apply non-ambiguous mask", value=True
|
| 1201 |
+
)
|
| 1202 |
+
mask_edges_checkbox = apply_mask_checkbox
|
| 1203 |
# ---------------------- Example Scenes Section ----------------------
|
| 1204 |
gr.Markdown("## Example Scenes")
|
| 1205 |
gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
|
|
|
|
| 1260 |
fn=gradio_demo,
|
| 1261 |
inputs=[
|
| 1262 |
target_dir_output,
|
|
|
|
| 1263 |
frame_filter,
|
| 1264 |
show_cam,
|
| 1265 |
filter_sky,
|
| 1266 |
filter_black_bg,
|
| 1267 |
filter_white_bg,
|
| 1268 |
+
apply_mask_checkbox,
|
| 1269 |
+
mask_edges_checkbox,
|
| 1270 |
],
|
| 1271 |
outputs=[
|
| 1272 |
reconstruction_output,
|
|
|
|
| 1290 |
# -------------------------------------------------------------------------
|
| 1291 |
# Real-time Visualization Updates
|
| 1292 |
# -------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1293 |
frame_filter.change(
|
| 1294 |
update_visualization,
|
| 1295 |
[
|
| 1296 |
target_dir_output,
|
|
|
|
| 1297 |
frame_filter,
|
| 1298 |
show_cam,
|
| 1299 |
is_example,
|
|
|
|
| 1304 |
update_visualization,
|
| 1305 |
[
|
| 1306 |
target_dir_output,
|
|
|
|
| 1307 |
frame_filter,
|
| 1308 |
show_cam,
|
| 1309 |
is_example,
|
|
|
|
| 1314 |
update_visualization,
|
| 1315 |
[
|
| 1316 |
target_dir_output,
|
|
|
|
| 1317 |
frame_filter,
|
| 1318 |
show_cam,
|
| 1319 |
is_example,
|
| 1320 |
filter_sky,
|
| 1321 |
filter_black_bg,
|
| 1322 |
filter_white_bg,
|
|
|
|
| 1323 |
],
|
| 1324 |
[reconstruction_output, log_output],
|
| 1325 |
)
|
|
|
|
| 1327 |
update_visualization,
|
| 1328 |
[
|
| 1329 |
target_dir_output,
|
|
|
|
| 1330 |
frame_filter,
|
| 1331 |
show_cam,
|
| 1332 |
is_example,
|
| 1333 |
filter_sky,
|
| 1334 |
filter_black_bg,
|
| 1335 |
filter_white_bg,
|
|
|
|
| 1336 |
],
|
| 1337 |
[reconstruction_output, log_output],
|
| 1338 |
)
|
|
|
|
| 1340 |
update_visualization,
|
| 1341 |
[
|
| 1342 |
target_dir_output,
|
|
|
|
| 1343 |
frame_filter,
|
| 1344 |
show_cam,
|
| 1345 |
is_example,
|
| 1346 |
filter_sky,
|
| 1347 |
filter_black_bg,
|
| 1348 |
filter_white_bg,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1349 |
],
|
| 1350 |
[reconstruction_output, log_output],
|
| 1351 |
)
|
|
|
|
| 1379 |
|
| 1380 |
# Depth tab navigation
|
| 1381 |
prev_depth_btn.click(
|
| 1382 |
+
fn=lambda processed_data, current_selector: navigate_depth_view(
|
| 1383 |
+
processed_data, current_selector, -1
|
| 1384 |
),
|
| 1385 |
+
inputs=[processed_data_state, depth_view_selector],
|
| 1386 |
outputs=[depth_view_selector, depth_map],
|
| 1387 |
)
|
| 1388 |
|
| 1389 |
next_depth_btn.click(
|
| 1390 |
+
fn=lambda processed_data, current_selector: navigate_depth_view(
|
| 1391 |
+
processed_data, current_selector, 1
|
| 1392 |
),
|
| 1393 |
+
inputs=[processed_data_state, depth_view_selector],
|
| 1394 |
outputs=[depth_view_selector, depth_map],
|
| 1395 |
)
|
| 1396 |
|
| 1397 |
depth_view_selector.change(
|
| 1398 |
+
fn=lambda processed_data, selector_value: (
|
| 1399 |
update_depth_view(
|
| 1400 |
processed_data,
|
| 1401 |
int(selector_value.split()[1]) - 1,
|
|
|
|
| 1402 |
)
|
| 1403 |
if selector_value
|
| 1404 |
else None
|
| 1405 |
),
|
| 1406 |
+
inputs=[processed_data_state, depth_view_selector],
|
| 1407 |
outputs=[depth_map],
|
| 1408 |
)
|
| 1409 |
|
| 1410 |
# Normal tab navigation
|
| 1411 |
prev_normal_btn.click(
|
| 1412 |
+
fn=lambda processed_data, current_selector: navigate_normal_view(
|
| 1413 |
+
processed_data, current_selector, -1
|
|
|
|
|
|
|
| 1414 |
),
|
| 1415 |
+
inputs=[processed_data_state, normal_view_selector],
|
| 1416 |
outputs=[normal_view_selector, normal_map],
|
| 1417 |
)
|
| 1418 |
|
| 1419 |
next_normal_btn.click(
|
| 1420 |
+
fn=lambda processed_data, current_selector: navigate_normal_view(
|
| 1421 |
+
processed_data, current_selector, 1
|
|
|
|
|
|
|
| 1422 |
),
|
| 1423 |
+
inputs=[processed_data_state, normal_view_selector],
|
| 1424 |
outputs=[normal_view_selector, normal_map],
|
| 1425 |
)
|
| 1426 |
|
| 1427 |
normal_view_selector.change(
|
| 1428 |
+
fn=lambda processed_data, selector_value: (
|
| 1429 |
update_normal_view(
|
| 1430 |
processed_data,
|
| 1431 |
int(selector_value.split()[1]) - 1,
|
|
|
|
| 1432 |
)
|
| 1433 |
if selector_value
|
| 1434 |
else None
|
| 1435 |
),
|
| 1436 |
+
inputs=[processed_data_state, normal_view_selector],
|
| 1437 |
outputs=[normal_map],
|
| 1438 |
)
|
| 1439 |
|
hf_utils/vgg_geometry.py
DELETED
|
@@ -1,166 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import os
|
| 8 |
-
import torch
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def unproject_depth_map_to_point_map(
|
| 13 |
-
depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
|
| 14 |
-
) -> np.ndarray:
|
| 15 |
-
"""
|
| 16 |
-
Unproject a batch of depth maps to 3D world coordinates.
|
| 17 |
-
|
| 18 |
-
Args:
|
| 19 |
-
depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
|
| 20 |
-
extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
|
| 21 |
-
intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
|
| 22 |
-
|
| 23 |
-
Returns:
|
| 24 |
-
np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
|
| 25 |
-
"""
|
| 26 |
-
if isinstance(depth_map, torch.Tensor):
|
| 27 |
-
depth_map = depth_map.cpu().numpy()
|
| 28 |
-
if isinstance(extrinsics_cam, torch.Tensor):
|
| 29 |
-
extrinsics_cam = extrinsics_cam.cpu().numpy()
|
| 30 |
-
if isinstance(intrinsics_cam, torch.Tensor):
|
| 31 |
-
intrinsics_cam = intrinsics_cam.cpu().numpy()
|
| 32 |
-
|
| 33 |
-
world_points_list = []
|
| 34 |
-
for frame_idx in range(depth_map.shape[0]):
|
| 35 |
-
cur_world_points, _, _ = depth_to_world_coords_points(
|
| 36 |
-
depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
|
| 37 |
-
)
|
| 38 |
-
world_points_list.append(cur_world_points)
|
| 39 |
-
world_points_array = np.stack(world_points_list, axis=0)
|
| 40 |
-
|
| 41 |
-
return world_points_array
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def depth_to_world_coords_points(
|
| 45 |
-
depth_map: np.ndarray,
|
| 46 |
-
extrinsic: np.ndarray,
|
| 47 |
-
intrinsic: np.ndarray,
|
| 48 |
-
eps=1e-8,
|
| 49 |
-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 50 |
-
"""
|
| 51 |
-
Convert a depth map to world coordinates.
|
| 52 |
-
|
| 53 |
-
Args:
|
| 54 |
-
depth_map (np.ndarray): Depth map of shape (H, W).
|
| 55 |
-
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
| 56 |
-
extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
|
| 57 |
-
|
| 58 |
-
Returns:
|
| 59 |
-
tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
|
| 60 |
-
"""
|
| 61 |
-
if depth_map is None:
|
| 62 |
-
return None, None, None
|
| 63 |
-
|
| 64 |
-
# Valid depth mask
|
| 65 |
-
point_mask = depth_map > eps
|
| 66 |
-
|
| 67 |
-
# Convert depth map to camera coordinates
|
| 68 |
-
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
|
| 69 |
-
|
| 70 |
-
# Multiply with the inverse of extrinsic matrix to transform to world coordinates
|
| 71 |
-
# extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
|
| 72 |
-
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
|
| 73 |
-
|
| 74 |
-
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
|
| 75 |
-
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
|
| 76 |
-
|
| 77 |
-
# Apply the rotation and translation to the camera coordinates
|
| 78 |
-
world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
|
| 79 |
-
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
|
| 80 |
-
|
| 81 |
-
return world_coords_points, cam_coords_points, point_mask
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 85 |
-
"""
|
| 86 |
-
Convert a depth map to camera coordinates.
|
| 87 |
-
|
| 88 |
-
Args:
|
| 89 |
-
depth_map (np.ndarray): Depth map of shape (H, W).
|
| 90 |
-
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
| 91 |
-
|
| 92 |
-
Returns:
|
| 93 |
-
tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
|
| 94 |
-
"""
|
| 95 |
-
H, W = depth_map.shape
|
| 96 |
-
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
|
| 97 |
-
assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
|
| 98 |
-
|
| 99 |
-
# Intrinsic parameters
|
| 100 |
-
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
|
| 101 |
-
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
|
| 102 |
-
|
| 103 |
-
# Generate grid of pixel coordinates
|
| 104 |
-
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 105 |
-
|
| 106 |
-
# Unproject to camera coordinates
|
| 107 |
-
x_cam = (u - cu) * depth_map / fu
|
| 108 |
-
y_cam = (v - cv) * depth_map / fv
|
| 109 |
-
z_cam = depth_map
|
| 110 |
-
|
| 111 |
-
# Stack to form camera coordinates
|
| 112 |
-
cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 113 |
-
|
| 114 |
-
return cam_coords
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def closed_form_inverse_se3(se3, R=None, T=None):
|
| 118 |
-
"""
|
| 119 |
-
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
| 120 |
-
|
| 121 |
-
If `R` and `T` are provided, they must correspond to the rotation and translation
|
| 122 |
-
components of `se3`. Otherwise, they will be extracted from `se3`.
|
| 123 |
-
|
| 124 |
-
Args:
|
| 125 |
-
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
|
| 126 |
-
R (optional): Nx3x3 array or tensor of rotation matrices.
|
| 127 |
-
T (optional): Nx3x1 array or tensor of translation vectors.
|
| 128 |
-
|
| 129 |
-
Returns:
|
| 130 |
-
Inverted SE3 matrices with the same type and device as `se3`.
|
| 131 |
-
|
| 132 |
-
Shapes:
|
| 133 |
-
se3: (N, 4, 4)
|
| 134 |
-
R: (N, 3, 3)
|
| 135 |
-
T: (N, 3, 1)
|
| 136 |
-
"""
|
| 137 |
-
# Check if se3 is a numpy array or a torch tensor
|
| 138 |
-
is_numpy = isinstance(se3, np.ndarray)
|
| 139 |
-
|
| 140 |
-
# Validate shapes
|
| 141 |
-
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
| 142 |
-
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
| 143 |
-
|
| 144 |
-
# Extract R and T if not provided
|
| 145 |
-
if R is None:
|
| 146 |
-
R = se3[:, :3, :3] # (N,3,3)
|
| 147 |
-
if T is None:
|
| 148 |
-
T = se3[:, :3, 3:] # (N,3,1)
|
| 149 |
-
|
| 150 |
-
# Transpose R
|
| 151 |
-
if is_numpy:
|
| 152 |
-
# Compute the transpose of the rotation for NumPy
|
| 153 |
-
R_transposed = np.transpose(R, (0, 2, 1))
|
| 154 |
-
# -R^T t for NumPy
|
| 155 |
-
top_right = -np.matmul(R_transposed, T)
|
| 156 |
-
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
| 157 |
-
else:
|
| 158 |
-
R_transposed = R.transpose(1, 2) # (N,3,3)
|
| 159 |
-
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
|
| 160 |
-
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
| 161 |
-
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
| 162 |
-
|
| 163 |
-
inverted_matrix[:, :3, :3] = R_transposed
|
| 164 |
-
inverted_matrix[:, :3, 3:] = top_right
|
| 165 |
-
|
| 166 |
-
return inverted_matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_utils/visual_util.py
CHANGED
|
@@ -221,14 +221,14 @@ def predictions_to_glb(
|
|
| 221 |
|
| 222 |
# Prepare 4x4 matrices for camera extrinsics
|
| 223 |
num_cameras = len(camera_matrices)
|
| 224 |
-
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
|
| 225 |
-
extrinsics_matrices[:, :3, :4] = camera_matrices
|
| 226 |
-
extrinsics_matrices[:, 3, 3] = 1
|
| 227 |
|
| 228 |
if show_cam:
|
| 229 |
# Add camera models to the scene
|
| 230 |
for i in range(num_cameras):
|
| 231 |
-
world_to_camera =
|
| 232 |
camera_to_world = np.linalg.inv(world_to_camera)
|
| 233 |
rgba_color = colormap(i / num_cameras)
|
| 234 |
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
|
@@ -238,7 +238,7 @@ def predictions_to_glb(
|
|
| 238 |
)
|
| 239 |
|
| 240 |
# Align scene to the observation of the first camera
|
| 241 |
-
scene_3d = apply_scene_alignment(scene_3d,
|
| 242 |
|
| 243 |
print("GLB Scene built")
|
| 244 |
return scene_3d
|
|
|
|
| 221 |
|
| 222 |
# Prepare 4x4 matrices for camera extrinsics
|
| 223 |
num_cameras = len(camera_matrices)
|
| 224 |
+
# extrinsics_matrices = np.zeros((num_cameras, 4, 4))
|
| 225 |
+
# extrinsics_matrices[:, :3, :4] = camera_matrices
|
| 226 |
+
# extrinsics_matrices[:, 3, 3] = 1
|
| 227 |
|
| 228 |
if show_cam:
|
| 229 |
# Add camera models to the scene
|
| 230 |
for i in range(num_cameras):
|
| 231 |
+
world_to_camera = camera_matrices[i]
|
| 232 |
camera_to_world = np.linalg.inv(world_to_camera)
|
| 233 |
rgba_color = colormap(i / num_cameras)
|
| 234 |
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
|
|
|
| 238 |
)
|
| 239 |
|
| 240 |
# Align scene to the observation of the first camera
|
| 241 |
+
scene_3d = apply_scene_alignment(scene_3d, camera_matrices)
|
| 242 |
|
| 243 |
print("GLB Scene built")
|
| 244 |
return scene_3d
|
mapanything/__init__.py
ADDED
|
File without changes
|
mapanything/datasets/wai/ase.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class ASEWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class ASEWAI(BaseDataset):
|
mapanything/datasets/wai/bedlam.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class BedlamWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class BedlamWAI(BaseDataset):
|
mapanything/datasets/wai/blendedmvs.py
CHANGED
|
@@ -8,7 +8,7 @@ import cv2
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
-
from wai import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class BlendedMVSWAI(BaseDataset):
|
|
@@ -108,16 +108,9 @@ class BlendedMVSWAI(BaseDataset):
|
|
| 108 |
view_data = load_frame(
|
| 109 |
scene_root,
|
| 110 |
view_file_name,
|
| 111 |
-
modalities=["image", "depth"],
|
| 112 |
-
# modalities=["image", "depth", "pred_mask/moge2"],
|
| 113 |
scene_meta=scene_meta,
|
| 114 |
)
|
| 115 |
-
### HOTFIX: Load required additional masks manually
|
| 116 |
-
### Remove once stability issue with scene_meta is fixed
|
| 117 |
-
mask_path = os.path.join(
|
| 118 |
-
scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
|
| 119 |
-
)
|
| 120 |
-
view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
|
| 121 |
|
| 122 |
# Convert necessary data to numpy
|
| 123 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class BlendedMVSWAI(BaseDataset):
|
|
|
|
| 108 |
view_data = load_frame(
|
| 109 |
scene_root,
|
| 110 |
view_file_name,
|
| 111 |
+
modalities=["image", "depth", "pred_mask/moge2"],
|
|
|
|
| 112 |
scene_meta=scene_meta,
|
| 113 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Convert necessary data to numpy
|
| 116 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
mapanything/datasets/wai/dl3dv.py
CHANGED
|
@@ -12,7 +12,7 @@ from mapanything.utils.cropping import (
|
|
| 12 |
rescale_image_and_other_optional_info,
|
| 13 |
resize_with_nearest_interpolation_to_match_aspect_ratio,
|
| 14 |
)
|
| 15 |
-
from wai import load_data, load_frame
|
| 16 |
|
| 17 |
|
| 18 |
class DL3DVWAI(BaseDataset):
|
|
@@ -115,35 +115,14 @@ class DL3DVWAI(BaseDataset):
|
|
| 115 |
view_data = load_frame(
|
| 116 |
scene_root,
|
| 117 |
view_file_name,
|
| 118 |
-
modalities=[
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
# ],
|
| 125 |
scene_meta=scene_meta,
|
| 126 |
)
|
| 127 |
-
### HOTFIX: Load required additional modalities manually
|
| 128 |
-
### Remove once stability issue with scene_meta is fixed
|
| 129 |
-
mvs_depth_path = os.path.join(
|
| 130 |
-
scene_root, "mvsanywhere", "v0", "depth", f"{view_file_name}.exr"
|
| 131 |
-
)
|
| 132 |
-
mvs_conf_path = os.path.join(
|
| 133 |
-
scene_root,
|
| 134 |
-
"mvsanywhere",
|
| 135 |
-
"v0",
|
| 136 |
-
"depth_confidence",
|
| 137 |
-
f"{view_file_name}.exr",
|
| 138 |
-
)
|
| 139 |
-
mask_path = os.path.join(
|
| 140 |
-
scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
|
| 141 |
-
)
|
| 142 |
-
view_data["pred_depth/mvsanywhere"] = load_data(mvs_depth_path, "depth")
|
| 143 |
-
view_data["depth_confidence/mvsanywhere"] = load_data(
|
| 144 |
-
mvs_conf_path, "scalar"
|
| 145 |
-
)
|
| 146 |
-
view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
|
| 147 |
|
| 148 |
# Convert necessary data to numpy
|
| 149 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
|
|
|
| 12 |
rescale_image_and_other_optional_info,
|
| 13 |
resize_with_nearest_interpolation_to_match_aspect_ratio,
|
| 14 |
)
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
|
| 17 |
|
| 18 |
class DL3DVWAI(BaseDataset):
|
|
|
|
| 115 |
view_data = load_frame(
|
| 116 |
scene_root,
|
| 117 |
view_file_name,
|
| 118 |
+
modalities=[
|
| 119 |
+
"image",
|
| 120 |
+
"pred_depth/mvsanywhere",
|
| 121 |
+
"pred_mask/moge2",
|
| 122 |
+
"depth_confidence/mvsanywhere",
|
| 123 |
+
],
|
|
|
|
| 124 |
scene_meta=scene_meta,
|
| 125 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# Convert necessary data to numpy
|
| 128 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
mapanything/datasets/wai/dtu.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class DTUWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class DTUWAI(BaseDataset):
|
mapanything/datasets/wai/dynamicreplica.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class DynamicReplicaWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class DynamicReplicaWAI(BaseDataset):
|
mapanything/datasets/wai/eth3d.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class ETH3DWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class ETH3DWAI(BaseDataset):
|
mapanything/datasets/wai/gta_sfm.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class GTASfMWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class GTASfMWAI(BaseDataset):
|
mapanything/datasets/wai/matrixcity.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class MatrixCityWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class MatrixCityWAI(BaseDataset):
|
mapanything/datasets/wai/megadepth.py
CHANGED
|
@@ -8,7 +8,7 @@ import cv2
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
-
from wai import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class MegaDepthWAI(BaseDataset):
|
|
@@ -109,16 +109,9 @@ class MegaDepthWAI(BaseDataset):
|
|
| 109 |
view_data = load_frame(
|
| 110 |
scene_root,
|
| 111 |
view_file_name,
|
| 112 |
-
modalities=["image", "depth"],
|
| 113 |
-
# modalities=["image", "depth", "pred_mask/moge2"],
|
| 114 |
scene_meta=scene_meta,
|
| 115 |
)
|
| 116 |
-
### HOTFIX: Load required additional masks manually
|
| 117 |
-
### Remove once stability issue with scene_meta is fixed
|
| 118 |
-
mask_path = os.path.join(
|
| 119 |
-
scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
|
| 120 |
-
)
|
| 121 |
-
view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
|
| 122 |
|
| 123 |
# Convert necessary data to numpy
|
| 124 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class MegaDepthWAI(BaseDataset):
|
|
|
|
| 109 |
view_data = load_frame(
|
| 110 |
scene_root,
|
| 111 |
view_file_name,
|
| 112 |
+
modalities=["image", "depth", "pred_mask/moge2"],
|
|
|
|
| 113 |
scene_meta=scene_meta,
|
| 114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
# Convert necessary data to numpy
|
| 117 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
mapanything/datasets/wai/mpsd.py
CHANGED
|
@@ -8,7 +8,7 @@ import cv2
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
-
from wai import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class MPSDWAI(BaseDataset):
|
|
@@ -108,16 +108,9 @@ class MPSDWAI(BaseDataset):
|
|
| 108 |
view_data = load_frame(
|
| 109 |
scene_root,
|
| 110 |
view_file_name,
|
| 111 |
-
modalities=["image", "depth"],
|
| 112 |
-
# modalities=["image", "depth", "pred_mask/moge2"],
|
| 113 |
scene_meta=scene_meta,
|
| 114 |
)
|
| 115 |
-
### HOTFIX: Load required additional masks manually
|
| 116 |
-
### Remove once stability issue with scene_meta is fixed
|
| 117 |
-
mask_path = os.path.join(
|
| 118 |
-
scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
|
| 119 |
-
)
|
| 120 |
-
view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
|
| 121 |
|
| 122 |
# Convert necessary data to numpy
|
| 123 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class MPSDWAI(BaseDataset):
|
|
|
|
| 108 |
view_data = load_frame(
|
| 109 |
scene_root,
|
| 110 |
view_file_name,
|
| 111 |
+
modalities=["image", "depth", "pred_mask/moge2"],
|
|
|
|
| 112 |
scene_meta=scene_meta,
|
| 113 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Convert necessary data to numpy
|
| 116 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
mapanything/datasets/wai/mvs_synth.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class MVSSynthWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class MVSSynthWAI(BaseDataset):
|
mapanything/datasets/wai/paralleldomain4d.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class ParallelDomain4DWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class ParallelDomain4DWAI(BaseDataset):
|
mapanything/datasets/wai/sailvos3d.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class SAILVOS3DWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class SAILVOS3DWAI(BaseDataset):
|
mapanything/datasets/wai/scannetpp.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class ScanNetPPWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class ScanNetPPWAI(BaseDataset):
|
mapanything/datasets/wai/spring.py
CHANGED
|
@@ -8,7 +8,7 @@ import cv2
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
-
from wai import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class SpringWAI(BaseDataset):
|
|
@@ -107,16 +107,9 @@ class SpringWAI(BaseDataset):
|
|
| 107 |
view_data = load_frame(
|
| 108 |
scene_root,
|
| 109 |
view_file_name,
|
| 110 |
-
modalities=["image", "depth", "skymask"],
|
| 111 |
-
# modalities=["image", "depth", "skymask", "pred_mask/moge2"],
|
| 112 |
scene_meta=scene_meta,
|
| 113 |
)
|
| 114 |
-
### HOTFIX: Load required additional masks manually
|
| 115 |
-
### Remove once stability issue with scene_meta is fixed
|
| 116 |
-
mask_path = os.path.join(
|
| 117 |
-
scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
|
| 118 |
-
)
|
| 119 |
-
view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
|
| 120 |
|
| 121 |
# Convert necessary data to numpy
|
| 122 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class SpringWAI(BaseDataset):
|
|
|
|
| 107 |
view_data = load_frame(
|
| 108 |
scene_root,
|
| 109 |
view_file_name,
|
| 110 |
+
modalities=["image", "depth", "skymask", "pred_mask/moge2"],
|
|
|
|
| 111 |
scene_meta=scene_meta,
|
| 112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# Convert necessary data to numpy
|
| 115 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
mapanything/datasets/wai/structured3d.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class Structured3DWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class Structured3DWAI(BaseDataset):
|
mapanything/datasets/wai/tav2_wb.py
CHANGED
|
@@ -8,7 +8,7 @@ import cv2
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
-
from wai import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class TartanAirV2WBWAI(BaseDataset):
|
|
@@ -108,16 +108,9 @@ class TartanAirV2WBWAI(BaseDataset):
|
|
| 108 |
view_data = load_frame(
|
| 109 |
scene_root,
|
| 110 |
view_file_name,
|
| 111 |
-
modalities=["image", "depth"],
|
| 112 |
-
# modalities=["image", "depth", "pred_mask/moge2"],
|
| 113 |
scene_meta=scene_meta,
|
| 114 |
)
|
| 115 |
-
### HOTFIX: Load required additional masks manually
|
| 116 |
-
### Remove once stability issue with scene_meta is fixed
|
| 117 |
-
mask_path = os.path.join(
|
| 118 |
-
scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
|
| 119 |
-
)
|
| 120 |
-
view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
|
| 121 |
|
| 122 |
# Convert necessary data to numpy
|
| 123 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 11 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 12 |
|
| 13 |
|
| 14 |
class TartanAirV2WBWAI(BaseDataset):
|
|
|
|
| 108 |
view_data = load_frame(
|
| 109 |
scene_root,
|
| 110 |
view_file_name,
|
| 111 |
+
modalities=["image", "depth", "pred_mask/moge2"],
|
|
|
|
| 112 |
scene_meta=scene_meta,
|
| 113 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Convert necessary data to numpy
|
| 116 |
image = view_data["image"].permute(1, 2, 0).numpy()
|
mapanything/datasets/wai/unrealstereo4k.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class UnrealStereo4KWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class UnrealStereo4KWAI(BaseDataset):
|
mapanything/datasets/wai/xrooms.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
-
from wai import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class XRoomsWAI(BaseDataset):
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 10 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 11 |
|
| 12 |
|
| 13 |
class XRoomsWAI(BaseDataset):
|
mapanything/models/external/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# External Model Code for Benchmarking & Re-Training
|
| 2 |
+
|
| 3 |
+
This directory contains external model code that we use to train and benchmark external models fairly. These libraries are not part of the core MapAnything codebase and are included for only benchmarking purposes. The code in this directory is licensed under the same license as the source code from which it was derived, unless otherwise specified.
|
| 4 |
+
|
| 5 |
+
The open-source Apache 2.0 License of MapAnything does not apply to these libraries.
|
mapanything/models/external/moge/models/v1.py
CHANGED
|
@@ -475,7 +475,7 @@ class MoGeModel(nn.Module):
|
|
| 475 |
return_dict = {"points": points, "mask": mask}
|
| 476 |
return return_dict
|
| 477 |
|
| 478 |
-
@torch.inference_mode()
|
| 479 |
def infer(
|
| 480 |
self,
|
| 481 |
image: torch.Tensor,
|
|
|
|
| 475 |
return_dict = {"points": points, "mask": mask}
|
| 476 |
return return_dict
|
| 477 |
|
| 478 |
+
# @torch.inference_mode()
|
| 479 |
def infer(
|
| 480 |
self,
|
| 481 |
image: torch.Tensor,
|
mapanything/models/external/moge/models/v2.py
CHANGED
|
@@ -227,7 +227,7 @@ class MoGeModel(nn.Module):
|
|
| 227 |
|
| 228 |
return return_dict
|
| 229 |
|
| 230 |
-
@torch.inference_mode()
|
| 231 |
def infer(
|
| 232 |
self,
|
| 233 |
image: torch.Tensor,
|
|
|
|
| 227 |
|
| 228 |
return return_dict
|
| 229 |
|
| 230 |
+
# @torch.inference_mode()
|
| 231 |
def infer(
|
| 232 |
self,
|
| 233 |
image: torch.Tensor,
|
mapanything/models/mapanything/ablations.py
CHANGED
|
@@ -134,8 +134,10 @@ class MapAnythingAblations(nn.Module):
|
|
| 134 |
# Initialize image encoder
|
| 135 |
if self.encoder_config["uses_torch_hub"]:
|
| 136 |
self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# Initialize the encoder for ray directions
|
| 141 |
ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
|
|
|
|
| 134 |
# Initialize image encoder
|
| 135 |
if self.encoder_config["uses_torch_hub"]:
|
| 136 |
self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
|
| 137 |
+
# Create a copy of the config before deleting the key to preserve it for serialization
|
| 138 |
+
encoder_config_copy = self.encoder_config.copy()
|
| 139 |
+
del encoder_config_copy["uses_torch_hub"]
|
| 140 |
+
self.encoder = encoder_factory(**encoder_config_copy)
|
| 141 |
|
| 142 |
# Initialize the encoder for ray directions
|
| 143 |
ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
|
mapanything/models/mapanything/model.py
CHANGED
|
@@ -2,11 +2,13 @@
|
|
| 2 |
MapAnything model class defined using UniCeption modules.
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
from functools import partial
|
| 6 |
-
from typing import Callable, Dict, Type, Union
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
|
|
|
| 10 |
|
| 11 |
from mapanything.utils.geometry import (
|
| 12 |
apply_log_to_norm,
|
|
@@ -15,6 +17,11 @@ from mapanything.utils.geometry import (
|
|
| 15 |
normalize_pose_translations,
|
| 16 |
transform_pose_using_quats_and_trans_2_to_1,
|
| 17 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from uniception.models.encoders import (
|
| 19 |
encoder_factory,
|
| 20 |
EncoderGlobalRepInput,
|
|
@@ -72,7 +79,7 @@ if hasattr(torch.backends.cuda, "matmul") and hasattr(
|
|
| 72 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 73 |
|
| 74 |
|
| 75 |
-
class MapAnything(nn.Module):
|
| 76 |
"Modular MapAnything model class that supports input of images & optional geometric modalities (multiple reconstruction tasks)."
|
| 77 |
|
| 78 |
def __init__(
|
|
@@ -139,8 +146,10 @@ class MapAnything(nn.Module):
|
|
| 139 |
# Initialize image encoder
|
| 140 |
if self.encoder_config["uses_torch_hub"]:
|
| 141 |
self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Initialize the encoder for ray directions
|
| 146 |
ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
|
|
@@ -199,6 +208,14 @@ class MapAnything(nn.Module):
|
|
| 199 |
# Load pretrained weights
|
| 200 |
self._load_pretrained_weights()
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
def _initialize_info_sharing(self, info_sharing_config):
|
| 203 |
"""
|
| 204 |
Initialize the information sharing module based on the configuration.
|
|
@@ -1717,3 +1734,202 @@ class MapAnything(nn.Module):
|
|
| 1717 |
res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i]
|
| 1718 |
|
| 1719 |
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
MapAnything model class defined using UniCeption modules.
|
| 3 |
"""
|
| 4 |
|
| 5 |
+
import warnings
|
| 6 |
from functools import partial
|
| 7 |
+
from typing import Any, Callable, Dict, List, Type, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 12 |
|
| 13 |
from mapanything.utils.geometry import (
|
| 14 |
apply_log_to_norm,
|
|
|
|
| 17 |
normalize_pose_translations,
|
| 18 |
transform_pose_using_quats_and_trans_2_to_1,
|
| 19 |
)
|
| 20 |
+
from mapanything.utils.inference import (
|
| 21 |
+
postprocess_model_outputs_for_inference,
|
| 22 |
+
preprocess_input_views_for_inference,
|
| 23 |
+
validate_input_views_for_inference,
|
| 24 |
+
)
|
| 25 |
from uniception.models.encoders import (
|
| 26 |
encoder_factory,
|
| 27 |
EncoderGlobalRepInput,
|
|
|
|
| 79 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 80 |
|
| 81 |
|
| 82 |
+
class MapAnything(nn.Module, PyTorchModelHubMixin):
|
| 83 |
"Modular MapAnything model class that supports input of images & optional geometric modalities (multiple reconstruction tasks)."
|
| 84 |
|
| 85 |
def __init__(
|
|
|
|
| 146 |
# Initialize image encoder
|
| 147 |
if self.encoder_config["uses_torch_hub"]:
|
| 148 |
self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
|
| 149 |
+
# Create a copy of the config before deleting the key to preserve it for serialization
|
| 150 |
+
encoder_config_copy = self.encoder_config.copy()
|
| 151 |
+
del encoder_config_copy["uses_torch_hub"]
|
| 152 |
+
self.encoder = encoder_factory(**encoder_config_copy)
|
| 153 |
|
| 154 |
# Initialize the encoder for ray directions
|
| 155 |
ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
|
|
|
|
| 208 |
# Load pretrained weights
|
| 209 |
self._load_pretrained_weights()
|
| 210 |
|
| 211 |
+
@property
|
| 212 |
+
def device(self) -> torch.device:
|
| 213 |
+
return next(self.parameters()).device
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def dtype(self) -> torch.dtype:
|
| 217 |
+
return next(self.parameters()).dtype
|
| 218 |
+
|
| 219 |
def _initialize_info_sharing(self, info_sharing_config):
|
| 220 |
"""
|
| 221 |
Initialize the information sharing module based on the configuration.
|
|
|
|
| 1734 |
res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i]
|
| 1735 |
|
| 1736 |
return res
|
| 1737 |
+
|
| 1738 |
+
def _configure_geometric_input_config(
|
| 1739 |
+
self,
|
| 1740 |
+
use_calibration: bool,
|
| 1741 |
+
use_depth: bool,
|
| 1742 |
+
use_pose: bool,
|
| 1743 |
+
use_depth_scale: bool,
|
| 1744 |
+
use_pose_scale: bool,
|
| 1745 |
+
):
|
| 1746 |
+
"""
|
| 1747 |
+
Configure the geometric input configuration
|
| 1748 |
+
"""
|
| 1749 |
+
# Store original config for restoration
|
| 1750 |
+
if not hasattr(self, "_original_geometric_config"):
|
| 1751 |
+
self._original_geometric_config = dict(self.geometric_input_config)
|
| 1752 |
+
|
| 1753 |
+
# Set the geometric input configuration
|
| 1754 |
+
if not (use_calibration or use_depth or use_pose):
|
| 1755 |
+
# No geometric inputs (images-only mode)
|
| 1756 |
+
self.geometric_input_config.update(
|
| 1757 |
+
{
|
| 1758 |
+
"overall_prob": 0.0,
|
| 1759 |
+
"dropout_prob": 1.0,
|
| 1760 |
+
"ray_dirs_prob": 0.0,
|
| 1761 |
+
"depth_prob": 0.0,
|
| 1762 |
+
"cam_prob": 0.0,
|
| 1763 |
+
"sparse_depth_prob": 0.0,
|
| 1764 |
+
"depth_scale_norm_all_prob": 0.0,
|
| 1765 |
+
"pose_scale_norm_all_prob": 0.0,
|
| 1766 |
+
}
|
| 1767 |
+
)
|
| 1768 |
+
else:
|
| 1769 |
+
# Enable geometric inputs with deterministic behavior
|
| 1770 |
+
self.geometric_input_config.update(
|
| 1771 |
+
{
|
| 1772 |
+
"overall_prob": 1.0,
|
| 1773 |
+
"dropout_prob": 0.0,
|
| 1774 |
+
"ray_dirs_prob": 1.0 if use_calibration else 0.0,
|
| 1775 |
+
"depth_prob": 1.0 if use_depth else 0.0,
|
| 1776 |
+
"cam_prob": 1.0 if use_pose else 0.0,
|
| 1777 |
+
"sparse_depth_prob": 0.0, # No sparsification during inference
|
| 1778 |
+
"depth_scale_norm_all_prob": 0.0 if use_depth_scale else 1.0,
|
| 1779 |
+
"pose_scale_norm_all_prob": 0.0 if use_pose_scale else 1.0,
|
| 1780 |
+
}
|
| 1781 |
+
)
|
| 1782 |
+
|
| 1783 |
+
def _restore_original_geometric_input_config(self):
|
| 1784 |
+
"""
|
| 1785 |
+
Restore original geometric input configuration
|
| 1786 |
+
"""
|
| 1787 |
+
if hasattr(self, "_original_geometric_config"):
|
| 1788 |
+
self.geometric_input_config.update(self._original_geometric_config)
|
| 1789 |
+
|
| 1790 |
+
@torch.inference_mode()
|
| 1791 |
+
def infer(
|
| 1792 |
+
self,
|
| 1793 |
+
views: List[Dict[str, Any]],
|
| 1794 |
+
use_amp: bool = True,
|
| 1795 |
+
amp_dtype: str = "bf16",
|
| 1796 |
+
apply_mask: bool = True,
|
| 1797 |
+
mask_edges: bool = True,
|
| 1798 |
+
edge_normal_threshold: float = 5.0,
|
| 1799 |
+
edge_depth_threshold: float = 0.03,
|
| 1800 |
+
apply_confidence_mask: bool = False,
|
| 1801 |
+
confidence_percentile: float = 10,
|
| 1802 |
+
ignore_calibration_inputs: bool = False,
|
| 1803 |
+
ignore_depth_inputs: bool = False,
|
| 1804 |
+
ignore_pose_inputs: bool = False,
|
| 1805 |
+
ignore_depth_scale_inputs: bool = False,
|
| 1806 |
+
ignore_pose_scale_inputs: bool = False,
|
| 1807 |
+
) -> List[Dict[str, torch.Tensor]]:
|
| 1808 |
+
"""
|
| 1809 |
+
User-friendly inference with strict input validation and automatic conversion.
|
| 1810 |
+
|
| 1811 |
+
Args:
|
| 1812 |
+
views: List of view dictionaries. Each dict can contain:
|
| 1813 |
+
Required:
|
| 1814 |
+
- 'img': torch.Tensor of shape (B, 3, H, W) - normalized RGB images
|
| 1815 |
+
- 'data_norm_type': str - normalization type used to normalize the images (must be equal to self.model.encoder.data_norm_type)
|
| 1816 |
+
|
| 1817 |
+
Optional Geometric Inputs (only one of intrinsics OR ray_directions):
|
| 1818 |
+
- 'intrinsics': torch.Tensor of shape (B, 3, 3) - will be converted to ray directions
|
| 1819 |
+
- 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame
|
| 1820 |
+
- 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame (intrinsics or ray_directions must be provided)
|
| 1821 |
+
- 'camera_poses': torch.Tensor of shape (B, 4, 4) or tuple of (quats - (B, 4), trans - (B, 3)) - can be any world frame
|
| 1822 |
+
- 'is_metric_scale': bool or torch.Tensor of shape (B,) - if not provided, defaults to True
|
| 1823 |
+
|
| 1824 |
+
Optional Additional Info:
|
| 1825 |
+
- 'instance': List[str] where length of list is B - instance info for each view
|
| 1826 |
+
- 'idx': List[int] where length of list is B - index info for each view
|
| 1827 |
+
- 'true_shape': List[tuple] where length of list is B - true shape info (H, W) for each view
|
| 1828 |
+
|
| 1829 |
+
use_amp: Whether to use automatic mixed precision for faster inference. Defaults to True.
|
| 1830 |
+
amp_dtype: The dtype to use for mixed precision. Defaults to "bf16" (bfloat16). Options: "fp16", "bf16", "fp32".
|
| 1831 |
+
apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
|
| 1832 |
+
mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
|
| 1833 |
+
edge_normal_threshold: Tolerance threshold for normals-based edge detection. Defaults to 5.0.
|
| 1834 |
+
edge_depth_threshold: Relative tolerance threshold for depth-based edge detection. Defaults to 0.03.
|
| 1835 |
+
apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False.
|
| 1836 |
+
confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10.
|
| 1837 |
+
ignore_calibration_inputs: Whether to ignore the calibration inputs (intrinsics and ray_directions). Defaults to False.
|
| 1838 |
+
ignore_depth_inputs: Whether to ignore the depth inputs. Defaults to False.
|
| 1839 |
+
ignore_pose_inputs: Whether to ignore the pose inputs. Defaults to False.
|
| 1840 |
+
ignore_depth_scale_inputs: Whether to ignore the depth scale inputs. Defaults to False.
|
| 1841 |
+
ignore_pose_scale_inputs: Whether to ignore the pose scale inputs. Defaults to False.
|
| 1842 |
+
|
| 1843 |
+
IMPORTANT CONSTRAINTS:
|
| 1844 |
+
- Cannot provide both 'intrinsics' and 'ray_directions' (they represent the same information)
|
| 1845 |
+
- If 'depth' is provided, then 'intrinsics' or 'ray_directions' must also be provided
|
| 1846 |
+
- If ANY view has 'camera_poses', then view 0 (first view) MUST also have 'camera_poses'
|
| 1847 |
+
|
| 1848 |
+
Returns:
|
| 1849 |
+
List of prediction dictionaries, one per view. Each dict contains:
|
| 1850 |
+
- 'img_no_norm': torch.Tensor of shape (B, H, W, 3) - denormalized rgb images
|
| 1851 |
+
- 'pts3d': torch.Tensor of shape (B, H, W, 3) - predicted points in world frame
|
| 1852 |
+
- 'pts3d_cam': torch.Tensor of shape (B, H, W, 3) - predicted points in camera frame
|
| 1853 |
+
- 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame
|
| 1854 |
+
- 'intrinsics': torch.Tensor of shape (B, 3, 3) - pinhole camera intrinsics recovered from ray directions
|
| 1855 |
+
- 'depth_along_ray': torch.Tensor of shape (B, H, W, 1) - depth along ray in camera frame
|
| 1856 |
+
- 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame
|
| 1857 |
+
- 'cam_trans': torch.Tensor of shape (B, 3) - camera translation in world frame
|
| 1858 |
+
- 'cam_quats': torch.Tensor of shape (B, 4) - camera quaternion in world frame
|
| 1859 |
+
- 'camera_poses': torch.Tensor of shape (B, 4, 4) - camera pose in world frame
|
| 1860 |
+
- 'metric_scaling_factor': torch.Tensor of shape (B,) - applied metric scaling factor
|
| 1861 |
+
- 'mask': torch.Tensor of shape (B, H, W, 1) - combo of non-ambiguous mask, edge mask and confidence-based mask if used
|
| 1862 |
+
- 'non_ambiguous_mask': torch.Tensor of shape (B, H, W) - non-ambiguous mask
|
| 1863 |
+
- 'non_ambiguous_mask_logits': torch.Tensor of shape (B, H, W) - non-ambiguous mask logits
|
| 1864 |
+
- 'conf': torch.Tensor of shape (B, H, W) - confidence
|
| 1865 |
+
|
| 1866 |
+
Raises:
|
| 1867 |
+
ValueError: For invalid inputs, missing required keys, conflicting modalities, or constraint violations
|
| 1868 |
+
"""
|
| 1869 |
+
# Determine the mixed precision floating point type
|
| 1870 |
+
if use_amp:
|
| 1871 |
+
if amp_dtype == "fp16":
|
| 1872 |
+
amp_dtype = torch.float16
|
| 1873 |
+
elif amp_dtype == "bf16":
|
| 1874 |
+
if torch.cuda.is_bf16_supported():
|
| 1875 |
+
amp_dtype = torch.bfloat16
|
| 1876 |
+
else:
|
| 1877 |
+
warnings.warn(
|
| 1878 |
+
"bf16 is not supported on this device. Using fp16 instead."
|
| 1879 |
+
)
|
| 1880 |
+
amp_dtype = torch.float16
|
| 1881 |
+
elif amp_dtype == "fp32":
|
| 1882 |
+
amp_dtype = torch.float32
|
| 1883 |
+
else:
|
| 1884 |
+
amp_dtype = torch.float32
|
| 1885 |
+
|
| 1886 |
+
# Validate the input views
|
| 1887 |
+
validated_views = validate_input_views_for_inference(views)
|
| 1888 |
+
|
| 1889 |
+
# Transfer the views to the same device as the model
|
| 1890 |
+
ignore_keys = set(
|
| 1891 |
+
[
|
| 1892 |
+
"instance",
|
| 1893 |
+
"idx",
|
| 1894 |
+
"true_shape",
|
| 1895 |
+
"data_norm_type",
|
| 1896 |
+
]
|
| 1897 |
+
)
|
| 1898 |
+
for view in validated_views:
|
| 1899 |
+
for name in view.keys():
|
| 1900 |
+
if name in ignore_keys:
|
| 1901 |
+
continue
|
| 1902 |
+
view[name] = view[name].to(self.device, non_blocking=True)
|
| 1903 |
+
|
| 1904 |
+
# Pre-process the input views
|
| 1905 |
+
processed_views = preprocess_input_views_for_inference(validated_views)
|
| 1906 |
+
|
| 1907 |
+
# Set the model input probabilities based on input args for ignoring inputs
|
| 1908 |
+
self._configure_geometric_input_config(
|
| 1909 |
+
use_calibration=not ignore_calibration_inputs,
|
| 1910 |
+
use_depth=not ignore_depth_inputs,
|
| 1911 |
+
use_pose=not ignore_pose_inputs,
|
| 1912 |
+
use_depth_scale=not ignore_depth_scale_inputs,
|
| 1913 |
+
use_pose_scale=not ignore_pose_scale_inputs,
|
| 1914 |
+
)
|
| 1915 |
+
|
| 1916 |
+
# Run the model
|
| 1917 |
+
with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
|
| 1918 |
+
preds = self.forward(processed_views)
|
| 1919 |
+
|
| 1920 |
+
# Post-process the model outputs
|
| 1921 |
+
preds = postprocess_model_outputs_for_inference(
|
| 1922 |
+
raw_outputs=preds,
|
| 1923 |
+
input_views=processed_views,
|
| 1924 |
+
apply_mask=apply_mask,
|
| 1925 |
+
mask_edges=mask_edges,
|
| 1926 |
+
edge_normal_threshold=edge_normal_threshold,
|
| 1927 |
+
edge_depth_threshold=edge_depth_threshold,
|
| 1928 |
+
apply_confidence_mask=apply_confidence_mask,
|
| 1929 |
+
confidence_percentile=confidence_percentile,
|
| 1930 |
+
)
|
| 1931 |
+
|
| 1932 |
+
# Restore the original configuration
|
| 1933 |
+
self._restore_original_geometric_input_config()
|
| 1934 |
+
|
| 1935 |
+
return preds
|
mapanything/models/mapanything/modular_dust3r.py
CHANGED
|
@@ -99,8 +99,10 @@ class ModularDUSt3R(nn.Module):
|
|
| 99 |
# Initialize Encoder
|
| 100 |
if self.encoder_config["uses_torch_hub"]:
|
| 101 |
self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
| 104 |
|
| 105 |
# Initialize Custom Positional Encoding if required
|
| 106 |
if custom_positional_encoding is not None:
|
|
|
|
| 99 |
# Initialize Encoder
|
| 100 |
if self.encoder_config["uses_torch_hub"]:
|
| 101 |
self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
|
| 102 |
+
# Create a copy of the config before deleting the key to preserve it for serialization
|
| 103 |
+
encoder_config_copy = self.encoder_config.copy()
|
| 104 |
+
del encoder_config_copy["uses_torch_hub"]
|
| 105 |
+
self.encoder = encoder_factory(**encoder_config_copy)
|
| 106 |
|
| 107 |
# Initialize Custom Positional Encoding if required
|
| 108 |
if custom_positional_encoding is not None:
|
mapanything/train/losses.py
CHANGED
|
@@ -1766,6 +1766,202 @@ class PointsPlusScaleRegr3D(Criterion, MultiLoss):
|
|
| 1766 |
return losses, (details | {})
|
| 1767 |
|
| 1768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1769 |
class FactoredGeometryRegr3D(Criterion, MultiLoss):
|
| 1770 |
"""
|
| 1771 |
Regression Loss for Factored Geometry.
|
|
@@ -1787,6 +1983,7 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
|
|
| 1787 |
pose_quats_loss_weight=1,
|
| 1788 |
pose_trans_loss_weight=1,
|
| 1789 |
compute_pairwise_relative_pose_loss=False,
|
|
|
|
| 1790 |
compute_world_frame_points_loss=True,
|
| 1791 |
world_frame_points_loss_weight=1,
|
| 1792 |
):
|
|
@@ -1821,6 +2018,8 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
|
|
| 1821 |
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
|
| 1822 |
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
|
| 1823 |
exhaustive pairwise relative poses. Default: False.
|
|
|
|
|
|
|
| 1824 |
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
|
| 1825 |
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
|
| 1826 |
"""
|
|
@@ -1847,6 +2046,7 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
|
|
| 1847 |
self.pose_quats_loss_weight = pose_quats_loss_weight
|
| 1848 |
self.pose_trans_loss_weight = pose_trans_loss_weight
|
| 1849 |
self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
|
|
|
|
| 1850 |
self.compute_world_frame_points_loss = compute_world_frame_points_loss
|
| 1851 |
self.world_frame_points_loss_weight = world_frame_points_loss_weight
|
| 1852 |
|
|
@@ -1869,6 +2069,19 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
|
|
| 1869 |
gt_ray_directions = []
|
| 1870 |
gt_pose_quats = []
|
| 1871 |
# Predicted quantities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1872 |
no_norm_pr_pts = []
|
| 1873 |
no_norm_pr_pts_cam = []
|
| 1874 |
no_norm_pr_depth = []
|
|
@@ -1922,16 +2135,34 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
|
|
| 1922 |
gt_pose_quats.append(gt_pose_quats_in_view0)
|
| 1923 |
no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
|
| 1924 |
|
| 1925 |
-
# Get predictions
|
| 1926 |
-
no_norm_pr_pts.append(preds[i]["pts3d"])
|
| 1927 |
no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"])
|
| 1928 |
pr_ray_directions.append(preds[i]["ray_directions"])
|
| 1929 |
if self.depth_type_for_loss == "depth_along_ray":
|
| 1930 |
no_norm_pr_depth.append(preds[i]["depth_along_ray"])
|
| 1931 |
elif self.depth_type_for_loss == "depth_z":
|
| 1932 |
no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:])
|
| 1933 |
-
|
| 1934 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1935 |
|
| 1936 |
if dist_clip is not None:
|
| 1937 |
# Points that are too far-away == invalid
|
|
@@ -2443,6 +2674,7 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
|
|
| 2443 |
pose_quats_loss_weight=1,
|
| 2444 |
pose_trans_loss_weight=1,
|
| 2445 |
compute_pairwise_relative_pose_loss=False,
|
|
|
|
| 2446 |
compute_world_frame_points_loss=True,
|
| 2447 |
world_frame_points_loss_weight=1,
|
| 2448 |
apply_normal_and_gm_loss_to_synthetic_data_only=True,
|
|
@@ -2478,6 +2710,8 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
|
|
| 2478 |
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
|
| 2479 |
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
|
| 2480 |
exhaustive pairwise relative poses. Default: False.
|
|
|
|
|
|
|
| 2481 |
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
|
| 2482 |
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
|
| 2483 |
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
|
|
@@ -2500,6 +2734,7 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
|
|
| 2500 |
pose_quats_loss_weight=pose_quats_loss_weight,
|
| 2501 |
pose_trans_loss_weight=pose_trans_loss_weight,
|
| 2502 |
compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
|
|
|
|
| 2503 |
compute_world_frame_points_loss=compute_world_frame_points_loss,
|
| 2504 |
world_frame_points_loss_weight=world_frame_points_loss_weight,
|
| 2505 |
)
|
|
@@ -2895,6 +3130,7 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
|
|
| 2895 |
pose_trans_loss_weight=1,
|
| 2896 |
scale_loss_weight=1,
|
| 2897 |
compute_pairwise_relative_pose_loss=False,
|
|
|
|
| 2898 |
compute_world_frame_points_loss=True,
|
| 2899 |
world_frame_points_loss_weight=1,
|
| 2900 |
):
|
|
@@ -2928,6 +3164,8 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
|
|
| 2928 |
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
|
| 2929 |
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
|
| 2930 |
exhaustive pairwise relative poses. Default: False.
|
|
|
|
|
|
|
| 2931 |
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
|
| 2932 |
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
|
| 2933 |
"""
|
|
@@ -2948,6 +3186,7 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
|
|
| 2948 |
self.pose_trans_loss_weight = pose_trans_loss_weight
|
| 2949 |
self.scale_loss_weight = scale_loss_weight
|
| 2950 |
self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
|
|
|
|
| 2951 |
self.compute_world_frame_points_loss = compute_world_frame_points_loss
|
| 2952 |
self.world_frame_points_loss_weight = world_frame_points_loss_weight
|
| 2953 |
|
|
@@ -2970,6 +3209,19 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
|
|
| 2970 |
gt_ray_directions = []
|
| 2971 |
gt_pose_quats = []
|
| 2972 |
# Predicted quantities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2973 |
no_norm_pr_pts = []
|
| 2974 |
no_norm_pr_pts_cam = []
|
| 2975 |
no_norm_pr_depth = []
|
|
@@ -3024,6 +3276,24 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
|
|
| 3024 |
gt_pose_quats.append(gt_pose_quats_in_view0)
|
| 3025 |
no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
|
| 3026 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3027 |
# Get predictions for normalized loss
|
| 3028 |
if self.depth_type_for_loss == "depth_along_ray":
|
| 3029 |
curr_view_no_norm_depth = preds[i]["depth_along_ray"]
|
|
@@ -3032,7 +3302,7 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
|
|
| 3032 |
if "metric_scaling_factor" in preds[i].keys():
|
| 3033 |
# Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
|
| 3034 |
# This detaches the predicted metric scaling factor from the geometry based loss
|
| 3035 |
-
curr_view_no_norm_pr_pts =
|
| 3036 |
"metric_scaling_factor"
|
| 3037 |
].unsqueeze(-1).unsqueeze(-1)
|
| 3038 |
curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
|
|
@@ -3042,19 +3312,19 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
|
|
| 3042 |
"metric_scaling_factor"
|
| 3043 |
].unsqueeze(-1).unsqueeze(-1)
|
| 3044 |
curr_view_no_norm_pr_pose_trans = (
|
| 3045 |
-
|
| 3046 |
)
|
| 3047 |
else:
|
| 3048 |
-
curr_view_no_norm_pr_pts =
|
| 3049 |
curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
|
| 3050 |
curr_view_no_norm_depth = curr_view_no_norm_depth
|
| 3051 |
-
curr_view_no_norm_pr_pose_trans =
|
| 3052 |
no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
|
| 3053 |
no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
|
| 3054 |
no_norm_pr_depth.append(curr_view_no_norm_depth)
|
| 3055 |
no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
|
| 3056 |
pr_ray_directions.append(preds[i]["ray_directions"])
|
| 3057 |
-
pr_pose_quats.append(
|
| 3058 |
|
| 3059 |
# Get the predicted metric scale points
|
| 3060 |
if "metric_scaling_factor" in preds[i].keys():
|
|
@@ -3553,6 +3823,7 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
|
|
| 3553 |
pose_trans_loss_weight=1,
|
| 3554 |
scale_loss_weight=1,
|
| 3555 |
compute_pairwise_relative_pose_loss=False,
|
|
|
|
| 3556 |
compute_world_frame_points_loss=True,
|
| 3557 |
world_frame_points_loss_weight=1,
|
| 3558 |
apply_normal_and_gm_loss_to_synthetic_data_only=True,
|
|
@@ -3585,6 +3856,8 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
|
|
| 3585 |
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
|
| 3586 |
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
|
| 3587 |
exhaustive pairwise relative poses. Default: False.
|
|
|
|
|
|
|
| 3588 |
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
|
| 3589 |
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
|
| 3590 |
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
|
|
@@ -3607,6 +3880,7 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
|
|
| 3607 |
pose_trans_loss_weight=pose_trans_loss_weight,
|
| 3608 |
scale_loss_weight=scale_loss_weight,
|
| 3609 |
compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
|
|
|
|
| 3610 |
compute_world_frame_points_loss=compute_world_frame_points_loss,
|
| 3611 |
world_frame_points_loss_weight=world_frame_points_loss_weight,
|
| 3612 |
)
|
|
|
|
| 1766 |
return losses, (details | {})
|
| 1767 |
|
| 1768 |
|
| 1769 |
+
class NormalGMLoss(MultiLoss):
|
| 1770 |
+
"""
|
| 1771 |
+
Normal & Gradient Matching Loss for Monocular Depth Training.
|
| 1772 |
+
"""
|
| 1773 |
+
|
| 1774 |
+
def __init__(
|
| 1775 |
+
self,
|
| 1776 |
+
norm_predictions=True,
|
| 1777 |
+
norm_mode="avg_dis",
|
| 1778 |
+
apply_normal_and_gm_loss_to_synthetic_data_only=True,
|
| 1779 |
+
):
|
| 1780 |
+
"""
|
| 1781 |
+
Initialize the loss criterion for Normal & Gradient Matching Loss (currently only valid for 1 view).
|
| 1782 |
+
Computes:
|
| 1783 |
+
(1) Normal Loss over the PointMap (naturally will be in local frame) in euclidean coordinates,
|
| 1784 |
+
(2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
|
| 1785 |
+
|
| 1786 |
+
Args:
|
| 1787 |
+
norm_predictions (bool): If True, normalize the predictions before computing the loss.
|
| 1788 |
+
norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
|
| 1789 |
+
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
|
| 1790 |
+
If False, apply the normal and gm loss to all data. Default: True.
|
| 1791 |
+
"""
|
| 1792 |
+
super().__init__()
|
| 1793 |
+
self.norm_predictions = norm_predictions
|
| 1794 |
+
self.norm_mode = norm_mode
|
| 1795 |
+
self.apply_normal_and_gm_loss_to_synthetic_data_only = (
|
| 1796 |
+
apply_normal_and_gm_loss_to_synthetic_data_only
|
| 1797 |
+
)
|
| 1798 |
+
|
| 1799 |
+
def get_all_info(self, batch, preds, dist_clip=None):
|
| 1800 |
+
"""
|
| 1801 |
+
Function to get all the information needed to compute the loss.
|
| 1802 |
+
Returns all quantities normalized.
|
| 1803 |
+
"""
|
| 1804 |
+
n_views = len(batch)
|
| 1805 |
+
assert n_views == 1, (
|
| 1806 |
+
"Normal & Gradient Matching Loss Class only supports 1 view"
|
| 1807 |
+
)
|
| 1808 |
+
|
| 1809 |
+
# Everything is normalized w.r.t. camera of view1
|
| 1810 |
+
in_camera1 = closed_form_pose_inverse(batch[0]["camera_pose"])
|
| 1811 |
+
|
| 1812 |
+
# Initialize lists to store data for all views
|
| 1813 |
+
no_norm_gt_pts = []
|
| 1814 |
+
valid_masks = []
|
| 1815 |
+
no_norm_pr_pts = []
|
| 1816 |
+
|
| 1817 |
+
# Get ground truth & prediction info for all views
|
| 1818 |
+
for i in range(n_views):
|
| 1819 |
+
# Get ground truth
|
| 1820 |
+
no_norm_gt_pts.append(geotrf(in_camera1, batch[i]["pts3d"]))
|
| 1821 |
+
valid_masks.append(batch[i]["valid_mask"].clone())
|
| 1822 |
+
|
| 1823 |
+
# Get predictions for normalized loss
|
| 1824 |
+
if "metric_scaling_factor" in preds[i].keys():
|
| 1825 |
+
# Divide by the predicted metric scaling factor to get the raw predicted points
|
| 1826 |
+
# This detaches the predicted metric scaling factor from the geometry based loss
|
| 1827 |
+
curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
|
| 1828 |
+
"metric_scaling_factor"
|
| 1829 |
+
].unsqueeze(-1).unsqueeze(-1)
|
| 1830 |
+
else:
|
| 1831 |
+
curr_view_no_norm_pr_pts = preds[i]["pts3d"]
|
| 1832 |
+
no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
|
| 1833 |
+
|
| 1834 |
+
if dist_clip is not None:
|
| 1835 |
+
# Points that are too far-away == invalid
|
| 1836 |
+
for i in range(n_views):
|
| 1837 |
+
dis = no_norm_gt_pts[i].norm(dim=-1)
|
| 1838 |
+
valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
|
| 1839 |
+
|
| 1840 |
+
# Initialize normalized tensors
|
| 1841 |
+
gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
|
| 1842 |
+
pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
|
| 1843 |
+
|
| 1844 |
+
# Normalize the predicted points if specified
|
| 1845 |
+
if self.norm_predictions:
|
| 1846 |
+
pr_normalization_output = normalize_multiple_pointclouds(
|
| 1847 |
+
no_norm_pr_pts,
|
| 1848 |
+
valid_masks,
|
| 1849 |
+
self.norm_mode,
|
| 1850 |
+
ret_factor=True,
|
| 1851 |
+
)
|
| 1852 |
+
pr_pts_norm = pr_normalization_output[:-1]
|
| 1853 |
+
|
| 1854 |
+
# Normalize the ground truth points
|
| 1855 |
+
gt_normalization_output = normalize_multiple_pointclouds(
|
| 1856 |
+
no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
|
| 1857 |
+
)
|
| 1858 |
+
gt_pts_norm = gt_normalization_output[:-1]
|
| 1859 |
+
|
| 1860 |
+
for i in range(n_views):
|
| 1861 |
+
if self.norm_predictions:
|
| 1862 |
+
# Assign the normalized predictions
|
| 1863 |
+
pr_pts[i] = pr_pts_norm[i]
|
| 1864 |
+
else:
|
| 1865 |
+
# Assign the raw predicted points
|
| 1866 |
+
pr_pts[i] = no_norm_pr_pts[i]
|
| 1867 |
+
# Assign the normalized ground truth
|
| 1868 |
+
gt_pts[i] = gt_pts_norm[i]
|
| 1869 |
+
|
| 1870 |
+
return gt_pts, pr_pts, valid_masks
|
| 1871 |
+
|
| 1872 |
+
def compute_loss(self, batch, preds, **kw):
|
| 1873 |
+
gt_pts, pred_pts, valid_masks = self.get_all_info(batch, preds, **kw)
|
| 1874 |
+
n_views = len(batch)
|
| 1875 |
+
assert n_views == 1, (
|
| 1876 |
+
"Normal & Gradient Matching Loss Class only supports 1 view"
|
| 1877 |
+
)
|
| 1878 |
+
|
| 1879 |
+
normal_losses = []
|
| 1880 |
+
gradient_matching_losses = []
|
| 1881 |
+
details = {}
|
| 1882 |
+
running_avg_dict = {}
|
| 1883 |
+
self_name = type(self).__name__
|
| 1884 |
+
|
| 1885 |
+
for i in range(n_views):
|
| 1886 |
+
# Get the local frame points, log space depth_z & valid masks
|
| 1887 |
+
pred_local_pts3d = pred_pts[i]
|
| 1888 |
+
pred_depth_z = pred_local_pts3d[..., 2:]
|
| 1889 |
+
pred_depth_z = apply_log_to_norm(pred_depth_z)
|
| 1890 |
+
gt_local_pts3d = gt_pts[i]
|
| 1891 |
+
gt_depth_z = gt_local_pts3d[..., 2:]
|
| 1892 |
+
gt_depth_z = apply_log_to_norm(gt_depth_z)
|
| 1893 |
+
valid_mask_for_normal_gm_loss = valid_masks[i].clone()
|
| 1894 |
+
|
| 1895 |
+
# Update the validity mask for normal & gm loss based on the synthetic data mask if required
|
| 1896 |
+
if self.apply_normal_and_gm_loss_to_synthetic_data_only:
|
| 1897 |
+
synthetic_mask = batch[i]["is_synthetic"] # (B, )
|
| 1898 |
+
synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
|
| 1899 |
+
synthetic_mask = synthetic_mask.expand(
|
| 1900 |
+
-1, pred_depth_z.shape[1], pred_depth_z.shape[2]
|
| 1901 |
+
) # (B, H, W)
|
| 1902 |
+
valid_mask_for_normal_gm_loss = (
|
| 1903 |
+
valid_mask_for_normal_gm_loss & synthetic_mask
|
| 1904 |
+
)
|
| 1905 |
+
|
| 1906 |
+
# Compute the normal loss
|
| 1907 |
+
normal_loss = compute_normal_loss(
|
| 1908 |
+
pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
|
| 1909 |
+
)
|
| 1910 |
+
normal_losses.append(normal_loss)
|
| 1911 |
+
|
| 1912 |
+
# Compute the gradient matching loss
|
| 1913 |
+
gradient_matching_loss = compute_gradient_matching_loss(
|
| 1914 |
+
pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
|
| 1915 |
+
)
|
| 1916 |
+
gradient_matching_losses.append(gradient_matching_loss)
|
| 1917 |
+
|
| 1918 |
+
# Add loss details if only valid values are present
|
| 1919 |
+
# Initialize or update running average directly
|
| 1920 |
+
# Normal loss details
|
| 1921 |
+
if float(normal_loss) > 0:
|
| 1922 |
+
details[f"{self_name}_normal_view{i + 1}"] = float(normal_loss)
|
| 1923 |
+
normal_avg_key = f"{self_name}_normal_avg"
|
| 1924 |
+
if normal_avg_key not in details:
|
| 1925 |
+
details[normal_avg_key] = float(normal_losses[i])
|
| 1926 |
+
running_avg_dict[f"{self_name}_normal_valid_views"] = 1
|
| 1927 |
+
else:
|
| 1928 |
+
normal_valid_views = (
|
| 1929 |
+
running_avg_dict[f"{self_name}_normal_valid_views"] + 1
|
| 1930 |
+
)
|
| 1931 |
+
running_avg_dict[f"{self_name}_normal_valid_views"] = (
|
| 1932 |
+
normal_valid_views
|
| 1933 |
+
)
|
| 1934 |
+
details[normal_avg_key] += (
|
| 1935 |
+
float(normal_losses[i]) - details[normal_avg_key]
|
| 1936 |
+
) / normal_valid_views
|
| 1937 |
+
|
| 1938 |
+
# Gradient Matching loss details
|
| 1939 |
+
if float(gradient_matching_loss) > 0:
|
| 1940 |
+
details[f"{self_name}_gradient_matching_view{i + 1}"] = float(
|
| 1941 |
+
gradient_matching_loss
|
| 1942 |
+
)
|
| 1943 |
+
# For gradient matching loss
|
| 1944 |
+
gm_avg_key = f"{self_name}_gradient_matching_avg"
|
| 1945 |
+
if gm_avg_key not in details:
|
| 1946 |
+
details[gm_avg_key] = float(gradient_matching_losses[i])
|
| 1947 |
+
running_avg_dict[f"{self_name}_gm_valid_views"] = 1
|
| 1948 |
+
else:
|
| 1949 |
+
gm_valid_views = running_avg_dict[f"{self_name}_gm_valid_views"] + 1
|
| 1950 |
+
running_avg_dict[f"{self_name}_gm_valid_views"] = gm_valid_views
|
| 1951 |
+
details[gm_avg_key] += (
|
| 1952 |
+
float(gradient_matching_losses[i]) - details[gm_avg_key]
|
| 1953 |
+
) / gm_valid_views
|
| 1954 |
+
|
| 1955 |
+
# Put the losses together
|
| 1956 |
+
loss_terms = []
|
| 1957 |
+
for i in range(n_views):
|
| 1958 |
+
loss_terms.append((normal_losses[i], None, "normal"))
|
| 1959 |
+
loss_terms.append((gradient_matching_losses[i], None, "gradient_matching"))
|
| 1960 |
+
losses = Sum(*loss_terms)
|
| 1961 |
+
|
| 1962 |
+
return losses, details
|
| 1963 |
+
|
| 1964 |
+
|
| 1965 |
class FactoredGeometryRegr3D(Criterion, MultiLoss):
|
| 1966 |
"""
|
| 1967 |
Regression Loss for Factored Geometry.
|
|
|
|
| 1983 |
pose_quats_loss_weight=1,
|
| 1984 |
pose_trans_loss_weight=1,
|
| 1985 |
compute_pairwise_relative_pose_loss=False,
|
| 1986 |
+
convert_predictions_to_view0_frame=False,
|
| 1987 |
compute_world_frame_points_loss=True,
|
| 1988 |
world_frame_points_loss_weight=1,
|
| 1989 |
):
|
|
|
|
| 2018 |
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
|
| 2019 |
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
|
| 2020 |
exhaustive pairwise relative poses. Default: False.
|
| 2021 |
+
convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
|
| 2022 |
+
Use this if the predictions are not already in the view0 frame. Default: False.
|
| 2023 |
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
|
| 2024 |
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
|
| 2025 |
"""
|
|
|
|
| 2046 |
self.pose_quats_loss_weight = pose_quats_loss_weight
|
| 2047 |
self.pose_trans_loss_weight = pose_trans_loss_weight
|
| 2048 |
self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
|
| 2049 |
+
self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame
|
| 2050 |
self.compute_world_frame_points_loss = compute_world_frame_points_loss
|
| 2051 |
self.world_frame_points_loss_weight = world_frame_points_loss_weight
|
| 2052 |
|
|
|
|
| 2069 |
gt_ray_directions = []
|
| 2070 |
gt_pose_quats = []
|
| 2071 |
# Predicted quantities
|
| 2072 |
+
if self.convert_predictions_to_view0_frame:
|
| 2073 |
+
# Get the camera transform to convert quantities to view0 frame
|
| 2074 |
+
pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze(
|
| 2075 |
+
0
|
| 2076 |
+
)
|
| 2077 |
+
batch_size = preds[0]["cam_quats"].shape[0]
|
| 2078 |
+
pred_camera0 = pred_camera0.repeat(batch_size, 1, 1)
|
| 2079 |
+
pred_camera0_rot = quaternion_to_rotation_matrix(
|
| 2080 |
+
preds[0]["cam_quats"].clone()
|
| 2081 |
+
)
|
| 2082 |
+
pred_camera0[..., :3, :3] = pred_camera0_rot
|
| 2083 |
+
pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone()
|
| 2084 |
+
pred_in_camera0 = closed_form_pose_inverse(pred_camera0)
|
| 2085 |
no_norm_pr_pts = []
|
| 2086 |
no_norm_pr_pts_cam = []
|
| 2087 |
no_norm_pr_depth = []
|
|
|
|
| 2135 |
gt_pose_quats.append(gt_pose_quats_in_view0)
|
| 2136 |
no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
|
| 2137 |
|
| 2138 |
+
# Get the local predictions
|
|
|
|
| 2139 |
no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"])
|
| 2140 |
pr_ray_directions.append(preds[i]["ray_directions"])
|
| 2141 |
if self.depth_type_for_loss == "depth_along_ray":
|
| 2142 |
no_norm_pr_depth.append(preds[i]["depth_along_ray"])
|
| 2143 |
elif self.depth_type_for_loss == "depth_z":
|
| 2144 |
no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:])
|
| 2145 |
+
|
| 2146 |
+
# Get the predicted global predictions in view0's frame
|
| 2147 |
+
if self.convert_predictions_to_view0_frame:
|
| 2148 |
+
# Convert predictions to view0 frame
|
| 2149 |
+
pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"])
|
| 2150 |
+
pr_pose_quats_in_view0, pr_pose_trans_in_view0 = (
|
| 2151 |
+
transform_pose_using_quats_and_trans_2_to_1(
|
| 2152 |
+
preds[0]["cam_quats"],
|
| 2153 |
+
preds[0]["cam_trans"],
|
| 2154 |
+
preds[i]["cam_quats"],
|
| 2155 |
+
preds[i]["cam_trans"],
|
| 2156 |
+
)
|
| 2157 |
+
)
|
| 2158 |
+
no_norm_pr_pts.append(pr_pts3d_in_view0)
|
| 2159 |
+
no_norm_pr_pose_trans.append(pr_pose_trans_in_view0)
|
| 2160 |
+
pr_pose_quats.append(pr_pose_quats_in_view0)
|
| 2161 |
+
else:
|
| 2162 |
+
# Predictions are already in view0 frame
|
| 2163 |
+
no_norm_pr_pts.append(preds[i]["pts3d"])
|
| 2164 |
+
no_norm_pr_pose_trans.append(preds[i]["cam_trans"])
|
| 2165 |
+
pr_pose_quats.append(preds[i]["cam_quats"])
|
| 2166 |
|
| 2167 |
if dist_clip is not None:
|
| 2168 |
# Points that are too far-away == invalid
|
|
|
|
| 2674 |
pose_quats_loss_weight=1,
|
| 2675 |
pose_trans_loss_weight=1,
|
| 2676 |
compute_pairwise_relative_pose_loss=False,
|
| 2677 |
+
convert_predictions_to_view0_frame=False,
|
| 2678 |
compute_world_frame_points_loss=True,
|
| 2679 |
world_frame_points_loss_weight=1,
|
| 2680 |
apply_normal_and_gm_loss_to_synthetic_data_only=True,
|
|
|
|
| 2710 |
pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
|
| 2711 |
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
|
| 2712 |
exhaustive pairwise relative poses. Default: False.
|
| 2713 |
+
convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
|
| 2714 |
+
Use this if the predictions are not already in the view0 frame. Default: False.
|
| 2715 |
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
|
| 2716 |
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
|
| 2717 |
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
|
|
|
|
| 2734 |
pose_quats_loss_weight=pose_quats_loss_weight,
|
| 2735 |
pose_trans_loss_weight=pose_trans_loss_weight,
|
| 2736 |
compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
|
| 2737 |
+
convert_predictions_to_view0_frame=convert_predictions_to_view0_frame,
|
| 2738 |
compute_world_frame_points_loss=compute_world_frame_points_loss,
|
| 2739 |
world_frame_points_loss_weight=world_frame_points_loss_weight,
|
| 2740 |
)
|
|
|
|
| 3130 |
pose_trans_loss_weight=1,
|
| 3131 |
scale_loss_weight=1,
|
| 3132 |
compute_pairwise_relative_pose_loss=False,
|
| 3133 |
+
convert_predictions_to_view0_frame=False,
|
| 3134 |
compute_world_frame_points_loss=True,
|
| 3135 |
world_frame_points_loss_weight=1,
|
| 3136 |
):
|
|
|
|
| 3164 |
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
|
| 3165 |
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
|
| 3166 |
exhaustive pairwise relative poses. Default: False.
|
| 3167 |
+
convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
|
| 3168 |
+
Use this if the predictions are not already in the view0 frame. Default: False.
|
| 3169 |
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
|
| 3170 |
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
|
| 3171 |
"""
|
|
|
|
| 3186 |
self.pose_trans_loss_weight = pose_trans_loss_weight
|
| 3187 |
self.scale_loss_weight = scale_loss_weight
|
| 3188 |
self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
|
| 3189 |
+
self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame
|
| 3190 |
self.compute_world_frame_points_loss = compute_world_frame_points_loss
|
| 3191 |
self.world_frame_points_loss_weight = world_frame_points_loss_weight
|
| 3192 |
|
|
|
|
| 3209 |
gt_ray_directions = []
|
| 3210 |
gt_pose_quats = []
|
| 3211 |
# Predicted quantities
|
| 3212 |
+
if self.convert_predictions_to_view0_frame:
|
| 3213 |
+
# Get the camera transform to convert quantities to view0 frame
|
| 3214 |
+
pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze(
|
| 3215 |
+
0
|
| 3216 |
+
)
|
| 3217 |
+
batch_size = preds[0]["cam_quats"].shape[0]
|
| 3218 |
+
pred_camera0 = pred_camera0.repeat(batch_size, 1, 1)
|
| 3219 |
+
pred_camera0_rot = quaternion_to_rotation_matrix(
|
| 3220 |
+
preds[0]["cam_quats"].clone()
|
| 3221 |
+
)
|
| 3222 |
+
pred_camera0[..., :3, :3] = pred_camera0_rot
|
| 3223 |
+
pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone()
|
| 3224 |
+
pred_in_camera0 = closed_form_pose_inverse(pred_camera0)
|
| 3225 |
no_norm_pr_pts = []
|
| 3226 |
no_norm_pr_pts_cam = []
|
| 3227 |
no_norm_pr_depth = []
|
|
|
|
| 3276 |
gt_pose_quats.append(gt_pose_quats_in_view0)
|
| 3277 |
no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
|
| 3278 |
|
| 3279 |
+
# Get the global predictions in view0's frame
|
| 3280 |
+
if self.convert_predictions_to_view0_frame:
|
| 3281 |
+
# Convert predictions to view0 frame
|
| 3282 |
+
pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"])
|
| 3283 |
+
pr_pose_quats_in_view0, pr_pose_trans_in_view0 = (
|
| 3284 |
+
transform_pose_using_quats_and_trans_2_to_1(
|
| 3285 |
+
preds[0]["cam_quats"],
|
| 3286 |
+
preds[0]["cam_trans"],
|
| 3287 |
+
preds[i]["cam_quats"],
|
| 3288 |
+
preds[i]["cam_trans"],
|
| 3289 |
+
)
|
| 3290 |
+
)
|
| 3291 |
+
else:
|
| 3292 |
+
# Predictions are already in view0 frame
|
| 3293 |
+
pr_pts3d_in_view0 = preds[i]["pts3d"]
|
| 3294 |
+
pr_pose_trans_in_view0 = preds[i]["cam_trans"]
|
| 3295 |
+
pr_pose_quats_in_view0 = preds[i]["cam_quats"]
|
| 3296 |
+
|
| 3297 |
# Get predictions for normalized loss
|
| 3298 |
if self.depth_type_for_loss == "depth_along_ray":
|
| 3299 |
curr_view_no_norm_depth = preds[i]["depth_along_ray"]
|
|
|
|
| 3302 |
if "metric_scaling_factor" in preds[i].keys():
|
| 3303 |
# Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
|
| 3304 |
# This detaches the predicted metric scaling factor from the geometry based loss
|
| 3305 |
+
curr_view_no_norm_pr_pts = pr_pts3d_in_view0 / preds[i][
|
| 3306 |
"metric_scaling_factor"
|
| 3307 |
].unsqueeze(-1).unsqueeze(-1)
|
| 3308 |
curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
|
|
|
|
| 3312 |
"metric_scaling_factor"
|
| 3313 |
].unsqueeze(-1).unsqueeze(-1)
|
| 3314 |
curr_view_no_norm_pr_pose_trans = (
|
| 3315 |
+
pr_pose_trans_in_view0 / preds[i]["metric_scaling_factor"]
|
| 3316 |
)
|
| 3317 |
else:
|
| 3318 |
+
curr_view_no_norm_pr_pts = pr_pts3d_in_view0
|
| 3319 |
curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
|
| 3320 |
curr_view_no_norm_depth = curr_view_no_norm_depth
|
| 3321 |
+
curr_view_no_norm_pr_pose_trans = pr_pose_trans_in_view0
|
| 3322 |
no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
|
| 3323 |
no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
|
| 3324 |
no_norm_pr_depth.append(curr_view_no_norm_depth)
|
| 3325 |
no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
|
| 3326 |
pr_ray_directions.append(preds[i]["ray_directions"])
|
| 3327 |
+
pr_pose_quats.append(pr_pose_quats_in_view0)
|
| 3328 |
|
| 3329 |
# Get the predicted metric scale points
|
| 3330 |
if "metric_scaling_factor" in preds[i].keys():
|
|
|
|
| 3823 |
pose_trans_loss_weight=1,
|
| 3824 |
scale_loss_weight=1,
|
| 3825 |
compute_pairwise_relative_pose_loss=False,
|
| 3826 |
+
convert_predictions_to_view0_frame=False,
|
| 3827 |
compute_world_frame_points_loss=True,
|
| 3828 |
world_frame_points_loss_weight=1,
|
| 3829 |
apply_normal_and_gm_loss_to_synthetic_data_only=True,
|
|
|
|
| 3856 |
scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
|
| 3857 |
compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
|
| 3858 |
exhaustive pairwise relative poses. Default: False.
|
| 3859 |
+
convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
|
| 3860 |
+
Use this if the predictions are not already in the view0 frame. Default: False.
|
| 3861 |
compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
|
| 3862 |
world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
|
| 3863 |
apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
|
|
|
|
| 3880 |
pose_trans_loss_weight=pose_trans_loss_weight,
|
| 3881 |
scale_loss_weight=scale_loss_weight,
|
| 3882 |
compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
|
| 3883 |
+
convert_predictions_to_view0_frame=convert_predictions_to_view0_frame,
|
| 3884 |
compute_world_frame_points_loss=compute_world_frame_points_loss,
|
| 3885 |
world_frame_points_loss_weight=world_frame_points_loss_weight,
|
| 3886 |
)
|
mapanything/utils/geometry.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import Tuple, Union
|
|
| 10 |
import einops as ein
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
|
|
|
| 13 |
|
| 14 |
from mapanything.utils.misc import invalid_to_zeros
|
| 15 |
from mapanything.utils.warnings import no_warnings
|
|
@@ -646,6 +647,96 @@ def quaternion_to_rotation_matrix(quat):
|
|
| 646 |
return rot_matrix
|
| 647 |
|
| 648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
def quaternion_inverse(quat):
|
| 650 |
"""
|
| 651 |
Compute the inverse of a quaternion.
|
|
|
|
| 10 |
import einops as ein
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
|
| 15 |
from mapanything.utils.misc import invalid_to_zeros
|
| 16 |
from mapanything.utils.warnings import no_warnings
|
|
|
|
| 647 |
return rot_matrix
|
| 648 |
|
| 649 |
|
| 650 |
+
def rotation_matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 651 |
+
"""
|
| 652 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
| 659 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 660 |
+
"""
|
| 661 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 662 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 663 |
+
|
| 664 |
+
batch_dim = matrix.shape[:-2]
|
| 665 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
| 666 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
q_abs = _sqrt_positive_part(
|
| 670 |
+
torch.stack(
|
| 671 |
+
[
|
| 672 |
+
1.0 + m00 + m11 + m22,
|
| 673 |
+
1.0 + m00 - m11 - m22,
|
| 674 |
+
1.0 - m00 + m11 - m22,
|
| 675 |
+
1.0 - m00 - m11 + m22,
|
| 676 |
+
],
|
| 677 |
+
dim=-1,
|
| 678 |
+
)
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 682 |
+
quat_by_rijk = torch.stack(
|
| 683 |
+
[
|
| 684 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 685 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 686 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 687 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 688 |
+
],
|
| 689 |
+
dim=-2,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 693 |
+
# the candidate won't be picked.
|
| 694 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 695 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 696 |
+
|
| 697 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 698 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 699 |
+
out = quat_candidates[
|
| 700 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
| 701 |
+
].reshape(batch_dim + (4,))
|
| 702 |
+
|
| 703 |
+
# Convert from rijk to ijkr
|
| 704 |
+
out = out[..., [1, 2, 3, 0]]
|
| 705 |
+
|
| 706 |
+
out = standardize_quaternion(out)
|
| 707 |
+
|
| 708 |
+
return out
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 712 |
+
"""
|
| 713 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 714 |
+
but with a zero subgradient where x is 0.
|
| 715 |
+
"""
|
| 716 |
+
ret = torch.zeros_like(x)
|
| 717 |
+
positive_mask = x > 0
|
| 718 |
+
if torch.is_grad_enabled():
|
| 719 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 720 |
+
else:
|
| 721 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 722 |
+
return ret
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 726 |
+
"""
|
| 727 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 728 |
+
part is non negative.
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
quaternions: Quaternions with real part last,
|
| 732 |
+
as tensor of shape (..., 4).
|
| 733 |
+
|
| 734 |
+
Returns:
|
| 735 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 736 |
+
"""
|
| 737 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
| 738 |
+
|
| 739 |
+
|
| 740 |
def quaternion_inverse(quat):
|
| 741 |
"""
|
| 742 |
Compute the inverse of a quaternion.
|
mapanything/utils/image.py
CHANGED
|
@@ -287,6 +287,17 @@ def load_images(
|
|
| 287 |
f"Using target resolution {target_size[0]}x{target_size[1]} (W x H) for all images"
|
| 288 |
)
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
# Second pass: Resize all images to the same target size
|
| 291 |
imgs = []
|
| 292 |
for path, img, W1, H1 in loaded_images:
|
|
@@ -298,16 +309,6 @@ def load_images(
|
|
| 298 |
if verbose:
|
| 299 |
print(f" - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
|
| 300 |
|
| 301 |
-
if norm_type in IMAGE_NORMALIZATION_DICT.keys():
|
| 302 |
-
img_norm = IMAGE_NORMALIZATION_DICT[norm_type]
|
| 303 |
-
ImgNorm = tvf.Compose(
|
| 304 |
-
[tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)]
|
| 305 |
-
)
|
| 306 |
-
else:
|
| 307 |
-
raise ValueError(
|
| 308 |
-
f"Unknown image normalization type: {norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}"
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
imgs.append(
|
| 312 |
dict(
|
| 313 |
img=ImgNorm(img)[None],
|
|
|
|
| 287 |
f"Using target resolution {target_size[0]}x{target_size[1]} (W x H) for all images"
|
| 288 |
)
|
| 289 |
|
| 290 |
+
# Get the image normalization function based on the norm_type
|
| 291 |
+
if norm_type in IMAGE_NORMALIZATION_DICT.keys():
|
| 292 |
+
img_norm = IMAGE_NORMALIZATION_DICT[norm_type]
|
| 293 |
+
ImgNorm = tvf.Compose(
|
| 294 |
+
[tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)]
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"Unknown image normalization type: {norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
# Second pass: Resize all images to the same target size
|
| 302 |
imgs = []
|
| 303 |
for path, img, W1, H1 in loaded_images:
|
|
|
|
| 309 |
if verbose:
|
| 310 |
print(f" - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
imgs.append(
|
| 313 |
dict(
|
| 314 |
img=ImgNorm(img)[None],
|
mapanything/utils/inference.py
CHANGED
|
@@ -3,9 +3,43 @@ Inference utilities.
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import warnings
|
|
|
|
| 6 |
|
|
|
|
| 7 |
import torch
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def loss_of_one_batch_multi_view(
|
| 11 |
batch,
|
|
@@ -84,3 +118,358 @@ def loss_of_one_batch_multi_view(
|
|
| 84 |
result["loss"] = loss
|
| 85 |
|
| 86 |
return result[ret] if ret else result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import warnings
|
| 6 |
+
from typing import Any, Dict, List
|
| 7 |
|
| 8 |
+
import numpy as np
|
| 9 |
import torch
|
| 10 |
|
| 11 |
+
from mapanything.utils.geometry import (
|
| 12 |
+
depth_edge,
|
| 13 |
+
get_rays_in_camera_frame,
|
| 14 |
+
normals_edge,
|
| 15 |
+
points_to_normals,
|
| 16 |
+
quaternion_to_rotation_matrix,
|
| 17 |
+
recover_pinhole_intrinsics_from_ray_directions,
|
| 18 |
+
rotation_matrix_to_quaternion,
|
| 19 |
+
)
|
| 20 |
+
from mapanything.utils.image import rgb
|
| 21 |
+
|
| 22 |
+
# Hard constraints - exactly what users can provide
|
| 23 |
+
ALLOWED_VIEW_KEYS = {
|
| 24 |
+
"img", # Required - input images
|
| 25 |
+
"data_norm_type", # Required - normalization type of the input images
|
| 26 |
+
"depth_z", # Optional - Z depth maps
|
| 27 |
+
"ray_directions", # Optional - ray directions in camera frame
|
| 28 |
+
"intrinsics", # Optional - pinhole camera intrinsics (conflicts with ray_directions)
|
| 29 |
+
"camera_poses", # Optional - camera poses
|
| 30 |
+
"is_metric_scale", # Optional - whether inputs are metric scale
|
| 31 |
+
"true_shape", # Optional - original image shape
|
| 32 |
+
"idx", # Optional - index of the view
|
| 33 |
+
"instance", # Optional - instance info of the view
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
REQUIRED_KEYS = {"img", "data_norm_type"}
|
| 37 |
+
|
| 38 |
+
# Define conflicting keys that cannot be used together
|
| 39 |
+
CONFLICTING_KEYS = [
|
| 40 |
+
("intrinsics", "ray_directions") # Both represent camera projection
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
|
| 44 |
def loss_of_one_batch_multi_view(
|
| 45 |
batch,
|
|
|
|
| 118 |
result["loss"] = loss
|
| 119 |
|
| 120 |
return result[ret] if ret else result
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def validate_input_views_for_inference(
|
| 124 |
+
views: List[Dict[str, Any]],
|
| 125 |
+
) -> List[Dict[str, Any]]:
|
| 126 |
+
"""
|
| 127 |
+
Strict validation and preprocessing of input views.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
views: List of view dictionaries
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Validated and preprocessed views
|
| 134 |
+
|
| 135 |
+
Raises:
|
| 136 |
+
ValueError: For invalid keys, missing required keys, conflicting inputs, or invalid camera pose constraints
|
| 137 |
+
"""
|
| 138 |
+
# Ensure input is not empty
|
| 139 |
+
if not views:
|
| 140 |
+
raise ValueError("At least one view must be provided")
|
| 141 |
+
|
| 142 |
+
# Track which views have camera poses
|
| 143 |
+
views_with_poses = []
|
| 144 |
+
|
| 145 |
+
# Validate each view
|
| 146 |
+
for view_idx, view in enumerate(views):
|
| 147 |
+
# Check for invalid keys
|
| 148 |
+
provided_keys = set(view.keys())
|
| 149 |
+
invalid_keys = provided_keys - ALLOWED_VIEW_KEYS
|
| 150 |
+
if invalid_keys:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
f"View {view_idx} contains invalid keys: {invalid_keys}. "
|
| 153 |
+
f"Allowed keys are: {sorted(ALLOWED_VIEW_KEYS)}"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Check for missing required keys
|
| 157 |
+
missing_keys = REQUIRED_KEYS - provided_keys
|
| 158 |
+
if missing_keys:
|
| 159 |
+
raise ValueError(f"View {view_idx} missing required keys: {missing_keys}")
|
| 160 |
+
|
| 161 |
+
# Check for conflicting keys
|
| 162 |
+
for conflict_set in CONFLICTING_KEYS:
|
| 163 |
+
present_conflicts = [key for key in conflict_set if key in provided_keys]
|
| 164 |
+
if len(present_conflicts) > 1:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
f"View {view_idx} contains conflicting keys: {present_conflicts}. "
|
| 167 |
+
f"Only one of {conflict_set} can be provided at a time."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Check depth constraint: If depth is provided, intrinsics or ray_directions must also be provided
|
| 171 |
+
if "depth_z" in provided_keys:
|
| 172 |
+
if (
|
| 173 |
+
"intrinsics" not in provided_keys
|
| 174 |
+
and "ray_directions" not in provided_keys
|
| 175 |
+
):
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"View {view_idx} depth constraint violation: If 'depth_z' is provided, "
|
| 178 |
+
f"then 'intrinsics' or 'ray_directions' must also be provided. "
|
| 179 |
+
f"Z Depth values require camera calibration information to be meaningful for an image."
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Track views with camera poses
|
| 183 |
+
if "camera_poses" in provided_keys:
|
| 184 |
+
views_with_poses.append(view_idx)
|
| 185 |
+
|
| 186 |
+
# Cross-view constraint: If any view has camera_poses, view 0 must have them too
|
| 187 |
+
if views_with_poses and 0 not in views_with_poses:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
f"Camera pose constraint violation: Views {views_with_poses} have camera_poses, "
|
| 190 |
+
f"but view 0 (reference view) does not. When using camera_poses, the first view "
|
| 191 |
+
f"must also provide camera_poses to serve as the reference frame."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
return views
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def preprocess_input_views_for_inference(
|
| 198 |
+
views: List[Dict[str, Any]],
|
| 199 |
+
) -> List[Dict[str, Any]]:
|
| 200 |
+
"""
|
| 201 |
+
Pre-process input views to match the expected internal input format.
|
| 202 |
+
|
| 203 |
+
The following steps are performed:
|
| 204 |
+
1. Convert intrinsics to ray directions when required. If ray directions are already provided, unit normalize them.
|
| 205 |
+
2. Convert depth_z to depth_along_ray
|
| 206 |
+
3. Convert camera_poses to the expected input keys (camera_pose_quats and camera_pose_trans)
|
| 207 |
+
4. Default is_metric_scale to True when not provided
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
views: List of view dictionaries
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Preprocessed views with consistent internal format
|
| 214 |
+
"""
|
| 215 |
+
processed_views = []
|
| 216 |
+
|
| 217 |
+
for view_idx, view in enumerate(views):
|
| 218 |
+
# Copy the view dictionary to avoid modifying the original input
|
| 219 |
+
processed_view = dict(view)
|
| 220 |
+
|
| 221 |
+
# Step 1: Convert intrinsics to ray_directions when required. If ray_directions are provided, unit normalize them.
|
| 222 |
+
if "intrinsics" in view:
|
| 223 |
+
images = view["img"]
|
| 224 |
+
height, width = images.shape[-2:]
|
| 225 |
+
intrinsics = view["intrinsics"]
|
| 226 |
+
_, ray_directions = get_rays_in_camera_frame(
|
| 227 |
+
intrinsics=intrinsics,
|
| 228 |
+
height=height,
|
| 229 |
+
width=width,
|
| 230 |
+
normalize_to_unit_sphere=True,
|
| 231 |
+
)
|
| 232 |
+
processed_view["ray_directions"] = ray_directions
|
| 233 |
+
del processed_view["intrinsics"]
|
| 234 |
+
elif "ray_directions" in view:
|
| 235 |
+
ray_directions = view["ray_directions"]
|
| 236 |
+
ray_norm = torch.norm(ray_directions, dim=-1, keepdim=True)
|
| 237 |
+
processed_view["ray_directions"] = ray_directions / (ray_norm + 1e-8)
|
| 238 |
+
|
| 239 |
+
# Step 2: Convert depth_z to depth_along_ray
|
| 240 |
+
if "depth_z" in view:
|
| 241 |
+
depth_z = view["depth_z"]
|
| 242 |
+
ray_directions = processed_view["ray_directions"]
|
| 243 |
+
ray_directions_unit_plane = ray_directions / ray_directions[..., 2:3]
|
| 244 |
+
pts3d_cam = depth_z * ray_directions_unit_plane
|
| 245 |
+
depth_along_ray = torch.norm(pts3d_cam, dim=-1, keepdim=True)
|
| 246 |
+
processed_view["depth_along_ray"] = depth_along_ray
|
| 247 |
+
del processed_view["depth_z"]
|
| 248 |
+
|
| 249 |
+
# Step 3: Convert camera_poses to expected input keys
|
| 250 |
+
if "camera_poses" in view:
|
| 251 |
+
camera_poses = view["camera_poses"]
|
| 252 |
+
if isinstance(camera_poses, tuple) and len(camera_poses) == 2:
|
| 253 |
+
quats, trans = camera_poses
|
| 254 |
+
processed_view["camera_pose_quats"] = quats
|
| 255 |
+
processed_view["camera_pose_trans"] = trans
|
| 256 |
+
elif torch.is_tensor(camera_poses) and camera_poses.shape[-2:] == (4, 4):
|
| 257 |
+
rotation_matrices = camera_poses[:, :3, :3]
|
| 258 |
+
translation_vectors = camera_poses[:, :3, 3]
|
| 259 |
+
quats = rotation_matrix_to_quaternion(rotation_matrices)
|
| 260 |
+
processed_view["camera_pose_quats"] = quats
|
| 261 |
+
processed_view["camera_pose_trans"] = translation_vectors
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
f"View {view_idx}: camera_poses must be either a tuple of (quats, trans) "
|
| 265 |
+
f"or a tensor of (B, 4, 4) transformation matrices."
|
| 266 |
+
)
|
| 267 |
+
del processed_view["camera_poses"]
|
| 268 |
+
|
| 269 |
+
# Step 4: Default is_metric_scale to True when not provided
|
| 270 |
+
if "is_metric_scale" not in processed_view:
|
| 271 |
+
# Get batch size from the image tensor
|
| 272 |
+
batch_size = view["img"].shape[0]
|
| 273 |
+
# Default to True for all samples in the batch
|
| 274 |
+
processed_view["is_metric_scale"] = torch.ones(
|
| 275 |
+
batch_size, dtype=torch.bool, device=view["img"].device
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Rename keys to match expected model input format
|
| 279 |
+
if "ray_directions" in processed_view:
|
| 280 |
+
processed_view["ray_directions_cam"] = processed_view["ray_directions"]
|
| 281 |
+
del processed_view["ray_directions"]
|
| 282 |
+
|
| 283 |
+
# Append the processed view to the list
|
| 284 |
+
processed_views.append(processed_view)
|
| 285 |
+
|
| 286 |
+
return processed_views
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def postprocess_model_outputs_for_inference(
|
| 290 |
+
raw_outputs: List[Dict[str, torch.Tensor]],
|
| 291 |
+
input_views: List[Dict[str, Any]],
|
| 292 |
+
apply_mask: bool = True,
|
| 293 |
+
mask_edges: bool = True,
|
| 294 |
+
edge_normal_threshold: float = 5.0,
|
| 295 |
+
edge_depth_threshold: float = 0.03,
|
| 296 |
+
apply_confidence_mask: bool = False,
|
| 297 |
+
confidence_percentile: float = 10,
|
| 298 |
+
) -> List[Dict[str, torch.Tensor]]:
|
| 299 |
+
"""
|
| 300 |
+
Post-process raw model outputs by copying raw outputs and adding essential derived fields.
|
| 301 |
+
|
| 302 |
+
This function simplifies the raw model outputs by:
|
| 303 |
+
1. Copying all raw outputs as-is
|
| 304 |
+
2. Adding denormalized images (img_no_norm)
|
| 305 |
+
3. Adding Z depth (depth_z) from camera frame points
|
| 306 |
+
4. Recovering pinhole camera intrinsics from ray directions
|
| 307 |
+
5. Adding camera pose matrices (camera_poses) if pose data is available
|
| 308 |
+
6. Applying mask to dense geometry outputs if requested (supports edge masking and confidence masking)
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
raw_outputs: List of raw model output dictionaries, one per view
|
| 312 |
+
input_views: List of original input view dictionaries, one per view
|
| 313 |
+
apply_mask: Whether to apply non-ambiguous mask to dense outputs. Defaults to True.
|
| 314 |
+
mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
|
| 315 |
+
apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False.
|
| 316 |
+
confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10.
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
List of processed output dictionaries containing:
|
| 320 |
+
- All original raw outputs (after masking dense geometry outputs if requested)
|
| 321 |
+
- 'img_no_norm': Denormalized RGB images (B, H, W, 3)
|
| 322 |
+
- 'depth_z': Z depth from camera frame (B, H, W, 1) if points in camera frame available
|
| 323 |
+
- 'intrinsics': Recovered pinhole camera intrinsics (B, 3, 3) if ray directions available
|
| 324 |
+
- 'camera_poses': 4x4 pose matrices (B, 4, 4) if pose data available
|
| 325 |
+
- 'mask': comprehensive mask for dense geometry outputs (B, H, W, 1) if requested
|
| 326 |
+
|
| 327 |
+
"""
|
| 328 |
+
processed_outputs = []
|
| 329 |
+
|
| 330 |
+
for view_idx, (raw_output, original_view) in enumerate(
|
| 331 |
+
zip(raw_outputs, input_views)
|
| 332 |
+
):
|
| 333 |
+
# Start by copying all raw outputs
|
| 334 |
+
processed_output = dict(raw_output)
|
| 335 |
+
|
| 336 |
+
# 1. Add denormalized images
|
| 337 |
+
img = original_view["img"] # Shape: (B, 3, H, W)
|
| 338 |
+
data_norm_type = original_view["data_norm_type"][0]
|
| 339 |
+
img_hwc = rgb(img, data_norm_type)
|
| 340 |
+
|
| 341 |
+
# Convert numpy back to torch if needed (rgb returns numpy)
|
| 342 |
+
if isinstance(img_hwc, np.ndarray):
|
| 343 |
+
img_hwc = torch.from_numpy(img_hwc).to(img.device)
|
| 344 |
+
|
| 345 |
+
processed_output["img_no_norm"] = img_hwc
|
| 346 |
+
|
| 347 |
+
# 2. Add Z depth if we have camera frame points
|
| 348 |
+
if "pts3d_cam" in processed_output:
|
| 349 |
+
processed_output["depth_z"] = processed_output["pts3d_cam"][..., 2:3]
|
| 350 |
+
|
| 351 |
+
# 3. Recover pinhole camera intrinsics from ray directions if available
|
| 352 |
+
if "ray_directions" in processed_output:
|
| 353 |
+
intrinsics = recover_pinhole_intrinsics_from_ray_directions(
|
| 354 |
+
processed_output["ray_directions"]
|
| 355 |
+
)
|
| 356 |
+
processed_output["intrinsics"] = intrinsics
|
| 357 |
+
|
| 358 |
+
# 4. Add camera pose matrices if both translation and quaternions are available
|
| 359 |
+
if "cam_trans" in processed_output and "cam_quats" in processed_output:
|
| 360 |
+
cam_trans = processed_output["cam_trans"] # (B, 3)
|
| 361 |
+
cam_quats = processed_output["cam_quats"] # (B, 4)
|
| 362 |
+
batch_size = cam_trans.shape[0]
|
| 363 |
+
|
| 364 |
+
# Convert quaternions to rotation matrices
|
| 365 |
+
rotation_matrices = quaternion_to_rotation_matrix(cam_quats) # (B, 3, 3)
|
| 366 |
+
|
| 367 |
+
# Create 4x4 pose matrices
|
| 368 |
+
pose_matrices = (
|
| 369 |
+
torch.eye(4, device=img.device).unsqueeze(0).repeat(batch_size, 1, 1)
|
| 370 |
+
)
|
| 371 |
+
pose_matrices[:, :3, :3] = rotation_matrices
|
| 372 |
+
pose_matrices[:, :3, 3] = cam_trans
|
| 373 |
+
|
| 374 |
+
processed_output["camera_poses"] = pose_matrices # (B, 4, 4)
|
| 375 |
+
|
| 376 |
+
# 5. Apply comprehensive mask to dense geometry outputs if requested
|
| 377 |
+
if apply_mask:
|
| 378 |
+
final_mask = None
|
| 379 |
+
|
| 380 |
+
# Start with non-ambiguous mask if available
|
| 381 |
+
if "non_ambiguous_mask" in processed_output:
|
| 382 |
+
non_ambiguous_mask = (
|
| 383 |
+
processed_output["non_ambiguous_mask"].cpu().numpy()
|
| 384 |
+
) # (B, H, W)
|
| 385 |
+
final_mask = non_ambiguous_mask
|
| 386 |
+
|
| 387 |
+
# Apply confidence mask if requested and available
|
| 388 |
+
if apply_confidence_mask and "conf" in processed_output:
|
| 389 |
+
confidences = processed_output["conf"].cpu() # (B, H, W)
|
| 390 |
+
# Compute percentile threshold for each batch element
|
| 391 |
+
batch_size = confidences.shape[0]
|
| 392 |
+
conf_mask = torch.zeros_like(confidences, dtype=torch.bool)
|
| 393 |
+
percentile_threshold = (
|
| 394 |
+
torch.quantile(
|
| 395 |
+
confidences.reshape(batch_size, -1),
|
| 396 |
+
confidence_percentile / 100.0,
|
| 397 |
+
dim=1,
|
| 398 |
+
)
|
| 399 |
+
.unsqueeze(-1)
|
| 400 |
+
.unsqueeze(-1)
|
| 401 |
+
) # Shape: (B, 1, 1)
|
| 402 |
+
|
| 403 |
+
# Compute mask for each batch element
|
| 404 |
+
conf_mask = confidences > percentile_threshold
|
| 405 |
+
conf_mask = conf_mask.numpy()
|
| 406 |
+
|
| 407 |
+
if final_mask is not None:
|
| 408 |
+
final_mask = final_mask & conf_mask
|
| 409 |
+
else:
|
| 410 |
+
final_mask = conf_mask
|
| 411 |
+
|
| 412 |
+
# Apply edge mask if requested and we have the required data
|
| 413 |
+
if mask_edges and final_mask is not None and "pts3d" in processed_output:
|
| 414 |
+
# Get 3D points for edge computation
|
| 415 |
+
pred_pts3d = processed_output["pts3d"].cpu().numpy() # (B, H, W, 3)
|
| 416 |
+
batch_size, height, width = final_mask.shape
|
| 417 |
+
|
| 418 |
+
edge_masks = []
|
| 419 |
+
for b in range(batch_size):
|
| 420 |
+
batch_final_mask = final_mask[b] # (H, W)
|
| 421 |
+
batch_pts3d = pred_pts3d[b] # (H, W, 3)
|
| 422 |
+
|
| 423 |
+
if batch_final_mask.any(): # Only compute if we have valid points
|
| 424 |
+
# Compute normals and normal-based edge mask
|
| 425 |
+
normals, normals_mask = points_to_normals(
|
| 426 |
+
batch_pts3d, mask=batch_final_mask
|
| 427 |
+
)
|
| 428 |
+
normal_edges = normals_edge(
|
| 429 |
+
normals, tol=edge_normal_threshold, mask=normals_mask
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Compute depth-based edge mask
|
| 433 |
+
depth_z = (
|
| 434 |
+
processed_output["depth_z"][b].squeeze(-1).cpu().numpy()
|
| 435 |
+
)
|
| 436 |
+
depth_edges = depth_edge(
|
| 437 |
+
depth_z, rtol=edge_depth_threshold, mask=batch_final_mask
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Combine both edge types
|
| 441 |
+
edge_mask = ~(depth_edges & normal_edges)
|
| 442 |
+
edge_masks.append(edge_mask)
|
| 443 |
+
else:
|
| 444 |
+
# No valid points, keep all as invalid
|
| 445 |
+
edge_masks.append(np.zeros_like(batch_final_mask, dtype=bool))
|
| 446 |
+
|
| 447 |
+
# Stack batch edge masks and combine with final mask
|
| 448 |
+
edge_mask = np.stack(edge_masks, axis=0) # (B, H, W)
|
| 449 |
+
final_mask = final_mask & edge_mask
|
| 450 |
+
|
| 451 |
+
# Apply final mask to dense geometry outputs if we have a mask
|
| 452 |
+
if final_mask is not None:
|
| 453 |
+
# Convert mask to torch tensor
|
| 454 |
+
final_mask_torch = torch.from_numpy(final_mask).to(
|
| 455 |
+
processed_output["pts3d"].device
|
| 456 |
+
)
|
| 457 |
+
final_mask_torch = final_mask_torch.unsqueeze(-1) # (B, H, W, 1)
|
| 458 |
+
|
| 459 |
+
# Apply mask to dense geometry outputs (zero out invalid regions)
|
| 460 |
+
dense_geometry_keys = [
|
| 461 |
+
"pts3d",
|
| 462 |
+
"pts3d_cam",
|
| 463 |
+
"depth_along_ray",
|
| 464 |
+
"depth_z",
|
| 465 |
+
]
|
| 466 |
+
for key in dense_geometry_keys:
|
| 467 |
+
if key in processed_output:
|
| 468 |
+
processed_output[key] = processed_output[key] * final_mask_torch
|
| 469 |
+
|
| 470 |
+
# Add mask to processed output
|
| 471 |
+
processed_output["mask"] = final_mask_torch
|
| 472 |
+
|
| 473 |
+
processed_outputs.append(processed_output)
|
| 474 |
+
|
| 475 |
+
return processed_outputs
|
mapanything/utils/viz.py
CHANGED
|
@@ -110,7 +110,7 @@ def script_add_rerun_args(parser: ArgumentParser) -> None:
|
|
| 110 |
parser.add_argument(
|
| 111 |
"--url",
|
| 112 |
type=str,
|
| 113 |
-
default="rerun+http://127.0.0.1:
|
| 114 |
help="Connect to this HTTP(S) URL",
|
| 115 |
)
|
| 116 |
parser.add_argument(
|
|
@@ -129,7 +129,7 @@ def init_rerun_args(
|
|
| 129 |
headless=True,
|
| 130 |
connect=True,
|
| 131 |
serve=False,
|
| 132 |
-
url="rerun+http://127.0.0.1:
|
| 133 |
save=None,
|
| 134 |
stdout=False,
|
| 135 |
) -> Namespace:
|
|
|
|
| 110 |
parser.add_argument(
|
| 111 |
"--url",
|
| 112 |
type=str,
|
| 113 |
+
default="rerun+http://127.0.0.1:2004/proxy",
|
| 114 |
help="Connect to this HTTP(S) URL",
|
| 115 |
)
|
| 116 |
parser.add_argument(
|
|
|
|
| 129 |
headless=True,
|
| 130 |
connect=True,
|
| 131 |
serve=False,
|
| 132 |
+
url="rerun+http://127.0.0.1:2004/proxy",
|
| 133 |
save=None,
|
| 134 |
stdout=False,
|
| 135 |
) -> Namespace:
|
mapanything/utils/wai/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This utils module contains PORTAGE of wai-core scripts/methods for MapAnything.
|
| 3 |
+
"""
|
mapanything/utils/wai/basic_dataset.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from box import Box
|
| 6 |
+
|
| 7 |
+
from mapanything.utils.wai.core import get_frame_index, load_data, load_frame
|
| 8 |
+
from mapanything.utils.wai.ops import stack
|
| 9 |
+
from mapanything.utils.wai.scene_frame import get_scene_frame_names
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BasicSceneframeDataset(torch.utils.data.Dataset):
|
| 13 |
+
"""Basic wai dataset to iterative over frames of scenes"""
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
|
| 17 |
+
return stack(batch)
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
cfg: Box,
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
Initialize the BasicSceneframeDataset.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
cfg (Box): Configuration object containing dataset parameters including:
|
| 28 |
+
- root: Root directory containing scene data
|
| 29 |
+
- frame_modalities: List of modalities to load for each frame
|
| 30 |
+
- key_remap: Optional dictionary mapping original keys to new keys
|
| 31 |
+
"""
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.cfg = cfg
|
| 34 |
+
self.root = cfg.root
|
| 35 |
+
keyframes = cfg.get("use_keyframes", True)
|
| 36 |
+
self.scene_frame_names = get_scene_frame_names(cfg, keyframes=keyframes)
|
| 37 |
+
self.scene_frame_list = [
|
| 38 |
+
(scene_name, frame_name)
|
| 39 |
+
for scene_name, frame_names in self.scene_frame_names.items()
|
| 40 |
+
for frame_name in frame_names
|
| 41 |
+
]
|
| 42 |
+
self._scene_cache = {}
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
"""
|
| 46 |
+
Get the total number of scene-frame pairs in the dataset.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
int: The number of scene-frame pairs.
|
| 50 |
+
"""
|
| 51 |
+
return len(self.scene_frame_list)
|
| 52 |
+
|
| 53 |
+
def _load_scene(self, scene_name: str) -> dict[str, Any]:
|
| 54 |
+
"""
|
| 55 |
+
Load scene data for a given scene name.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
scene_name (str): The name of the scene to load.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
dict: A dictionary containing scene data, including scene metadata.
|
| 62 |
+
"""
|
| 63 |
+
# load scene data
|
| 64 |
+
scene_data = {}
|
| 65 |
+
scene_data["meta"] = load_data(
|
| 66 |
+
Path(
|
| 67 |
+
self.root,
|
| 68 |
+
scene_name,
|
| 69 |
+
self.cfg.get("scene_meta_path", "scene_meta.json"),
|
| 70 |
+
),
|
| 71 |
+
"scene_meta",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return scene_data
|
| 75 |
+
|
| 76 |
+
def _load_scene_frame(
|
| 77 |
+
self, scene_name: str, frame_name: str | float
|
| 78 |
+
) -> dict[str, Any]:
|
| 79 |
+
"""
|
| 80 |
+
Load data for a specific frame from a specific scene.
|
| 81 |
+
|
| 82 |
+
This method loads scene data if not already cached, then loads the specified frame
|
| 83 |
+
from that scene with the modalities specified in the configuration.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
scene_name (str): The name of the scene containing the frame.
|
| 87 |
+
frame_name (str or float): The name/timestamp of the frame to load.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
dict: A dictionary containing the loaded frame data with requested modalities.
|
| 91 |
+
"""
|
| 92 |
+
scene_frame_data = {}
|
| 93 |
+
if not (scene_data := self._scene_cache.get(scene_name)):
|
| 94 |
+
scene_data = self._load_scene(scene_name)
|
| 95 |
+
# for now only cache the last scene
|
| 96 |
+
self._scene_cache = {}
|
| 97 |
+
self._scene_cache[scene_name] = scene_data
|
| 98 |
+
|
| 99 |
+
frame_idx = get_frame_index(scene_data["meta"], frame_name)
|
| 100 |
+
|
| 101 |
+
scene_frame_data["scene_name"] = scene_name
|
| 102 |
+
scene_frame_data["frame_name"] = frame_name
|
| 103 |
+
scene_frame_data["scene_path"] = str(Path(self.root, scene_name))
|
| 104 |
+
scene_frame_data["frame_idx"] = frame_idx
|
| 105 |
+
scene_frame_data.update(
|
| 106 |
+
load_frame(
|
| 107 |
+
Path(self.root, scene_name),
|
| 108 |
+
frame_name,
|
| 109 |
+
modalities=self.cfg.frame_modalities,
|
| 110 |
+
scene_meta=scene_data["meta"],
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
# Remap key names
|
| 114 |
+
for key, new_key in self.cfg.get("key_remap", {}).items():
|
| 115 |
+
if key in scene_frame_data:
|
| 116 |
+
scene_frame_data[new_key] = scene_frame_data.pop(key)
|
| 117 |
+
|
| 118 |
+
return scene_frame_data
|
| 119 |
+
|
| 120 |
+
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 121 |
+
"""
|
| 122 |
+
Get a specific scene-frame pair by index.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
index (int): The index of the scene-frame pair to retrieve.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
dict: A dictionary containing the loaded frame data with requested modalities.
|
| 129 |
+
"""
|
| 130 |
+
scene_frame = self._load_scene_frame(*self.scene_frame_list[index])
|
| 131 |
+
return scene_frame
|
mapanything/utils/wai/camera.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This utils script contains PORTAGE of wai-core camera methods for MapAnything.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from scipy.spatial.transform import Rotation, Slerp
|
| 10 |
+
|
| 11 |
+
from mapanything.utils.wai.ops import get_dtype_device
|
| 12 |
+
|
| 13 |
+
# constants regarding camera models
|
| 14 |
+
PINHOLE_CAM_KEYS = ["fl_x", "fl_y", "cx", "cy", "h", "w"]
|
| 15 |
+
DISTORTION_PARAM_KEYS = [
|
| 16 |
+
"k1",
|
| 17 |
+
"k2",
|
| 18 |
+
"k3",
|
| 19 |
+
"k4",
|
| 20 |
+
"p1",
|
| 21 |
+
"p2",
|
| 22 |
+
] # order corresponds to the OpenCV convention
|
| 23 |
+
CAMERA_KEYS = PINHOLE_CAM_KEYS + DISTORTION_PARAM_KEYS
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def interpolate_intrinsics(
|
| 27 |
+
frame1: dict[str, Any],
|
| 28 |
+
frame2: dict[str, Any],
|
| 29 |
+
alpha: float,
|
| 30 |
+
) -> dict[str, Any]:
|
| 31 |
+
"""
|
| 32 |
+
Interpolate camera intrinsics linearly.
|
| 33 |
+
Args:
|
| 34 |
+
frame1: The first frame dictionary.
|
| 35 |
+
frame2: The second frame dictionary.
|
| 36 |
+
alpha: Interpolation parameter. alpha = 0 for frame1, alpha = 1 for frame2.
|
| 37 |
+
Returns:
|
| 38 |
+
frame_inter: dictionary with new intrinsics.
|
| 39 |
+
"""
|
| 40 |
+
frame_inter = {}
|
| 41 |
+
for key in CAMERA_KEYS:
|
| 42 |
+
if key in frame1 and key in frame2:
|
| 43 |
+
p1 = frame1[key]
|
| 44 |
+
p2 = frame2[key]
|
| 45 |
+
frame_inter[key] = (1 - alpha) * p1 + alpha * p2
|
| 46 |
+
return frame_inter
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def interpolate_extrinsics(
|
| 50 |
+
matrix1: list | np.ndarray | torch.Tensor,
|
| 51 |
+
matrix2: list | np.ndarray | torch.Tensor,
|
| 52 |
+
alpha: float,
|
| 53 |
+
) -> list | np.ndarray | torch.Tensor:
|
| 54 |
+
"""
|
| 55 |
+
Interpolate camera extrinsics 4x4 matrices using SLERP.
|
| 56 |
+
Args:
|
| 57 |
+
matrix1: The first matrix.
|
| 58 |
+
matrix2: The second matrix.
|
| 59 |
+
alpha: Interpolation parameter. alpha = 0 for matrix1, alpha = 1 for matrix2.
|
| 60 |
+
Returns:
|
| 61 |
+
matrix: 4x4 interpolated matrix, same type.
|
| 62 |
+
Raises:
|
| 63 |
+
ValueError: If different type.
|
| 64 |
+
"""
|
| 65 |
+
if not isinstance(matrix1, type(matrix2)):
|
| 66 |
+
raise ValueError("Both matrices should have the same type.")
|
| 67 |
+
|
| 68 |
+
dtype, device = get_dtype_device(matrix1)
|
| 69 |
+
if isinstance(matrix1, list):
|
| 70 |
+
mtype = "list"
|
| 71 |
+
matrix1 = np.array(matrix1)
|
| 72 |
+
matrix2 = np.array(matrix2)
|
| 73 |
+
elif isinstance(matrix1, np.ndarray):
|
| 74 |
+
mtype = "numpy"
|
| 75 |
+
elif isinstance(matrix1, torch.Tensor):
|
| 76 |
+
mtype = "torch"
|
| 77 |
+
matrix1 = matrix1.numpy()
|
| 78 |
+
matrix2 = matrix2.numpy()
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
"Only list, numpy array and torch tensors are supported as inputs."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
R1 = matrix1[:3, :3]
|
| 85 |
+
t1 = matrix1[:3, 3]
|
| 86 |
+
R2 = matrix2[:3, :3]
|
| 87 |
+
t2 = matrix2[:3, 3]
|
| 88 |
+
|
| 89 |
+
# interpolate translation
|
| 90 |
+
t = (1 - alpha) * t1 + alpha * t2
|
| 91 |
+
|
| 92 |
+
# interpolate rotations with SLERP
|
| 93 |
+
R1_quat = Rotation.from_matrix(R1).as_quat()
|
| 94 |
+
R2_quat = Rotation.from_matrix(R2).as_quat()
|
| 95 |
+
rotation_slerp = Slerp([0, 1], Rotation(np.stack([R1_quat, R2_quat])))
|
| 96 |
+
R = rotation_slerp(alpha).as_matrix()
|
| 97 |
+
matrix_inter = np.eye(4)
|
| 98 |
+
|
| 99 |
+
# combine together
|
| 100 |
+
matrix_inter[:3, :3] = R
|
| 101 |
+
matrix_inter[:3, 3] = t
|
| 102 |
+
|
| 103 |
+
if mtype == "list":
|
| 104 |
+
matrix_inter = matrix_inter.tolist()
|
| 105 |
+
elif mtype == "torch":
|
| 106 |
+
matrix_inter = torch.from_numpy(matrix_inter).to(dtype).to(device)
|
| 107 |
+
elif mtype == "numpy":
|
| 108 |
+
matrix_inter = matrix_inter.astype(dtype)
|
| 109 |
+
|
| 110 |
+
return matrix_inter
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def convert_camera_coeffs_to_pinhole_matrix(
|
| 114 |
+
scene_meta, frame, fmt="torch"
|
| 115 |
+
) -> torch.Tensor | np.ndarray | list:
|
| 116 |
+
"""
|
| 117 |
+
Convert camera intrinsics from NeRFStudio format to a 3x3 intrinsics matrix.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
scene_meta: Scene metadata containing camera parameters
|
| 121 |
+
frame: Frame-specific camera parameters that override scene_meta
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
torch.Tensor: 3x3 camera intrinsics matrix
|
| 125 |
+
|
| 126 |
+
Raises:
|
| 127 |
+
ValueError: If camera model is not PINHOLE or if distortion coefficients are present
|
| 128 |
+
"""
|
| 129 |
+
# Check if camera model is supported
|
| 130 |
+
camera_model = frame.get("camera_model", scene_meta.get("camera_model"))
|
| 131 |
+
if camera_model != "PINHOLE":
|
| 132 |
+
raise ValueError("Only PINHOLE camera model supported")
|
| 133 |
+
|
| 134 |
+
# Check for unsupported distortion coefficients
|
| 135 |
+
if any(
|
| 136 |
+
(frame.get(coeff, 0) != 0) or (scene_meta.get(coeff, 0) != 0)
|
| 137 |
+
for coeff in DISTORTION_PARAM_KEYS
|
| 138 |
+
):
|
| 139 |
+
raise ValueError(
|
| 140 |
+
"Pinhole camera does not support radial/tangential distortion -> Undistort first"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Extract camera intrinsic parameters
|
| 144 |
+
camera_coeffs = {}
|
| 145 |
+
for coeff in ["fl_x", "fl_y", "cx", "cy"]:
|
| 146 |
+
camera_coeffs[coeff] = frame.get(coeff, scene_meta.get(coeff))
|
| 147 |
+
if camera_coeffs[coeff] is None:
|
| 148 |
+
raise ValueError(f"Missing required camera parameter: {coeff}")
|
| 149 |
+
|
| 150 |
+
# Create intrinsics matrix
|
| 151 |
+
intrinsics = [
|
| 152 |
+
[camera_coeffs["fl_x"], 0.0, camera_coeffs["cx"]],
|
| 153 |
+
[0.0, camera_coeffs["fl_y"], camera_coeffs["cy"]],
|
| 154 |
+
[0.0, 0.0, 1.0],
|
| 155 |
+
]
|
| 156 |
+
if fmt == "torch":
|
| 157 |
+
intrinsics = torch.tensor(intrinsics)
|
| 158 |
+
elif fmt == "np":
|
| 159 |
+
intrinsics = np.array(intrinsics)
|
| 160 |
+
|
| 161 |
+
return intrinsics
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def rotate_pinhole_90degcw(
|
| 165 |
+
W: int, H: int, fx: float, fy: float, cx: float, cy: float
|
| 166 |
+
) -> tuple[int, int, float, float, float, float]:
|
| 167 |
+
"""Rotates the intrinsics of a pinhole camera model by 90 degrees clockwise."""
|
| 168 |
+
W_new = H
|
| 169 |
+
H_new = W
|
| 170 |
+
fx_new = fy
|
| 171 |
+
fy_new = fx
|
| 172 |
+
cy_new = cx
|
| 173 |
+
cx_new = H - 1 - cy
|
| 174 |
+
return W_new, H_new, fx_new, fy_new, cx_new, cy_new
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _gl_cv_cmat() -> np.ndarray:
|
| 178 |
+
cmat = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
|
| 179 |
+
return cmat
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _apply_transformation(
|
| 183 |
+
c2ws: torch.Tensor | np.ndarray, cmat: np.ndarray
|
| 184 |
+
) -> torch.Tensor | np.ndarray:
|
| 185 |
+
"""
|
| 186 |
+
Convert camera poses using a provided conversion matrix.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4)
|
| 190 |
+
cmat (torch.Tensor or np.ndarray): Conversion matrix (4, 4)
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4)
|
| 194 |
+
"""
|
| 195 |
+
if isinstance(c2ws, torch.Tensor):
|
| 196 |
+
# Clone the input tensor to avoid modifying it in-place
|
| 197 |
+
c2ws_transformed = c2ws.clone()
|
| 198 |
+
# Apply the conversion matrix to the rotation part of the camera poses
|
| 199 |
+
if len(c2ws.shape) == 3:
|
| 200 |
+
c2ws_transformed[:, :3, :3] = c2ws_transformed[
|
| 201 |
+
:, :3, :3
|
| 202 |
+
] @ torch.from_numpy(cmat[:3, :3]).to(c2ws).unsqueeze(0)
|
| 203 |
+
else:
|
| 204 |
+
c2ws_transformed[:3, :3] = c2ws_transformed[:3, :3] @ torch.from_numpy(
|
| 205 |
+
cmat[:3, :3]
|
| 206 |
+
).to(c2ws)
|
| 207 |
+
|
| 208 |
+
elif isinstance(c2ws, np.ndarray):
|
| 209 |
+
# Clone the input array to avoid modifying it in-place
|
| 210 |
+
c2ws_transformed = c2ws.copy()
|
| 211 |
+
if len(c2ws.shape) == 3: # batched
|
| 212 |
+
# Apply the conversion matrix to the rotation part of the camera poses
|
| 213 |
+
c2ws_transformed[:, :3, :3] = np.einsum(
|
| 214 |
+
"ijk,lk->ijl", c2ws_transformed[:, :3, :3], cmat[:3, :3]
|
| 215 |
+
)
|
| 216 |
+
else: # single 4x4 matrix
|
| 217 |
+
# Apply the conversion matrix to the rotation part of the camera pose
|
| 218 |
+
c2ws_transformed[:3, :3] = np.dot(c2ws_transformed[:3, :3], cmat[:3, :3])
|
| 219 |
+
|
| 220 |
+
else:
|
| 221 |
+
raise ValueError("Input data type not supported.")
|
| 222 |
+
|
| 223 |
+
return c2ws_transformed
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def gl2cv(
|
| 227 |
+
c2ws: torch.Tensor | np.ndarray,
|
| 228 |
+
return_cmat: bool = False,
|
| 229 |
+
) -> torch.Tensor | np.ndarray | tuple[torch.Tensor | np.ndarray, np.ndarray]:
|
| 230 |
+
"""
|
| 231 |
+
Convert camera poses from OpenGL to OpenCV coordinate system.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4)
|
| 235 |
+
return_cmat (bool): If True, return the conversion matrix along with the transformed poses
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4)
|
| 239 |
+
np.ndarray (optional): Conversion matrix if return_cmat is True
|
| 240 |
+
"""
|
| 241 |
+
cmat = _gl_cv_cmat()
|
| 242 |
+
if return_cmat:
|
| 243 |
+
return _apply_transformation(c2ws, cmat), cmat
|
| 244 |
+
return _apply_transformation(c2ws, cmat)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def intrinsics_to_fov(
|
| 248 |
+
fx: torch.Tensor, fy: torch.Tensor, h: torch.Tensor, w: torch.Tensor
|
| 249 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 250 |
+
"""
|
| 251 |
+
Compute the horizontal and vertical fields of view in radians from camera intrinsics.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
fx (torch.Tensor): focal x
|
| 255 |
+
fy (torch.Tensor): focal y
|
| 256 |
+
h (torch.Tensor): Image height(s) with shape (B,).
|
| 257 |
+
w (torch.Tensor): Image width(s) with shape (B,).
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
tuple[torch.Tensor, torch.Tensor]: A tuple containing the horizontal and vertical fields
|
| 261 |
+
of view in radians, both with shape (N,).
|
| 262 |
+
"""
|
| 263 |
+
return 2 * torch.atan((w / 2) / fx), 2 * torch.atan((h / 2) / fy)
|
mapanything/utils/wai/colormaps/colors_fps_5k.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fae94fe5fb565ff40d1c556ae2640d00fc068e732cb4af5bb64eef034790e07c
|
| 3 |
+
size 9478
|
mapanything/utils/wai/core.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This utils script contains PORTAGE of wai-core core methods for MapAnything.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from mapanything.utils.wai.camera import (
|
| 14 |
+
CAMERA_KEYS,
|
| 15 |
+
convert_camera_coeffs_to_pinhole_matrix,
|
| 16 |
+
interpolate_extrinsics,
|
| 17 |
+
interpolate_intrinsics,
|
| 18 |
+
)
|
| 19 |
+
from mapanything.utils.wai.io import _get_method, _load_scene_meta
|
| 20 |
+
from mapanything.utils.wai.ops import crop
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
WAI_COLORMAP_PATH = Path(__file__).parent / "colormaps"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_data(fname: str | Path, format_type: str | None = None, **kwargs) -> Any:
|
| 28 |
+
"""
|
| 29 |
+
Loads data from a file using the appropriate method based on the file format.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
fname (str or Path): The filename or path to load data from.
|
| 33 |
+
format_type (str, optional): The format type of the data. If None, it will be inferred from the file extension if possible.
|
| 34 |
+
Supported formats include: 'readable', 'scalar', 'image', 'binary', 'depth', 'normals',
|
| 35 |
+
'numpy', 'ptz', 'mmap', 'scene_meta', 'labeled_image', 'mesh', 'labeled_mesh', 'caption', "latents".
|
| 36 |
+
**kwargs: Additional keyword arguments to pass to the loading method.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
The loaded data in the format returned by the specific loading method.
|
| 40 |
+
|
| 41 |
+
Raises:
|
| 42 |
+
ValueError: If the format cannot be inferred from the file extension.
|
| 43 |
+
NotImplementedError: If the specified format is not supported.
|
| 44 |
+
FileExistsError: If the file does not exist.
|
| 45 |
+
"""
|
| 46 |
+
load_method = _get_method(fname, format_type, load=True)
|
| 47 |
+
return load_method(fname, **kwargs)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def store_data(
|
| 51 |
+
fname: str | Path,
|
| 52 |
+
data: Any,
|
| 53 |
+
format_type: str | None = None,
|
| 54 |
+
**kwargs,
|
| 55 |
+
) -> Any:
|
| 56 |
+
"""
|
| 57 |
+
Stores data to a file using the appropriate method based on the file format.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
fname (str or Path): The filename or path to store data to.
|
| 61 |
+
data: The data to be stored.
|
| 62 |
+
format_type (str, optional): The format type of the data. If None, it will be inferred from the file extension.
|
| 63 |
+
**kwargs: Additional keyword arguments to pass to the storing method.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
The result of the storing method, which may vary depending on the method used.
|
| 67 |
+
"""
|
| 68 |
+
store_method = _get_method(fname, format_type, load=False)
|
| 69 |
+
Path(fname).parent.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
return store_method(fname, data, **kwargs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_frame(
|
| 74 |
+
scene_meta: dict[str, Any],
|
| 75 |
+
frame_key: int | str | float,
|
| 76 |
+
) -> dict[str, Any]:
|
| 77 |
+
"""
|
| 78 |
+
Get a frame from scene_meta based on name or index.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
scene_meta: Dictionary containing scene metadata
|
| 82 |
+
frame_key: Either a string (frame name) or integer (frame index) or float (video timestamp)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
The frame data (dict)
|
| 86 |
+
"""
|
| 87 |
+
frame_idx = get_frame_index(scene_meta, frame_key)
|
| 88 |
+
if isinstance(frame_idx, int):
|
| 89 |
+
frame = scene_meta["frames"][frame_idx]
|
| 90 |
+
frame["_is_interpolated"] = False
|
| 91 |
+
else:
|
| 92 |
+
frame = {}
|
| 93 |
+
frame["frame_name"] = frame_key
|
| 94 |
+
left = int(frame_idx) # it's floor operation
|
| 95 |
+
assert left >= 0 and left < (len(scene_meta["frames"]) - 1), "Wrong index"
|
| 96 |
+
frame_left = scene_meta["frames"][left]
|
| 97 |
+
frame_right = scene_meta["frames"][left + 1]
|
| 98 |
+
# Interpolate intrinsics and extrinsics
|
| 99 |
+
frame["transform_matrix"] = interpolate_extrinsics(
|
| 100 |
+
frame_left["transform_matrix"],
|
| 101 |
+
frame_right["transform_matrix"],
|
| 102 |
+
frame_idx - left,
|
| 103 |
+
)
|
| 104 |
+
frame.update(
|
| 105 |
+
interpolate_intrinsics(
|
| 106 |
+
frame_left,
|
| 107 |
+
frame_right,
|
| 108 |
+
frame_idx - left,
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
frame["_is_interpolated"] = True
|
| 112 |
+
return frame
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_intrinsics(
|
| 116 |
+
scene_meta,
|
| 117 |
+
frame_key,
|
| 118 |
+
fmt: str = "torch",
|
| 119 |
+
) -> torch.Tensor | np.ndarray | list:
|
| 120 |
+
frame = get_frame(scene_meta, frame_key)
|
| 121 |
+
return convert_camera_coeffs_to_pinhole_matrix(scene_meta, frame, fmt=fmt)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_extrinsics(
|
| 125 |
+
scene_meta,
|
| 126 |
+
frame_key,
|
| 127 |
+
fmt: str = "torch",
|
| 128 |
+
) -> torch.Tensor | np.ndarray | list | None:
|
| 129 |
+
frame = get_frame(scene_meta, frame_key)
|
| 130 |
+
if "transform_matrix" in frame:
|
| 131 |
+
if fmt == "torch":
|
| 132 |
+
return torch.tensor(frame["transform_matrix"]).reshape(4, 4).float()
|
| 133 |
+
elif fmt == "np":
|
| 134 |
+
return np.array(frame["transform_matrix"]).reshape(4, 4)
|
| 135 |
+
return frame["transform_matrix"]
|
| 136 |
+
else:
|
| 137 |
+
# TODO: should not happen if we enable interpolation
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_frame_index(
|
| 142 |
+
scene_meta: dict[str, Any],
|
| 143 |
+
frame_key: int | str | float,
|
| 144 |
+
frame_index_threshold_sec: float = 1e-4,
|
| 145 |
+
distance_threshold_sec: float = 2.0,
|
| 146 |
+
) -> int | float:
|
| 147 |
+
"""
|
| 148 |
+
Returns the frame index from scene_meta based on name (str) or index (int) or sub-frame index (float).
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
scene_meta: Dictionary containing scene metadata
|
| 152 |
+
frame_key: Either a string (frame name) or integer (frame index) or float (sub-frame index)
|
| 153 |
+
frame_index_threshold_sec: A threshold for nearest neighbor clipping for indexes (in seconds).
|
| 154 |
+
Default is 1e-4, which is 10000 fps.
|
| 155 |
+
distance_th: A threshold for maximum distance between interpolated frames (in seconds).
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Frame index (int)
|
| 159 |
+
|
| 160 |
+
Raises:
|
| 161 |
+
ValueError: If frame_key is not a string or integer or float
|
| 162 |
+
"""
|
| 163 |
+
if isinstance(frame_key, str):
|
| 164 |
+
try:
|
| 165 |
+
return scene_meta["frame_names"][frame_key]
|
| 166 |
+
except KeyError as err:
|
| 167 |
+
error_message = (
|
| 168 |
+
f"Frame name not found: {frame_key} - "
|
| 169 |
+
f"Please verify scene_meta.json of scene: {scene_meta['dataset_name']}/{scene_meta['scene_name']}"
|
| 170 |
+
)
|
| 171 |
+
logger.error(error_message)
|
| 172 |
+
raise KeyError(error_message) from err
|
| 173 |
+
|
| 174 |
+
if isinstance(frame_key, int):
|
| 175 |
+
return frame_key
|
| 176 |
+
|
| 177 |
+
if isinstance(frame_key, float):
|
| 178 |
+
# If exact hit
|
| 179 |
+
if frame_key in scene_meta["frame_names"]:
|
| 180 |
+
return scene_meta["frame_names"][frame_key]
|
| 181 |
+
|
| 182 |
+
frame_names = sorted(list(scene_meta["frame_names"].keys()))
|
| 183 |
+
distances = np.array([frm - frame_key for frm in frame_names])
|
| 184 |
+
left = int(np.nonzero(distances <= 0)[0][-1])
|
| 185 |
+
right = left + 1
|
| 186 |
+
|
| 187 |
+
# The last frame or rounding errors
|
| 188 |
+
if (
|
| 189 |
+
left == distances.shape[0] - 1
|
| 190 |
+
or abs(distances[left]) < frame_index_threshold_sec
|
| 191 |
+
):
|
| 192 |
+
return scene_meta["frame_names"][frame_names[int(left)]]
|
| 193 |
+
if abs(distances[right]) < frame_index_threshold_sec:
|
| 194 |
+
return scene_meta["frame_names"][frame_names[int(right)]]
|
| 195 |
+
|
| 196 |
+
interpolation_distance = distances[right] - distances[left]
|
| 197 |
+
if interpolation_distance > distance_threshold_sec:
|
| 198 |
+
raise ValueError(
|
| 199 |
+
f"Frame interpolation is forbidden for distances larger than {distance_threshold_sec}."
|
| 200 |
+
)
|
| 201 |
+
alpha = -distances[left] / interpolation_distance
|
| 202 |
+
|
| 203 |
+
return scene_meta["frame_names"][frame_names[int(left)]] + alpha
|
| 204 |
+
|
| 205 |
+
raise ValueError(f"Frame key type not supported: {frame_key} ({type(frame_key)}).")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def load_modality_data(
|
| 209 |
+
scene_root: Path | str,
|
| 210 |
+
results: dict[str, Any],
|
| 211 |
+
modality_dict: dict[str, Any],
|
| 212 |
+
modality: str,
|
| 213 |
+
frame: dict[str, Any] | None = None,
|
| 214 |
+
fmt: str = "torch",
|
| 215 |
+
) -> dict[str, Any]:
|
| 216 |
+
"""
|
| 217 |
+
Processes a modality by loading data from a specified path and updating the results dictionary.
|
| 218 |
+
This function extracts the format and path from the given modality dictionary, loads the data
|
| 219 |
+
from the specified path, and updates the results dictionary with the loaded data.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
scene_root (str or Path): The root directory of the scene where the data is located.
|
| 223 |
+
results (dict): A dictionary to store the loaded modality data and optional frame path.
|
| 224 |
+
modality_dict (dict): A dictionary containing the modality information, including 'format'
|
| 225 |
+
and the path to the data.
|
| 226 |
+
modality (str): The key under which the loaded modality data will be stored in the results.
|
| 227 |
+
frame (dict, optional): A dictionary containing frame information. If provided, that means we are loading
|
| 228 |
+
frame modalities, otherwise it is scene modalities.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
dict: The updated results dictionary containing the loaded modality data.
|
| 232 |
+
"""
|
| 233 |
+
modality_format = modality_dict["format"]
|
| 234 |
+
|
| 235 |
+
# The modality is stored as a video
|
| 236 |
+
if "video" in modality_format:
|
| 237 |
+
assert isinstance(frame["frame_name"], float), "frame_name should be float"
|
| 238 |
+
video_file = None
|
| 239 |
+
if "chunks" in modality_dict:
|
| 240 |
+
video_list = modality_dict["chunks"]
|
| 241 |
+
# Get the correct chunk of the video
|
| 242 |
+
for video_chunk in video_list:
|
| 243 |
+
if video_chunk["start"] <= frame["frame_name"] <= video_chunk["end"]:
|
| 244 |
+
video_file = video_chunk
|
| 245 |
+
break
|
| 246 |
+
else:
|
| 247 |
+
# There is only one video (no chunks)
|
| 248 |
+
video_file = modality_dict
|
| 249 |
+
if "start" not in video_file:
|
| 250 |
+
video_file["start"] = 0
|
| 251 |
+
if "end" not in video_file:
|
| 252 |
+
video_file["end"] = float("inf")
|
| 253 |
+
if not (video_file["start"] <= frame["frame_name"] <= video_file["end"]):
|
| 254 |
+
video_file = None
|
| 255 |
+
|
| 256 |
+
# This timestamp is not available in any of the chunks
|
| 257 |
+
if video_file is None:
|
| 258 |
+
frame_name = frame["frame_name"]
|
| 259 |
+
logger.warning(
|
| 260 |
+
f"Modality {modality} ({modality_format}) is not available at time {frame_name}"
|
| 261 |
+
)
|
| 262 |
+
return results
|
| 263 |
+
|
| 264 |
+
# Load the modality from the video
|
| 265 |
+
loaded_modality = load_data(
|
| 266 |
+
Path(scene_root, video_file["file"]),
|
| 267 |
+
modality_format,
|
| 268 |
+
frame_key=frame["frame_name"] - video_file["start"],
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if "bbox" in video_file:
|
| 272 |
+
loaded_modality = crop(loaded_modality, video_file["bbox"])
|
| 273 |
+
|
| 274 |
+
if loaded_modality is not None:
|
| 275 |
+
results[modality] = loaded_modality
|
| 276 |
+
|
| 277 |
+
if frame:
|
| 278 |
+
results[f"{modality}_fname"] = video_file["file"]
|
| 279 |
+
else:
|
| 280 |
+
modality_path = [v for k, v in modality_dict.items() if k != "format"][0]
|
| 281 |
+
if frame:
|
| 282 |
+
if modality_path in frame:
|
| 283 |
+
fname = frame[modality_path]
|
| 284 |
+
else:
|
| 285 |
+
fname = None
|
| 286 |
+
else:
|
| 287 |
+
fname = modality_path
|
| 288 |
+
if fname is not None:
|
| 289 |
+
loaded_modality = load_data(
|
| 290 |
+
Path(scene_root, fname),
|
| 291 |
+
modality_format,
|
| 292 |
+
frame_key=frame["frame_name"] if frame else None,
|
| 293 |
+
fmt=fmt,
|
| 294 |
+
)
|
| 295 |
+
results[modality] = loaded_modality
|
| 296 |
+
if frame:
|
| 297 |
+
results[f"{modality}_fname"] = frame[modality_path]
|
| 298 |
+
return results
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def load_modality(
|
| 302 |
+
scene_root: Path | str,
|
| 303 |
+
modality_meta: dict[str, Any],
|
| 304 |
+
modality: str,
|
| 305 |
+
frame: dict[str, Any] | None = None,
|
| 306 |
+
fmt: str = "torch",
|
| 307 |
+
) -> dict[str, Any]:
|
| 308 |
+
"""
|
| 309 |
+
Loads modality data based on the provided metadata and updates the results dictionary.
|
| 310 |
+
This function navigates through the modality metadata to find the specified modality,
|
| 311 |
+
then loads the data for each modality found.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
scene_root (str or Path): The root directory of the scene where the data is located.
|
| 315 |
+
modality_meta (dict): A nested dictionary containing metadata for various modalities.
|
| 316 |
+
modality (str): A string representing the path to the desired modality within the metadata,
|
| 317 |
+
using '/' as a separator for nested keys.
|
| 318 |
+
frame (dict, optional): A dictionary containing frame information. If provided, we are operating
|
| 319 |
+
on frame modalities, otherwise it is scene modalities.
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
dict: A dictionary containing the loaded modality data.
|
| 323 |
+
"""
|
| 324 |
+
results = {}
|
| 325 |
+
# support for nested modalities like "pred_depth/metric3dv2"
|
| 326 |
+
modality_keys = modality.split("/")
|
| 327 |
+
current_modality = modality_meta
|
| 328 |
+
for key in modality_keys:
|
| 329 |
+
try:
|
| 330 |
+
current_modality = current_modality[key]
|
| 331 |
+
except KeyError as err:
|
| 332 |
+
error_message = (
|
| 333 |
+
f"Modality '{err.args[0]}' not found in modalities metadata. "
|
| 334 |
+
f"Please verify the scene_meta.json and the provided modalities in {scene_root}."
|
| 335 |
+
)
|
| 336 |
+
logger.error(error_message)
|
| 337 |
+
raise KeyError(error_message) from err
|
| 338 |
+
if "format" in current_modality:
|
| 339 |
+
results = load_modality_data(
|
| 340 |
+
scene_root, results, current_modality, modality, frame, fmt=fmt
|
| 341 |
+
)
|
| 342 |
+
else:
|
| 343 |
+
# nested modality, return last by default
|
| 344 |
+
logger.warning("Nested modality, returning last by default")
|
| 345 |
+
key = next(reversed(current_modality.keys()))
|
| 346 |
+
results = load_modality_data(
|
| 347 |
+
scene_root, results, current_modality[key], modality, frame, fmt=fmt
|
| 348 |
+
)
|
| 349 |
+
return results
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def load_frame(
|
| 353 |
+
scene_root: Path | str,
|
| 354 |
+
frame_key: int | str | float,
|
| 355 |
+
modalities: str | list[str] | None = None,
|
| 356 |
+
scene_meta: dict[str, Any] | None = None,
|
| 357 |
+
load_intrinsics: bool = True,
|
| 358 |
+
load_extrinsics: bool = True,
|
| 359 |
+
fmt: str = "torch",
|
| 360 |
+
interpolate: bool = False,
|
| 361 |
+
) -> dict[str, Any]:
|
| 362 |
+
"""
|
| 363 |
+
Load a single frame from a scene with specified modalities.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
scene_root (str or Path): The root directory of the scene where the data is located.
|
| 367 |
+
frame_key (int or str or float): Either a string (frame name) or integer (frame index) or float (video timestamp).
|
| 368 |
+
modalities (str or list[str], optional): The modality or list of modalities to load.
|
| 369 |
+
If None, only basic frame information is loaded.
|
| 370 |
+
scene_meta (dict, optional): Dictionary containing scene metadata. If None, it will be loaded
|
| 371 |
+
from scene_meta.json in the scene_root.
|
| 372 |
+
interpolate (bool, optional): Allow interpolating frames?
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
dict: A dictionary containing the loaded frame data with the requested modalities.
|
| 376 |
+
"""
|
| 377 |
+
scene_root = Path(scene_root)
|
| 378 |
+
if scene_meta is None:
|
| 379 |
+
scene_meta = _load_scene_meta(scene_root / "scene_meta.json")
|
| 380 |
+
frame = get_frame(scene_meta, frame_key)
|
| 381 |
+
# compact, standarized frame representation
|
| 382 |
+
wai_frame = {}
|
| 383 |
+
if load_extrinsics:
|
| 384 |
+
extrinsics = get_extrinsics(
|
| 385 |
+
scene_meta,
|
| 386 |
+
frame_key,
|
| 387 |
+
fmt=fmt,
|
| 388 |
+
)
|
| 389 |
+
if extrinsics is not None:
|
| 390 |
+
wai_frame["extrinsics"] = extrinsics
|
| 391 |
+
if load_intrinsics:
|
| 392 |
+
camera_model = frame.get("camera_model", scene_meta.get("camera_model"))
|
| 393 |
+
wai_frame["camera_model"] = camera_model
|
| 394 |
+
if camera_model == "PINHOLE":
|
| 395 |
+
wai_frame["intrinsics"] = get_intrinsics(scene_meta, frame_key, fmt=fmt)
|
| 396 |
+
elif camera_model in ["OPENCV", "OPENCV_FISHEYE"]:
|
| 397 |
+
# optional per-frame intrinsics
|
| 398 |
+
for camera_key in CAMERA_KEYS:
|
| 399 |
+
if camera_key in frame:
|
| 400 |
+
wai_frame[camera_key] = float(frame[camera_key])
|
| 401 |
+
elif camera_key in scene_meta:
|
| 402 |
+
wai_frame[camera_key] = float(scene_meta[camera_key])
|
| 403 |
+
else:
|
| 404 |
+
error_message = (
|
| 405 |
+
f"Camera model not supported: {camera_model} - "
|
| 406 |
+
f"Please verify scene_meta.json of scene: {scene_meta['dataset_name']}/{scene_meta['scene_name']}"
|
| 407 |
+
)
|
| 408 |
+
logger.error(error_message)
|
| 409 |
+
raise NotImplementedError(error_message)
|
| 410 |
+
wai_frame["w"] = frame.get("w", scene_meta["w"] if "w" in scene_meta else None)
|
| 411 |
+
wai_frame["h"] = frame.get("h", scene_meta["h"] if "h" in scene_meta else None)
|
| 412 |
+
wai_frame["frame_name"] = frame["frame_name"]
|
| 413 |
+
wai_frame["frame_idx"] = get_frame_index(scene_meta, frame_key)
|
| 414 |
+
wai_frame["_is_interpolated"] = frame["_is_interpolated"]
|
| 415 |
+
|
| 416 |
+
if modalities is not None:
|
| 417 |
+
if isinstance(modalities, str):
|
| 418 |
+
modalities = [modalities]
|
| 419 |
+
for modality in modalities:
|
| 420 |
+
# Handle regex patterns in modality
|
| 421 |
+
if any(char in modality for char in ".|*+?()[]{}^$\\"):
|
| 422 |
+
# This is a regex pattern
|
| 423 |
+
pattern = re.compile(modality)
|
| 424 |
+
matching_modalities = [
|
| 425 |
+
m for m in scene_meta["frame_modalities"] if pattern.match(m)
|
| 426 |
+
]
|
| 427 |
+
if not matching_modalities:
|
| 428 |
+
raise ValueError(
|
| 429 |
+
f"No modalities match the pattern: {modality} in scene: {scene_root}"
|
| 430 |
+
)
|
| 431 |
+
# Use the first matching modality
|
| 432 |
+
modality = matching_modalities[0]
|
| 433 |
+
current_modalities = load_modality(
|
| 434 |
+
scene_root, scene_meta["frame_modalities"], modality, frame, fmt=fmt
|
| 435 |
+
)
|
| 436 |
+
wai_frame.update(current_modalities)
|
| 437 |
+
|
| 438 |
+
return wai_frame
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def set_frame(
|
| 442 |
+
scene_meta: dict[str, Any],
|
| 443 |
+
frame_key: int | str,
|
| 444 |
+
new_frame: dict[str, Any],
|
| 445 |
+
sort: bool = False,
|
| 446 |
+
) -> dict[str, Any]:
|
| 447 |
+
"""
|
| 448 |
+
Replace a frame in scene_meta with a new frame.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
scene_meta: Dictionary containing scene metadata.
|
| 452 |
+
frame_key: Either a string (frame name) or integer (frame index).
|
| 453 |
+
new_frame: New frame data to replace the existing frame.
|
| 454 |
+
sort: If True, sort the keys in the new_frame dictionary.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
Updated scene_meta dictionary.
|
| 458 |
+
"""
|
| 459 |
+
frame_idx = get_frame_index(scene_meta, frame_key)
|
| 460 |
+
if isinstance(frame_idx, float):
|
| 461 |
+
raise ValueError(
|
| 462 |
+
f"Setting frame for sub-frame frame_key is not supported: {frame_key} ({type(frame_key)})."
|
| 463 |
+
)
|
| 464 |
+
if sort:
|
| 465 |
+
new_frame = {k: new_frame[k] for k in sorted(new_frame)}
|
| 466 |
+
scene_meta["frames"][frame_idx] = new_frame
|
| 467 |
+
return scene_meta
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def nest_modality(
|
| 471 |
+
frame_modalities: dict[str, Any],
|
| 472 |
+
modality_name: str,
|
| 473 |
+
) -> dict[str, Any]:
|
| 474 |
+
"""
|
| 475 |
+
Converts a flat modality structure into a nested one based on the modality name.
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
frame_modalities (dict): Dictionary containing frame modalities.
|
| 479 |
+
modality_name (str): The name of the modality to nest.
|
| 480 |
+
|
| 481 |
+
Returns:
|
| 482 |
+
dict: A dictionary with the nested modality structure.
|
| 483 |
+
"""
|
| 484 |
+
frame_modality = {}
|
| 485 |
+
if modality_name in frame_modalities:
|
| 486 |
+
frame_modality = frame_modalities[modality_name]
|
| 487 |
+
if "frame_key" in frame_modality:
|
| 488 |
+
# required for backwards compatibility
|
| 489 |
+
# converting non-nested format into nested one based on name
|
| 490 |
+
modality_name = frame_modality["frame_key"].split("_")[0]
|
| 491 |
+
frame_modality = {modality_name: frame_modality}
|
| 492 |
+
return frame_modality
|
mapanything/utils/wai/intersection_check.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange, repeat
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_frustum_from_intrinsics(
|
| 7 |
+
intrinsics: torch.Tensor,
|
| 8 |
+
near: torch.Tensor | float,
|
| 9 |
+
far: torch.Tensor | float,
|
| 10 |
+
) -> torch.Tensor:
|
| 11 |
+
r"""
|
| 12 |
+
Create a frustum from camera intrinsics.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
intrinsics (torch.Tensor): Bx3x3 Intrinsics of cameras.
|
| 16 |
+
near (torch.Tensor or float): [B] Near plane distance.
|
| 17 |
+
far (torch.Tensor or float): [B] Far plane distance.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
frustum (torch.Tensor): Bx8x3 batch of frustum points following the order:
|
| 21 |
+
5 ---------- 4
|
| 22 |
+
|\ /|
|
| 23 |
+
6 \ / 7
|
| 24 |
+
\ 1 ---- 0 /
|
| 25 |
+
\| |/
|
| 26 |
+
2 ---- 3
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
|
| 30 |
+
cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
|
| 31 |
+
|
| 32 |
+
# Calculate the offsets at the near plane
|
| 33 |
+
near_x = near * (cx / fx)
|
| 34 |
+
near_y = near * (cy / fy)
|
| 35 |
+
far_x = far * (cx / fx)
|
| 36 |
+
far_y = far * (cy / fy)
|
| 37 |
+
|
| 38 |
+
# Define frustum vertices in camera space
|
| 39 |
+
near_plane = torch.stack(
|
| 40 |
+
[
|
| 41 |
+
torch.stack([near_x, near_y, near * torch.ones_like(near_x)], dim=-1),
|
| 42 |
+
torch.stack([-near_x, near_y, near * torch.ones_like(near_x)], dim=-1),
|
| 43 |
+
torch.stack([-near_x, -near_y, near * torch.ones_like(near_x)], dim=-1),
|
| 44 |
+
torch.stack([near_x, -near_y, near * torch.ones_like(near_x)], dim=-1),
|
| 45 |
+
],
|
| 46 |
+
dim=1,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
far_plane = torch.stack(
|
| 50 |
+
[
|
| 51 |
+
torch.stack([far_x, far_y, far * torch.ones_like(far_x)], dim=-1),
|
| 52 |
+
torch.stack([-far_x, far_y, far * torch.ones_like(far_x)], dim=-1),
|
| 53 |
+
torch.stack([-far_x, -far_y, far * torch.ones_like(far_x)], dim=-1),
|
| 54 |
+
torch.stack([far_x, -far_y, far * torch.ones_like(far_x)], dim=-1),
|
| 55 |
+
],
|
| 56 |
+
dim=1,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return torch.cat([near_plane, far_plane], dim=1)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _frustum_to_triangles(frustum: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
"""
|
| 64 |
+
Convert frustum to triangles.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
frustums (torch.Tensor): Bx8 batch of frustum points.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
frustum_triangles (torch.Tensor): Bx3x3 batch of frustum triangles.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
triangle_inds = torch.tensor(
|
| 74 |
+
[
|
| 75 |
+
[0, 1, 2],
|
| 76 |
+
[0, 2, 3],
|
| 77 |
+
[0, 3, 7],
|
| 78 |
+
[0, 7, 4],
|
| 79 |
+
[1, 2, 6],
|
| 80 |
+
[1, 6, 5],
|
| 81 |
+
[1, 4, 5],
|
| 82 |
+
[1, 0, 4],
|
| 83 |
+
[2, 6, 7],
|
| 84 |
+
[2, 3, 7],
|
| 85 |
+
[6, 7, 4],
|
| 86 |
+
[6, 5, 4],
|
| 87 |
+
]
|
| 88 |
+
)
|
| 89 |
+
frustum_triangles = frustum[:, triangle_inds]
|
| 90 |
+
return frustum_triangles
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def segment_triangle_intersection_check(
|
| 94 |
+
start_points: torch.Tensor,
|
| 95 |
+
end_points: torch.Tensor,
|
| 96 |
+
triangles: torch.Tensor,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
Check if segments (lines with starting and end point) intersect triangles in 3D using the
|
| 100 |
+
Moller-Trumbore algorithm.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
start_points (torch.Tensor): Bx3 Starting points of the segment.
|
| 104 |
+
end_points (torch.Tensor): Bx3 End points of the segment.
|
| 105 |
+
triangles (torch.Tensor): Bx3x3 Vertices of the triangles.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its
|
| 109 |
+
corresponding triangle.
|
| 110 |
+
"""
|
| 111 |
+
vertex0 = triangles[:, 0, :]
|
| 112 |
+
vertex1 = triangles[:, 1, :]
|
| 113 |
+
vertex2 = triangles[:, 2, :]
|
| 114 |
+
edge1 = vertex1 - vertex0
|
| 115 |
+
edge2 = vertex2 - vertex0
|
| 116 |
+
ray_vectors = end_points - start_points
|
| 117 |
+
max_lengths = torch.norm(ray_vectors, dim=1)
|
| 118 |
+
ray_vectors = ray_vectors / max_lengths[:, None]
|
| 119 |
+
h = torch.cross(ray_vectors, edge2, dim=1)
|
| 120 |
+
a = (edge1 * h).sum(dim=1)
|
| 121 |
+
|
| 122 |
+
epsilon = 1e-6
|
| 123 |
+
mask = torch.abs(a) > epsilon
|
| 124 |
+
f = torch.zeros_like(a)
|
| 125 |
+
f[mask] = 1.0 / a[mask]
|
| 126 |
+
|
| 127 |
+
s = start_points - vertex0
|
| 128 |
+
u = f * (s * h).sum(dim=1)
|
| 129 |
+
q = torch.cross(s, edge1, dim=1)
|
| 130 |
+
v = f * (ray_vectors * q).sum(dim=1)
|
| 131 |
+
|
| 132 |
+
t = f * (edge2 * q).sum(dim=1)
|
| 133 |
+
|
| 134 |
+
# Check conditions
|
| 135 |
+
intersects = (
|
| 136 |
+
(u >= 0)
|
| 137 |
+
& (u <= 1)
|
| 138 |
+
& (v >= 0)
|
| 139 |
+
& (u + v <= 1)
|
| 140 |
+
& (t >= epsilon)
|
| 141 |
+
& (t <= max_lengths)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return intersects
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def triangle_intersection_check(
|
| 148 |
+
triangles1: torch.Tensor,
|
| 149 |
+
triangles2: torch.Tensor,
|
| 150 |
+
) -> torch.Tensor:
|
| 151 |
+
"""
|
| 152 |
+
Check if two triangles intersect.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
triangles1 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles.
|
| 156 |
+
triangles2 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
triangle_intersection (torch.Tensor): B Boolean tensor indicating if triangles intersect.
|
| 160 |
+
"""
|
| 161 |
+
n = triangles1.shape[1]
|
| 162 |
+
start_points1 = rearrange(triangles1, "B N C -> (B N) C")
|
| 163 |
+
end_points1 = rearrange(
|
| 164 |
+
triangles1[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
start_points2 = rearrange(triangles2, "B N C -> (B N) C")
|
| 168 |
+
end_points2 = rearrange(
|
| 169 |
+
triangles2[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C"
|
| 170 |
+
)
|
| 171 |
+
intersection_1_2 = segment_triangle_intersection_check(
|
| 172 |
+
start_points1, end_points1, repeat(triangles2, "B N C -> (B N2) N C", N2=3)
|
| 173 |
+
)
|
| 174 |
+
intersection_2_1 = segment_triangle_intersection_check(
|
| 175 |
+
start_points2, end_points2, repeat(triangles1, "B N C -> (B N2) N C", N2=3)
|
| 176 |
+
)
|
| 177 |
+
triangle_intersection = torch.any(
|
| 178 |
+
rearrange(intersection_1_2, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n),
|
| 179 |
+
dim=1,
|
| 180 |
+
) | torch.any(
|
| 181 |
+
rearrange(intersection_2_1, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n),
|
| 182 |
+
dim=1,
|
| 183 |
+
)
|
| 184 |
+
return triangle_intersection
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def frustum_intersection_check(
|
| 188 |
+
frustums: torch.Tensor,
|
| 189 |
+
check_inside: bool = True,
|
| 190 |
+
chunk_size: int = 500,
|
| 191 |
+
device: str | None = None,
|
| 192 |
+
) -> torch.Tensor:
|
| 193 |
+
"""
|
| 194 |
+
Check if any pair of the frustums intersect with each other.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
frustums (torch.Tensor): Bx8 batch of frustum points.
|
| 198 |
+
check_inside (bool): If True, also checks if one frustum is inside another.
|
| 199 |
+
Defaults to True.
|
| 200 |
+
chunk_size (Optional[int]): Number of chunks to split the computation into.
|
| 201 |
+
Defaults to 500.
|
| 202 |
+
device (Optional[str]): Device to store exhuastive frustum intersection matrix on.
|
| 203 |
+
Defaults to None.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
frustum_intersection (torch.Tensor): BxB tensor of Booleans indicating if any pair
|
| 207 |
+
of frustums intersect with each other.
|
| 208 |
+
"""
|
| 209 |
+
B = frustums.shape[0]
|
| 210 |
+
if device is None:
|
| 211 |
+
device = frustums.device
|
| 212 |
+
frustum_triangles = _frustum_to_triangles(frustums)
|
| 213 |
+
T = frustum_triangles.shape[1]
|
| 214 |
+
|
| 215 |
+
# Perform frustum in frustum check if required
|
| 216 |
+
if check_inside:
|
| 217 |
+
frustum_intersection = frustums_in_frustum_check(
|
| 218 |
+
frustums=frustums, chunk_size=chunk_size, device=device
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
frustum_intersection = torch.zeros((B, B), dtype=torch.bool, device=device)
|
| 222 |
+
|
| 223 |
+
# Check triangle intersections in chunks
|
| 224 |
+
for i in tqdm(range(0, B, chunk_size), desc="Checking triangle intersections"):
|
| 225 |
+
i_end = min(i + chunk_size, B)
|
| 226 |
+
chunk_i_size = i_end - i
|
| 227 |
+
|
| 228 |
+
for j in range(0, B, chunk_size):
|
| 229 |
+
j_end = min(j + chunk_size, B)
|
| 230 |
+
chunk_j_size = j_end - j
|
| 231 |
+
|
| 232 |
+
# Process all triangle pairs between the two chunks in a vectorized way
|
| 233 |
+
triangles_i = frustum_triangles[i:i_end] # [chunk_i, T, 3, 3]
|
| 234 |
+
triangles_j = frustum_triangles[j:j_end] # [chunk_j, T, 3, 3]
|
| 235 |
+
|
| 236 |
+
# Reshape to process all triangle pairs at once
|
| 237 |
+
tri_i = triangles_i.reshape(chunk_i_size * T, 3, 3)
|
| 238 |
+
tri_j = triangles_j.reshape(chunk_j_size * T, 3, 3)
|
| 239 |
+
|
| 240 |
+
# Expand for all pairs - explicitly specify dimensions instead of using ...
|
| 241 |
+
tri_i_exp = repeat(tri_i, "bt i j -> (bt bj_t) i j", bj_t=chunk_j_size * T)
|
| 242 |
+
tri_j_exp = repeat(tri_j, "bt i j -> (bi_t bt) i j", bi_t=chunk_i_size * T)
|
| 243 |
+
|
| 244 |
+
# Check intersection
|
| 245 |
+
batch_intersect = triangle_intersection_check(tri_i_exp, tri_j_exp)
|
| 246 |
+
|
| 247 |
+
# Reshape and check if any triangle pair intersects
|
| 248 |
+
batch_intersect = batch_intersect.reshape(chunk_i_size, T, chunk_j_size, T)
|
| 249 |
+
batch_intersect = batch_intersect.any(dim=(1, 3))
|
| 250 |
+
|
| 251 |
+
# Update result
|
| 252 |
+
frustum_intersection[i:i_end, j:j_end] |= batch_intersect.to(device)
|
| 253 |
+
|
| 254 |
+
return frustum_intersection
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def ray_triangle_intersection_check(
|
| 258 |
+
ray_origins: torch.Tensor,
|
| 259 |
+
ray_vectors: torch.Tensor,
|
| 260 |
+
triangles: torch.Tensor,
|
| 261 |
+
max_lengths: torch.Tensor | None = None,
|
| 262 |
+
) -> torch.Tensor:
|
| 263 |
+
"""
|
| 264 |
+
Check if rays intersect triangles in 3D using the Moller-Trumbore algorithm, considering the
|
| 265 |
+
finite length of rays.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
ray_origins (torch.Tensor): Bx3 Origins of the rays.
|
| 269 |
+
ray_vectors (torch.Tensor): Bx3 Direction vectors of the rays.
|
| 270 |
+
triangles (torch.Tensor): Bx3x3 Vertices of the triangles.
|
| 271 |
+
max_lengths Optional[torch.Tensor]: B Maximum lengths of the rays.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its
|
| 275 |
+
corresponding triangle.
|
| 276 |
+
"""
|
| 277 |
+
vertex0 = triangles[:, 0, :]
|
| 278 |
+
vertex1 = triangles[:, 1, :]
|
| 279 |
+
vertex2 = triangles[:, 2, :]
|
| 280 |
+
edge1 = vertex1 - vertex0
|
| 281 |
+
edge2 = vertex2 - vertex0
|
| 282 |
+
h = torch.cross(ray_vectors, edge2, dim=1)
|
| 283 |
+
a = (edge1 * h).sum(dim=1)
|
| 284 |
+
|
| 285 |
+
epsilon = 1e-6
|
| 286 |
+
mask = torch.abs(a) > epsilon
|
| 287 |
+
f = torch.zeros_like(a)
|
| 288 |
+
f[mask] = 1.0 / a[mask]
|
| 289 |
+
|
| 290 |
+
s = ray_origins - vertex0
|
| 291 |
+
u = f * (s * h).sum(dim=1)
|
| 292 |
+
q = torch.cross(s, edge1, dim=1)
|
| 293 |
+
v = f * (ray_vectors * q).sum(dim=1)
|
| 294 |
+
|
| 295 |
+
t = f * (edge2 * q).sum(dim=1)
|
| 296 |
+
|
| 297 |
+
# Check conditions
|
| 298 |
+
intersects = (u >= 0) & (u <= 1) & (v >= 0) & (u + v <= 1) & (t >= epsilon)
|
| 299 |
+
if max_lengths is not None:
|
| 300 |
+
intersects &= t <= max_lengths
|
| 301 |
+
|
| 302 |
+
return intersects
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
#### Checks for frustums
|
| 306 |
+
def _frustum_to_planes(frustums: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
r"""
|
| 308 |
+
Converts frustum parameters to plane representation.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
frustums (torch.Tensor): Bx8 batch of frustum points following the order:
|
| 312 |
+
5 ---------- 4
|
| 313 |
+
|\ /|
|
| 314 |
+
6 \ / 7
|
| 315 |
+
\ 1 ---- 0 /
|
| 316 |
+
\| |/
|
| 317 |
+
2 ---- 3
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
planes (torch.Tensor): Bx6x4 where 6 represents the six frustum planes and
|
| 321 |
+
4 represents plane parameters [a, b, c, d].
|
| 322 |
+
"""
|
| 323 |
+
planes = []
|
| 324 |
+
for inds in [[0, 1, 3], [1, 6, 2], [0, 3, 7], [2, 6, 3], [0, 5, 1], [6, 5, 4]]:
|
| 325 |
+
normal = torch.cross(
|
| 326 |
+
frustums[:, inds[1]] - frustums[:, inds[0]],
|
| 327 |
+
frustums[:, inds[2]] - frustums[:, inds[0]],
|
| 328 |
+
dim=1,
|
| 329 |
+
)
|
| 330 |
+
normal = normal / torch.norm(normal, dim=1, keepdim=True)
|
| 331 |
+
d = -torch.sum(normal * frustums[:, inds[0]], dim=1, keepdim=True)
|
| 332 |
+
planes.append(torch.cat([normal, d], -1))
|
| 333 |
+
return torch.stack(planes, 1)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def points_in_frustum_check(
|
| 337 |
+
frustums: torch.Tensor,
|
| 338 |
+
points: torch.Tensor,
|
| 339 |
+
chunk_size: int | None = None,
|
| 340 |
+
device: str | None = None,
|
| 341 |
+
):
|
| 342 |
+
"""
|
| 343 |
+
Check if points are inside frustums.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
frustums (torch.Tensor): Bx8 batch of frustum points.
|
| 347 |
+
points (torch.Tensor): BxNx3 batch of points.
|
| 348 |
+
chunk_size (Optional[int]): Number of chunks to split the computation into. Defaults to None.
|
| 349 |
+
device (Optional[str]): Device to perfrom computation on. Defaults to None.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
inside (torch.Tensor): BxN batch of Booleans indicating if points are inside frustums.
|
| 353 |
+
"""
|
| 354 |
+
if device is None:
|
| 355 |
+
device = frustums.device
|
| 356 |
+
|
| 357 |
+
if chunk_size is not None:
|
| 358 |
+
# Split computation into chunks to avoid OOM errors for large batch sizes
|
| 359 |
+
point_plane_direction = []
|
| 360 |
+
for chunk_idx in range(0, frustums.shape[0], chunk_size):
|
| 361 |
+
chunk_frustum_planes = _frustum_to_planes(
|
| 362 |
+
frustums[chunk_idx : chunk_idx + chunk_size]
|
| 363 |
+
)
|
| 364 |
+
# Bx8x4 tensor of plane parameters [a, b, c, d]
|
| 365 |
+
chunk_points = points[chunk_idx : chunk_idx + chunk_size]
|
| 366 |
+
chunk_point_plane_direction = torch.einsum(
|
| 367 |
+
"bij,bnj->bni", (chunk_frustum_planes[:, :, :-1], chunk_points)
|
| 368 |
+
) + repeat(
|
| 369 |
+
chunk_frustum_planes[:, :, -1], "B P -> B N P", N=chunk_points.shape[1]
|
| 370 |
+
) # BxMxN tensor
|
| 371 |
+
point_plane_direction.append(chunk_point_plane_direction.to(device))
|
| 372 |
+
point_plane_direction = torch.cat(point_plane_direction)
|
| 373 |
+
else:
|
| 374 |
+
# Convert frustums to planes
|
| 375 |
+
frustum_planes = _frustum_to_planes(
|
| 376 |
+
frustums
|
| 377 |
+
) # Bx8x4 tensor of plane parameters [a, b, c, d]
|
| 378 |
+
# Compute dot product between each point and each plane
|
| 379 |
+
point_plane_direction = torch.einsum(
|
| 380 |
+
"bij,bnj->bni", (frustum_planes[:, :, :-1], points)
|
| 381 |
+
) + repeat(frustum_planes[:, :, -1], "B P -> B N P", N=points.shape[1]).to(
|
| 382 |
+
device
|
| 383 |
+
) # BxMxN tensor
|
| 384 |
+
|
| 385 |
+
inside = (point_plane_direction >= 0).all(-1)
|
| 386 |
+
return inside
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def frustums_in_frustum_check(
|
| 390 |
+
frustums: torch.Tensor,
|
| 391 |
+
chunk_size: int,
|
| 392 |
+
device: str | None = None,
|
| 393 |
+
use_double_chunking: bool = True,
|
| 394 |
+
):
|
| 395 |
+
"""
|
| 396 |
+
Check if frustums are contained within other frustums.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
frustums (torch.Tensor): Bx8 batch of frustum points.
|
| 400 |
+
chunk_size (Optional[int]): Number of chunks to split the computation into.
|
| 401 |
+
Defaults to None.
|
| 402 |
+
device (Optional[str]): Device to store exhuastive frustum containment matrix on.
|
| 403 |
+
Defaults to None.
|
| 404 |
+
use_double_chunking (bool): If True, use double chunking to avoid OOM errors.
|
| 405 |
+
Defaults to True.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
frustum_contained (torch.Tensor): BxB batch of Booleans indiciating if frustums are inside
|
| 409 |
+
other frustums.
|
| 410 |
+
"""
|
| 411 |
+
B = frustums.shape[0]
|
| 412 |
+
if device is None:
|
| 413 |
+
device = frustums.device
|
| 414 |
+
|
| 415 |
+
if use_double_chunking:
|
| 416 |
+
frustum_contained = torch.zeros((B, B), dtype=torch.bool, device=device)
|
| 417 |
+
# Check if frustums are containing each other by processing in chunks
|
| 418 |
+
for i in tqdm(range(0, B, chunk_size), desc="Checking frustum containment"):
|
| 419 |
+
i_end = min(i + chunk_size, B)
|
| 420 |
+
chunk_i_size = i_end - i
|
| 421 |
+
|
| 422 |
+
for j in range(0, B, chunk_size):
|
| 423 |
+
j_end = min(j + chunk_size, B)
|
| 424 |
+
chunk_j_size = j_end - j
|
| 425 |
+
|
| 426 |
+
# Process a chunk of frustums against another chunk
|
| 427 |
+
frustums_i = frustums[i:i_end]
|
| 428 |
+
frustums_j_vertices = frustums[
|
| 429 |
+
j:j_end, :1
|
| 430 |
+
] # Just need one vertex to check containment
|
| 431 |
+
|
| 432 |
+
# Perform points in frustum check
|
| 433 |
+
contained = rearrange(
|
| 434 |
+
points_in_frustum_check(
|
| 435 |
+
repeat(frustums_i, "B ... -> (B B2) ...", B2=chunk_j_size),
|
| 436 |
+
repeat(
|
| 437 |
+
frustums_j_vertices, "B ... -> (B2 B) ...", B2=chunk_i_size
|
| 438 |
+
),
|
| 439 |
+
)[:, 0],
|
| 440 |
+
"(B B2) -> B B2",
|
| 441 |
+
B=chunk_i_size,
|
| 442 |
+
).to(device)
|
| 443 |
+
|
| 444 |
+
# Map results back to the full matrix
|
| 445 |
+
frustum_contained[i:i_end, j:j_end] |= contained
|
| 446 |
+
frustum_contained[j:j_end, i:i_end] |= contained.transpose(
|
| 447 |
+
0, 1
|
| 448 |
+
) # Symmetric relation
|
| 449 |
+
else:
|
| 450 |
+
# Perform points in frustum check with a single chunked loop
|
| 451 |
+
frustum_contained = rearrange(
|
| 452 |
+
points_in_frustum_check(
|
| 453 |
+
repeat(frustums, "B ... -> (B B2) ...", B2=B),
|
| 454 |
+
repeat(frustums[:, :1], "B ... -> (B2 B) ...", B2=B),
|
| 455 |
+
chunk_size=chunk_size,
|
| 456 |
+
)[:, 0],
|
| 457 |
+
"(B B2) -> B B2",
|
| 458 |
+
B=B,
|
| 459 |
+
).to(device)
|
| 460 |
+
frustum_contained = frustum_contained | frustum_contained.T
|
| 461 |
+
|
| 462 |
+
return frustum_contained
|
mapanything/utils/wai/io.py
ADDED
|
@@ -0,0 +1,1373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This utils script contains PORTAGE of wai-core io methods for MapAnything.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gzip
|
| 6 |
+
import io
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Callable, cast, IO, Literal, overload
|
| 13 |
+
|
| 14 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 15 |
+
import cv2
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import trimesh
|
| 19 |
+
import yaml
|
| 20 |
+
from PIL import Image, PngImagePlugin
|
| 21 |
+
from plyfile import PlyData, PlyElement
|
| 22 |
+
from safetensors.torch import load_file as load_sft, save_file as save_sft
|
| 23 |
+
from torchvision.io import decode_image
|
| 24 |
+
from yaml import CLoader
|
| 25 |
+
|
| 26 |
+
from mapanything.utils.wai.ops import (
|
| 27 |
+
to_numpy,
|
| 28 |
+
)
|
| 29 |
+
from mapanything.utils.wai.semantics import (
|
| 30 |
+
apply_id_to_color_mapping,
|
| 31 |
+
INVALID_ID,
|
| 32 |
+
load_semantic_color_mapping,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Try to use orjson for faster JSON processing
|
| 36 |
+
try:
|
| 37 |
+
import orjson
|
| 38 |
+
except ImportError:
|
| 39 |
+
orjson = None
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@overload
|
| 45 |
+
def _load_readable(
|
| 46 |
+
fname: Path | str, load_as_string: Literal[True], **kwargs
|
| 47 |
+
) -> str: ...
|
| 48 |
+
@overload
|
| 49 |
+
def _load_readable(
|
| 50 |
+
fname: Path | str, load_as_string: Literal[False] = False, **kwargs
|
| 51 |
+
) -> dict: ...
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _load_readable(
|
| 55 |
+
fname: Path | str,
|
| 56 |
+
load_as_string: bool = False,
|
| 57 |
+
**kwargs,
|
| 58 |
+
) -> Any | str:
|
| 59 |
+
"""
|
| 60 |
+
Loads data from a human-readable file and will try to parse JSON or YAML files as a dict, list,
|
| 61 |
+
int, float, str, bool, or None object. Can optionally return the file contents as a string.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
fname (str or Path): The filename to load data from.
|
| 65 |
+
load_as_string (bool, optional): Whether to return the loaded data as a string.
|
| 66 |
+
Defaults to False.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
The loaded data, which can be any type of object that can be represented in JSON or YAML.
|
| 70 |
+
|
| 71 |
+
Raises:
|
| 72 |
+
NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
|
| 73 |
+
"""
|
| 74 |
+
if load_as_string:
|
| 75 |
+
return _load_readable_string(fname, **kwargs)
|
| 76 |
+
else:
|
| 77 |
+
return _load_readable_structured(fname, **kwargs)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _load_readable_structured(
|
| 81 |
+
fname: Path | str,
|
| 82 |
+
**kwargs,
|
| 83 |
+
) -> Any:
|
| 84 |
+
"""
|
| 85 |
+
Loads data from a human-readable file and will try to parse JSON or YAML files as a dict, list,
|
| 86 |
+
int, float, str, bool, or None object.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
fname (str or Path): The filename to load data from.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
The loaded data, which can be any type of object that can be represented in JSON or YAML.
|
| 93 |
+
|
| 94 |
+
Raises:
|
| 95 |
+
NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
|
| 96 |
+
"""
|
| 97 |
+
fname = Path(fname)
|
| 98 |
+
if not fname.exists():
|
| 99 |
+
raise FileNotFoundError(f"File does not exist: {fname}")
|
| 100 |
+
|
| 101 |
+
if fname.suffix == ".json":
|
| 102 |
+
# Use binary mode for JSON files
|
| 103 |
+
with open(fname, mode="rb") as f:
|
| 104 |
+
# Use orjson if available, otherwise use standard JSON
|
| 105 |
+
if orjson:
|
| 106 |
+
return orjson.loads(f.read())
|
| 107 |
+
return json.load(f)
|
| 108 |
+
|
| 109 |
+
if fname.suffix in [".yaml", ".yml"]:
|
| 110 |
+
# Use text mode with UTF-8 encoding for YAML files
|
| 111 |
+
with open(fname, mode="r", encoding="utf-8") as f:
|
| 112 |
+
return yaml.load(f, Loader=CLoader)
|
| 113 |
+
|
| 114 |
+
raise NotImplementedError(f"Readable format not supported: {fname.suffix}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _load_readable_string(
|
| 118 |
+
fname: Path | str,
|
| 119 |
+
**kwargs,
|
| 120 |
+
) -> str:
|
| 121 |
+
"""
|
| 122 |
+
Loads data from a human-readable file as a string.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
fname (str or Path): The filename to load data from.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
The file's contents, as a string.
|
| 129 |
+
"""
|
| 130 |
+
fname = Path(fname)
|
| 131 |
+
if not fname.exists():
|
| 132 |
+
raise FileNotFoundError(f"File does not exist: {fname}")
|
| 133 |
+
|
| 134 |
+
with open(fname, mode="r", encoding="utf-8") as f:
|
| 135 |
+
contents = f.read()
|
| 136 |
+
|
| 137 |
+
return contents
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _store_readable(
|
| 141 |
+
fname: Path | str,
|
| 142 |
+
data: Any,
|
| 143 |
+
**kwargs,
|
| 144 |
+
) -> int:
|
| 145 |
+
"""
|
| 146 |
+
Stores data in a human-readable file (JSON or YAML).
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
fname (str or Path): The filename to store data in.
|
| 150 |
+
data: The data to store, which can be any type of object that can be represented in JSON or YAML.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
The number of bytes written to the file.
|
| 154 |
+
|
| 155 |
+
Raises:
|
| 156 |
+
NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
|
| 157 |
+
"""
|
| 158 |
+
fname = Path(fname)
|
| 159 |
+
|
| 160 |
+
# Create parent directory if it doesn't exist
|
| 161 |
+
os.makedirs(fname.parent, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
if fname.suffix == ".json":
|
| 164 |
+
if orjson:
|
| 165 |
+
# Define the operation for orjson
|
| 166 |
+
with open(fname, mode="wb") as f:
|
| 167 |
+
return f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
|
| 168 |
+
else:
|
| 169 |
+
# Define the operation for standard json
|
| 170 |
+
with open(fname, mode="w", encoding="utf-8") as f:
|
| 171 |
+
json.dump(data, f, indent=2)
|
| 172 |
+
return f.tell()
|
| 173 |
+
|
| 174 |
+
elif fname.suffix in [".yaml", ".yml"]:
|
| 175 |
+
# Define the operation for YAML files
|
| 176 |
+
with open(fname, mode="w", encoding="utf-8") as f:
|
| 177 |
+
yaml.dump(data, f)
|
| 178 |
+
return f.tell()
|
| 179 |
+
else:
|
| 180 |
+
raise NotImplementedError(f"Writable format not supported: {fname.suffix}")
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_processing_state(scene_root: Path | str) -> dict:
|
| 184 |
+
"""
|
| 185 |
+
Retrieves the processing state of a scene.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
scene_root (Path or str): The root directory of the scene.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
dict: A dictionary containing the processing state of the scene.
|
| 192 |
+
If no processing log exists, or reading it fails, an empty
|
| 193 |
+
dictionary is returned.
|
| 194 |
+
"""
|
| 195 |
+
process_log_path = Path(scene_root) / "_process_log.json"
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
return _load_readable_structured(process_log_path)
|
| 199 |
+
except FileNotFoundError:
|
| 200 |
+
logger.debug(f"Log file not found, returning empty dict: {process_log_path}")
|
| 201 |
+
return {}
|
| 202 |
+
except Exception:
|
| 203 |
+
logger.error(
|
| 204 |
+
f"Could not parse, returning empty dict: {process_log_path}", exc_info=True
|
| 205 |
+
)
|
| 206 |
+
return {}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _write_exr(
|
| 210 |
+
fname: str | Path,
|
| 211 |
+
data: np.ndarray | torch.Tensor,
|
| 212 |
+
params: list | None = None,
|
| 213 |
+
**kwargs,
|
| 214 |
+
) -> bool:
|
| 215 |
+
"""
|
| 216 |
+
Writes an image as an EXR file using OpenCV.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
fname (str or Path): The filename to save the image to.
|
| 220 |
+
data (numpy.ndarray, torch.Tensor): The image data to save. Must be a 2D or 3D array.
|
| 221 |
+
params (list, optional): A list of parameters to pass to OpenCV's imwrite function.
|
| 222 |
+
Defaults to None, which uses 32-bit with zip compression.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
bool: True if the image was saved successfully, False otherwise.
|
| 226 |
+
|
| 227 |
+
Raises:
|
| 228 |
+
ValueError: If the input data has less than two or more than three dimensions.
|
| 229 |
+
|
| 230 |
+
Notes:
|
| 231 |
+
Only 32-bit float (CV_32F) images can be saved.
|
| 232 |
+
For comparison of different compression methods, see P1732924327.
|
| 233 |
+
"""
|
| 234 |
+
if Path(fname).suffix != ".exr":
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"Only filenames with suffix .exr allowed but received: {fname}"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
## Note: only 32-bit float (CV_32F) images can be saved
|
| 240 |
+
data_np = to_numpy(data, dtype=np.float32)
|
| 241 |
+
if (data_np.ndim > 3) or (data_np.ndim < 2):
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"Image needs to contain two or three dims but received: {data_np.shape}"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return cv2.imwrite(str(fname), data_np, params if params else [])
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@overload
|
| 250 |
+
def _read_exr(fname: str | Path, fmt: Literal["np"], **kwargs) -> np.ndarray: ...
|
| 251 |
+
@overload
|
| 252 |
+
def _read_exr(fname: str | Path, fmt: Literal["PIL"], **kwargs) -> Image.Image: ...
|
| 253 |
+
@overload
|
| 254 |
+
def _read_exr(
|
| 255 |
+
fname: str | Path, fmt: Literal["torch"] = "torch", **kwargs
|
| 256 |
+
) -> torch.Tensor: ...
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _read_exr(
|
| 260 |
+
fname: str | Path, fmt: Literal["np", "PIL", "torch"] = "torch", **kwargs
|
| 261 |
+
) -> np.ndarray | torch.Tensor | Image.Image:
|
| 262 |
+
"""
|
| 263 |
+
Reads an EXR image file using OpenCV.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
fname (str or Path): The filename of the EXR image to read.
|
| 267 |
+
fmt (str): The format of the output data. Can be one of:
|
| 268 |
+
- "torch": Returns a PyTorch tensor.
|
| 269 |
+
- "np": Returns a NumPy array.
|
| 270 |
+
- "PIL": Returns a PIL Image object.
|
| 271 |
+
Defaults to "torch".
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
The EXR image data in the specified output format.
|
| 275 |
+
|
| 276 |
+
Raises:
|
| 277 |
+
NotImplementedError: If the specified output format is not supported.
|
| 278 |
+
ValueError: If data shape is not supported, e.g. multi-channel PIL float images.
|
| 279 |
+
|
| 280 |
+
Notes:
|
| 281 |
+
The EXR image is read in its original format, without any conversion or rescaling.
|
| 282 |
+
"""
|
| 283 |
+
data = cv2.imread(str(fname), cv2.IMREAD_UNCHANGED)
|
| 284 |
+
if data is None:
|
| 285 |
+
raise FileNotFoundError(f"Failed to read EXR file: {fname}")
|
| 286 |
+
if fmt == "torch":
|
| 287 |
+
# Convert to PyTorch tensor with float32 dtype
|
| 288 |
+
data = torch.from_numpy(data).float()
|
| 289 |
+
elif fmt == "np":
|
| 290 |
+
# Convert to NumPy array with float32 dtype
|
| 291 |
+
data = np.array(data, dtype=np.float32)
|
| 292 |
+
elif fmt == "PIL":
|
| 293 |
+
if data.ndim != 2:
|
| 294 |
+
raise ValueError("PIL does not support multi-channel EXR images")
|
| 295 |
+
|
| 296 |
+
# Convert to PIL Image object
|
| 297 |
+
data = Image.fromarray(data)
|
| 298 |
+
else:
|
| 299 |
+
raise NotImplementedError(f"fmt not supported: {fmt}")
|
| 300 |
+
return data
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@overload
|
| 304 |
+
def _load_image(
|
| 305 |
+
fname: str | Path,
|
| 306 |
+
fmt: Literal["np"],
|
| 307 |
+
resize: tuple[int, int] | None = None,
|
| 308 |
+
**kwargs,
|
| 309 |
+
) -> np.ndarray: ...
|
| 310 |
+
@overload
|
| 311 |
+
def _load_image(
|
| 312 |
+
fname: str | Path,
|
| 313 |
+
fmt: Literal["pil"],
|
| 314 |
+
resize: tuple[int, int] | None = None,
|
| 315 |
+
**kwargs,
|
| 316 |
+
) -> Image.Image: ...
|
| 317 |
+
@overload
|
| 318 |
+
def _load_image(
|
| 319 |
+
fname: str | Path,
|
| 320 |
+
fmt: Literal["torch"] = "torch",
|
| 321 |
+
resize: tuple[int, int] | None = None,
|
| 322 |
+
**kwargs,
|
| 323 |
+
) -> torch.Tensor: ...
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _load_image(
|
| 327 |
+
fname: str | Path,
|
| 328 |
+
fmt: Literal["np", "pil", "torch"] = "torch",
|
| 329 |
+
resize: tuple[int, int] | None = None,
|
| 330 |
+
**kwargs,
|
| 331 |
+
) -> np.ndarray | torch.Tensor | Image.Image:
|
| 332 |
+
"""
|
| 333 |
+
Loads an image from a file.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
fname (str or Path): The filename to load the image from.
|
| 337 |
+
fmt (str): The format of the output data. Can be one of:
|
| 338 |
+
- "torch": Returns a PyTorch tensor with shape (C, H, W).
|
| 339 |
+
- "np": Returns a NumPy array with shape (H, W, C).
|
| 340 |
+
- "pil": Returns a PIL Image object.
|
| 341 |
+
Defaults to "torch".
|
| 342 |
+
resize (tuple, optional): A tuple of two integers representing the desired width and height of the image.
|
| 343 |
+
If None, the image is not resized. Defaults to None.
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
The loaded image in the specified output format.
|
| 347 |
+
|
| 348 |
+
Raises:
|
| 349 |
+
NotImplementedError: If the specified output format is not supported.
|
| 350 |
+
|
| 351 |
+
Notes:
|
| 352 |
+
This function loads non-binary images in RGB mode and normalizes pixel values to the range [0, 1].
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
# Fastest way to load into torch tensor
|
| 356 |
+
if resize is None and fmt == "torch":
|
| 357 |
+
return decode_image(str(fname)).float() / 255.0
|
| 358 |
+
|
| 359 |
+
# Load using PIL
|
| 360 |
+
with open(fname, "rb") as f:
|
| 361 |
+
pil_image = Image.open(f)
|
| 362 |
+
pil_image.load()
|
| 363 |
+
|
| 364 |
+
if pil_image.mode not in ["RGB", "RGBA"]:
|
| 365 |
+
raise OSError(
|
| 366 |
+
f"Expected a RGB or RGBA image in {fname}, but instead found an image with mode {pil_image.mode}"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
if resize is not None:
|
| 370 |
+
pil_image = pil_image.resize(resize)
|
| 371 |
+
|
| 372 |
+
if fmt == "torch":
|
| 373 |
+
return (
|
| 374 |
+
torch.from_numpy(np.array(pil_image)).permute(2, 0, 1).float() / 255.0
|
| 375 |
+
)
|
| 376 |
+
elif fmt == "np":
|
| 377 |
+
return np.array(pil_image, dtype=np.float32) / 255.0
|
| 378 |
+
elif fmt == "pil":
|
| 379 |
+
return pil_image
|
| 380 |
+
else:
|
| 381 |
+
raise NotImplementedError(f"Image format not supported: {fmt}")
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _store_image(
|
| 385 |
+
fname: str | Path, img_data: np.ndarray | torch.Tensor | Image.Image, **kwargs
|
| 386 |
+
) -> None:
|
| 387 |
+
"""
|
| 388 |
+
Stores an image in a file.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
fname (str or Path): The filename to store the image in.
|
| 392 |
+
img_data (numpy.ndarray, torch.tensor or PIL.Image.Image): The image data to store.
|
| 393 |
+
|
| 394 |
+
Notes (for numpy.ndarray or torch.tensor inputs):
|
| 395 |
+
This function assumes that the input image data is in the range [0, 1], and has shape
|
| 396 |
+
(H, W, C), or (C, H, W) for PyTorch tensors, with C being 3 or 4.
|
| 397 |
+
It converts the image data to uint8 format and saves it as a compressed image file.
|
| 398 |
+
"""
|
| 399 |
+
if isinstance(img_data, torch.Tensor):
|
| 400 |
+
if img_data.ndim != 3:
|
| 401 |
+
raise ValueError(f"Tensor needs to be 3D but received: {img_data.shape=}")
|
| 402 |
+
|
| 403 |
+
if img_data.shape[0] in [3, 4]:
|
| 404 |
+
# Convert to HWC format expected by pillow `Image.save` below
|
| 405 |
+
img_data = img_data.permute(1, 2, 0)
|
| 406 |
+
|
| 407 |
+
img_data = img_data.contiguous()
|
| 408 |
+
|
| 409 |
+
if isinstance(img_data, (np.ndarray, torch.Tensor)):
|
| 410 |
+
if img_data.shape[-1] not in [3, 4]:
|
| 411 |
+
raise ValueError(
|
| 412 |
+
f"Image must have 3 or 4 channels, but received: {img_data.shape=}"
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
img_data_np = to_numpy(img_data, dtype=np.float32)
|
| 416 |
+
img_data = Image.fromarray((255 * img_data_np).round().astype(np.uint8))
|
| 417 |
+
|
| 418 |
+
with open(fname, "wb") as f:
|
| 419 |
+
pil_kwargs = {
|
| 420 |
+
# Make PNGs faster to save using minimal compression
|
| 421 |
+
"optimize": False,
|
| 422 |
+
"compress_level": 1,
|
| 423 |
+
# Higher JPEG image quality
|
| 424 |
+
"quality": "high",
|
| 425 |
+
}
|
| 426 |
+
pil_kwargs.update(kwargs)
|
| 427 |
+
img_data.save(cast(IO[bytes], f), **pil_kwargs)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _load_binary_mask(
|
| 431 |
+
fname: str | Path,
|
| 432 |
+
fmt: str = "torch",
|
| 433 |
+
resize: tuple[int, int] | None = None,
|
| 434 |
+
**kwargs,
|
| 435 |
+
) -> np.ndarray | torch.Tensor | Image.Image:
|
| 436 |
+
"""
|
| 437 |
+
Loads a binary image from a file.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
fname (str or Path): The filename to load the binary image from.
|
| 441 |
+
fmt (str): The format of the output data. Can be one of:
|
| 442 |
+
- "torch": Returns a PyTorch Boolean tensor with shape H x W.
|
| 443 |
+
- "np": Returns a NumPy Boolean array with shape H x W.
|
| 444 |
+
- "pil": Returns a PIL Image object.
|
| 445 |
+
Defaults to "torch".
|
| 446 |
+
resize (tuple, optional): A tuple of two integers representing the desired width and height of the binary image.
|
| 447 |
+
If None, the image is not resized. Defaults to None.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
The loaded binary image in the specified output format.
|
| 451 |
+
|
| 452 |
+
Raises:
|
| 453 |
+
NotImplementedError: If the specified output format is not supported.
|
| 454 |
+
"""
|
| 455 |
+
if fmt not in ["pil", "np", "torch"]:
|
| 456 |
+
raise NotImplementedError(f"Image format not supported: {fmt}")
|
| 457 |
+
|
| 458 |
+
with open(fname, "rb") as f:
|
| 459 |
+
pil_image = Image.open(f)
|
| 460 |
+
pil_image.load()
|
| 461 |
+
|
| 462 |
+
if pil_image.mode == "L":
|
| 463 |
+
pil_image = pil_image.convert("1")
|
| 464 |
+
|
| 465 |
+
elif pil_image.mode != "1":
|
| 466 |
+
raise OSError(
|
| 467 |
+
f"Expected a binary or grayscale image in {fname}, but instead found an image with mode {pil_image.mode}"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if resize is not None:
|
| 471 |
+
pil_image = pil_image.resize(resize)
|
| 472 |
+
|
| 473 |
+
if fmt == "pil":
|
| 474 |
+
return pil_image
|
| 475 |
+
|
| 476 |
+
mask = np.array(pil_image, copy=True)
|
| 477 |
+
return mask if fmt == "np" else torch.from_numpy(mask)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _store_binary_mask(
|
| 481 |
+
fname: str | Path, img_data: np.ndarray | torch.Tensor | Image.Image, **kwargs
|
| 482 |
+
) -> None:
|
| 483 |
+
"""
|
| 484 |
+
Stores a binary image in a compressed image file.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
fname (str or Path): The filename to store the binary image in.
|
| 488 |
+
img_data (numpy.ndarray, torch.tensor or PIL.Image.Image): The binary image data to store.
|
| 489 |
+
"""
|
| 490 |
+
if isinstance(img_data, Image.Image):
|
| 491 |
+
if img_data.mode not in ["1", "L"]:
|
| 492 |
+
raise RuntimeError(
|
| 493 |
+
f'Expected a PIL image with mode "1" or "L", but instead got a PIL image with mode {img_data.mode}'
|
| 494 |
+
)
|
| 495 |
+
elif isinstance(img_data, np.ndarray) or isinstance(img_data, torch.Tensor):
|
| 496 |
+
if len(img_data.squeeze().shape) != 2:
|
| 497 |
+
raise RuntimeError(
|
| 498 |
+
f"Expected a PyTorch tensor or NumPy array with shape (H, W, 1), (1, H, W) or (H, W), but the shape is {img_data.shape}"
|
| 499 |
+
)
|
| 500 |
+
img_data = img_data.squeeze()
|
| 501 |
+
else:
|
| 502 |
+
raise NotImplementedError(f"Input format not supported: {type(img_data)}")
|
| 503 |
+
|
| 504 |
+
if not isinstance(img_data, Image.Image):
|
| 505 |
+
img_data = to_numpy(img_data, dtype=bool)
|
| 506 |
+
img_data = Image.fromarray(img_data)
|
| 507 |
+
|
| 508 |
+
img_data = img_data.convert("1")
|
| 509 |
+
with open(fname, "wb") as f:
|
| 510 |
+
img_data.save(f, compress_level=1, optimize=False)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _load_sft(
|
| 514 |
+
fname: str | Path,
|
| 515 |
+
fmt: str = "torch",
|
| 516 |
+
**kwargs,
|
| 517 |
+
) -> torch.Tensor:
|
| 518 |
+
"""
|
| 519 |
+
Loads a tensor from a safetensor file.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
fname (str | Path): The filename of the safetensor file to load.
|
| 523 |
+
fmt (str, optional): The format of the output data. Currently only "torch" is supported.
|
| 524 |
+
**kwargs: Additional keyword arguments (unused).
|
| 525 |
+
|
| 526 |
+
Returns:
|
| 527 |
+
torch.Tensor: The loaded tensor.
|
| 528 |
+
|
| 529 |
+
Raises:
|
| 530 |
+
AssertionError: If the file extension is not .sft or if fmt is not "torch".
|
| 531 |
+
"""
|
| 532 |
+
assert Path(fname).suffix == ".sft", "Only .sft (safetensor) is supported"
|
| 533 |
+
assert fmt == "torch", "Only torch format is supported for latent"
|
| 534 |
+
out = load_sft(str(fname))
|
| 535 |
+
return out["latent"]
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def _store_sft(fname: str | Path, data: torch.Tensor, **kwargs) -> None:
|
| 539 |
+
"""
|
| 540 |
+
Stores a tensor to a safetensor file.
|
| 541 |
+
|
| 542 |
+
Args:
|
| 543 |
+
fname (str | Path): The filename to store the latent in.
|
| 544 |
+
data (torch.Tensor): The latent tensor to store.
|
| 545 |
+
**kwargs: Additional keyword arguments (unused).
|
| 546 |
+
|
| 547 |
+
Raises:
|
| 548 |
+
AssertionError: If the file extension is not .sft or if data is not a torch.Tensor.
|
| 549 |
+
"""
|
| 550 |
+
assert Path(fname).suffix == ".sft", "Only .sft (safetensor) is supported"
|
| 551 |
+
assert isinstance(data, torch.Tensor)
|
| 552 |
+
save_sft(tensors={"latent": data}, filename=str(fname))
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def _store_depth(fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs) -> bool:
|
| 556 |
+
"""
|
| 557 |
+
Stores a depth map in an EXR file.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
fname (str or Path): The filename to save the depth map to.
|
| 561 |
+
data (numpy.ndarray, torch.tensor): The depth map to save.
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
bool: True if the depth map was saved successfully, False otherwise.
|
| 565 |
+
|
| 566 |
+
Raises:
|
| 567 |
+
ValueError: If the input data does not have two dimensions after removing singleton dimensions.
|
| 568 |
+
"""
|
| 569 |
+
data_np = to_numpy(data, dtype=np.float32)
|
| 570 |
+
data_np = data_np.squeeze() # remove all 1-dim entries
|
| 571 |
+
if data_np.ndim != 2:
|
| 572 |
+
raise ValueError(f"Depth image needs to be 2d, but received: {data_np.shape}")
|
| 573 |
+
|
| 574 |
+
if "params" in kwargs:
|
| 575 |
+
params = kwargs["params"]
|
| 576 |
+
else:
|
| 577 |
+
# use 16-bit with zip compression for depth maps
|
| 578 |
+
params = [
|
| 579 |
+
cv2.IMWRITE_EXR_TYPE,
|
| 580 |
+
cv2.IMWRITE_EXR_TYPE_HALF,
|
| 581 |
+
cv2.IMWRITE_EXR_COMPRESSION,
|
| 582 |
+
cv2.IMWRITE_EXR_COMPRESSION_ZIP,
|
| 583 |
+
]
|
| 584 |
+
|
| 585 |
+
return _write_exr(fname, data_np, params=params)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def _load_depth(
|
| 589 |
+
fname: str | Path, fmt: str = "torch", **kwargs
|
| 590 |
+
) -> np.ndarray | torch.Tensor | Image.Image:
|
| 591 |
+
"""
|
| 592 |
+
Loads a depth image from an EXR file.
|
| 593 |
+
|
| 594 |
+
Args:
|
| 595 |
+
fname (str or Path): The filename of the EXR file to load.
|
| 596 |
+
fmt (str): The format of the output data. Can be one of:
|
| 597 |
+
- "torch": Returns a PyTorch tensor.
|
| 598 |
+
- "np": Returns a NumPy array.
|
| 599 |
+
- "PIL": Returns a PIL Image object.
|
| 600 |
+
Defaults to "torch".
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
The loaded depth image in the specified output format.
|
| 604 |
+
|
| 605 |
+
Raises:
|
| 606 |
+
ValueError: If the loaded depth image does not have two dimensions.
|
| 607 |
+
|
| 608 |
+
Notes:
|
| 609 |
+
This function assumes that the EXR file contains a single-channel depth image.
|
| 610 |
+
"""
|
| 611 |
+
data = _read_exr(fname, fmt)
|
| 612 |
+
if (fmt != "PIL") and (data.ndim != 2):
|
| 613 |
+
raise ValueError(f"Depth image needs to be 2D, but loaded: {data.shape}")
|
| 614 |
+
return data
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def _store_normals(
|
| 618 |
+
fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs
|
| 619 |
+
) -> bool:
|
| 620 |
+
"""
|
| 621 |
+
Stores a normals image in an EXR file.
|
| 622 |
+
|
| 623 |
+
Args:
|
| 624 |
+
fname (str or Path): The filename to save the normals image to.
|
| 625 |
+
data (numpy.ndarray): The normals image data to save. Will be converted to a 32-bit float array.
|
| 626 |
+
|
| 627 |
+
Returns:
|
| 628 |
+
bool: True if the normals image was saved successfully, False otherwise.
|
| 629 |
+
|
| 630 |
+
Raises:
|
| 631 |
+
ValueError: If the input data has more than three dimensions after removing singleton dimensions.
|
| 632 |
+
ValueError: If the input data does not have exactly three channels.
|
| 633 |
+
ValueError: If the input data is not normalized (i.e., maximum absolute value exceeds 1).
|
| 634 |
+
|
| 635 |
+
Notes:
|
| 636 |
+
This function assumes that the input data is in HWC (height, width, channels) format.
|
| 637 |
+
If the input data is in CHW (channels, height, width) format, it will be automatically transposed to HWC.
|
| 638 |
+
"""
|
| 639 |
+
data_np = to_numpy(data, dtype=np.float32)
|
| 640 |
+
data_np = data_np.squeeze() # remove all singleton dimensions
|
| 641 |
+
|
| 642 |
+
if data_np.ndim != 3:
|
| 643 |
+
raise ValueError(
|
| 644 |
+
f"Normals image needs to be 3-dim but received: {data_np.shape}"
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
if (data_np.shape[0] == 3) and (data_np.shape[2] != 3):
|
| 648 |
+
# ensure HWC format
|
| 649 |
+
data_np = data_np.transpose(1, 2, 0)
|
| 650 |
+
|
| 651 |
+
if data_np.shape[2] != 3:
|
| 652 |
+
raise ValueError(
|
| 653 |
+
f"Normals image needs have 3 channels but received: {data_np.shape}"
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
# We want to check that the norm values are either 1 (valid) or 0 (invalid values are 0s)
|
| 657 |
+
norm = np.linalg.norm(data_np, axis=-1)
|
| 658 |
+
is_one = np.isclose(norm, 1.0, atol=1e-3)
|
| 659 |
+
is_zero = np.isclose(norm, 0.0)
|
| 660 |
+
if not np.all([is_one | is_zero]):
|
| 661 |
+
raise ValueError("Normals image must be normalized")
|
| 662 |
+
|
| 663 |
+
return _write_exr(fname, data_np)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def _load_normals(
|
| 667 |
+
fname: str | Path, fmt: str = "torch", **kwargs
|
| 668 |
+
) -> np.ndarray | torch.Tensor | Image.Image:
|
| 669 |
+
"""
|
| 670 |
+
Loads a normals image from an EXR file.
|
| 671 |
+
|
| 672 |
+
Args:
|
| 673 |
+
fname (str or Path): The filename of the EXR file to load.
|
| 674 |
+
fmt (str): The format of the output data. Can be one of:
|
| 675 |
+
- "torch": Returns a PyTorch tensor.
|
| 676 |
+
- "np": Returns a NumPy array.
|
| 677 |
+
- "PIL": Returns a PIL Image object.
|
| 678 |
+
Defaults to "torch".
|
| 679 |
+
|
| 680 |
+
Returns:
|
| 681 |
+
The loaded normals image in the specified output format.
|
| 682 |
+
|
| 683 |
+
Raises:
|
| 684 |
+
Warning: If the loaded normals image has more than two dimensions.
|
| 685 |
+
|
| 686 |
+
Notes:
|
| 687 |
+
This function assumes that the EXR file contains a 3-channel normals image.
|
| 688 |
+
"""
|
| 689 |
+
data = _read_exr(fname, fmt)
|
| 690 |
+
|
| 691 |
+
if data.ndim != 3:
|
| 692 |
+
raise ValueError(f"Normals image needs to be 3-dim but received: {data.shape}")
|
| 693 |
+
|
| 694 |
+
if data.shape[2] != 3:
|
| 695 |
+
raise ValueError(
|
| 696 |
+
f"Normals image needs have 3 channels but received: {data.shape}"
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
return data
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def _load_numpy(fname: str | Path, allow_pickle: bool = False, **kwargs) -> np.ndarray:
|
| 703 |
+
"""
|
| 704 |
+
Loads a NumPy array from a file.
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
fname (str or Path): The filename to load the NumPy array from.
|
| 708 |
+
allow_pickle (bool, optional): Whether to allow pickled objects in the NumPy file.
|
| 709 |
+
Defaults to False.
|
| 710 |
+
|
| 711 |
+
Returns:
|
| 712 |
+
numpy.ndarray: The loaded NumPy array.
|
| 713 |
+
|
| 714 |
+
Raises:
|
| 715 |
+
NotImplementedError: If the file suffix is not supported (i.e., not .npy or .npz).
|
| 716 |
+
|
| 717 |
+
Notes:
|
| 718 |
+
This function supports loading NumPy arrays from .npy and .npz files.
|
| 719 |
+
For .npz files, it assumes that the array is stored under the key "arr_0".
|
| 720 |
+
"""
|
| 721 |
+
fname = Path(fname)
|
| 722 |
+
with open(fname, "rb") as fid:
|
| 723 |
+
if fname.suffix == ".npy":
|
| 724 |
+
return np.load(fid, allow_pickle=allow_pickle)
|
| 725 |
+
elif fname.suffix == ".npz":
|
| 726 |
+
return np.load(fid, allow_pickle=allow_pickle).get("arr_0")
|
| 727 |
+
else:
|
| 728 |
+
raise NotImplementedError(f"Numpy format not supported: {fname.suffix}")
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def _store_numpy(fname: str | Path, data: np.ndarray, **kwargs) -> None:
|
| 732 |
+
"""
|
| 733 |
+
Stores a NumPy array in a file.
|
| 734 |
+
|
| 735 |
+
Args:
|
| 736 |
+
fname (str or Path): The filename to store the NumPy array in.
|
| 737 |
+
data (numpy.ndarray): The NumPy array to store.
|
| 738 |
+
|
| 739 |
+
Raises:
|
| 740 |
+
NotImplementedError: If the file suffix is not supported (i.e., not .npy or .npz).
|
| 741 |
+
|
| 742 |
+
Notes:
|
| 743 |
+
This function supports storing NumPy arrays in .npy and .npz files.
|
| 744 |
+
For .npz files, it uses compression to reduce the file size.
|
| 745 |
+
"""
|
| 746 |
+
fname = Path(fname)
|
| 747 |
+
with open(fname, "wb") as fid:
|
| 748 |
+
if fname.suffix == ".npy":
|
| 749 |
+
np.save(fid, data)
|
| 750 |
+
elif fname.suffix == ".npz":
|
| 751 |
+
np.savez_compressed(fid, arr_0=data)
|
| 752 |
+
else:
|
| 753 |
+
raise NotImplementedError(f"Numpy format not supported: {fname.suffix}")
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def _load_ptz(fname: str | Path, **kwargs) -> torch.Tensor:
|
| 757 |
+
"""
|
| 758 |
+
Loads a PyTorch tensor from a PTZ file.
|
| 759 |
+
|
| 760 |
+
Args:
|
| 761 |
+
fname (str or Path): The filename to load the tensor from.
|
| 762 |
+
|
| 763 |
+
Returns:
|
| 764 |
+
torch.Tensor: The loaded PyTorch tensor.
|
| 765 |
+
|
| 766 |
+
Notes:
|
| 767 |
+
This function assumes that the PTZ file contains a PyTorch tensor saved using `torch.save`.
|
| 768 |
+
If the tensor was saved in a different format, this function may fail.
|
| 769 |
+
"""
|
| 770 |
+
with open(fname, "rb") as fid:
|
| 771 |
+
data = gzip.decompress(fid.read())
|
| 772 |
+
## Note: if the following line fails, save PyTorch tensors in PTZ instead of NumPy
|
| 773 |
+
return torch.load(io.BytesIO(data), map_location="cpu", weights_only=True)
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def _store_ptz(fname: str | Path, data: torch.Tensor, **kwargs) -> None:
|
| 777 |
+
"""
|
| 778 |
+
Stores a PyTorch tensor in a PTZ file.
|
| 779 |
+
|
| 780 |
+
Args:
|
| 781 |
+
fname (str or Path): The filename to store the tensor in.
|
| 782 |
+
data (torch.Tensor): The PyTorch tensor to store.
|
| 783 |
+
|
| 784 |
+
Notes:
|
| 785 |
+
This function saves the tensor using `torch.save` and compresses it using gzip.
|
| 786 |
+
"""
|
| 787 |
+
with open(fname, "wb") as fid:
|
| 788 |
+
with gzip.open(fid, "wb") as gfid:
|
| 789 |
+
torch.save(data, gfid)
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
def _store_mmap(fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs) -> str:
|
| 793 |
+
"""
|
| 794 |
+
Stores matrix-shaped data in a memory-mapped file.
|
| 795 |
+
|
| 796 |
+
Args:
|
| 797 |
+
fname (str or Path): The filename to store the data in.
|
| 798 |
+
data (numpy.ndarray): The matrix-shaped data to store.
|
| 799 |
+
|
| 800 |
+
Returns:
|
| 801 |
+
str: The name of the stored memory-mapped file.
|
| 802 |
+
|
| 803 |
+
Notes:
|
| 804 |
+
This function stores the data in a .npy file with a modified filename that includes the shape of the data.
|
| 805 |
+
The data is converted to float32 format before storing.
|
| 806 |
+
"""
|
| 807 |
+
fname = Path(fname)
|
| 808 |
+
# add dimensions to the file name for loading
|
| 809 |
+
data_np = to_numpy(data, dtype=np.float32)
|
| 810 |
+
shape_string = "x".join([str(dim) for dim in data_np.shape])
|
| 811 |
+
mmap_name = f"{fname.stem}--{shape_string}.npy"
|
| 812 |
+
with open(fname.parent / mmap_name, "wb") as fid:
|
| 813 |
+
np.save(fid, data_np)
|
| 814 |
+
return mmap_name
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
def _load_mmap(fname: str | Path, **kwargs) -> np.memmap:
|
| 818 |
+
"""
|
| 819 |
+
Loads matrix-shaped data from a memory-mapped file.
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
fname (str or Path): The filename of the memory-mapped file to load.
|
| 823 |
+
|
| 824 |
+
Returns:
|
| 825 |
+
numpy.memmap: A memory-mapped array containing the loaded data.
|
| 826 |
+
|
| 827 |
+
Notes:
|
| 828 |
+
This function assumes that the filename contains the shape of the data, separated by 'x' or ','.
|
| 829 |
+
It uses this information to create a memory-mapped array with the correct shape.
|
| 830 |
+
"""
|
| 831 |
+
shape_string = Path(Path(fname).name.split("--")[1]).stem
|
| 832 |
+
shape = [int(dim) for dim in shape_string.replace(",", "x").split("x")]
|
| 833 |
+
with open(fname, "rb") as fid:
|
| 834 |
+
return np.memmap(fid, dtype=np.float32, mode="r", shape=shape, offset=128)
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
def _store_scene_meta(fname: Path | str, scene_meta: dict[str, Any], **kwargs) -> None:
|
| 838 |
+
"""
|
| 839 |
+
Stores scene metadata in a readable file.
|
| 840 |
+
|
| 841 |
+
Args:
|
| 842 |
+
fname (str or Path): The filename to store the scene metadata in.
|
| 843 |
+
scene_meta (dict): The scene metadata to store.
|
| 844 |
+
|
| 845 |
+
Notes:
|
| 846 |
+
This function updates the "last_modified" field of the scene metadata to the current date and time before storing it.
|
| 847 |
+
It also removes the "frame_names" field from the scene metadata, as it is not necessary to store this information.
|
| 848 |
+
Creates a backup of the existing file before overwriting it.
|
| 849 |
+
"""
|
| 850 |
+
# update the modified date
|
| 851 |
+
scene_meta["last_modified"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 852 |
+
if "frame_names" in scene_meta:
|
| 853 |
+
del scene_meta["frame_names"]
|
| 854 |
+
|
| 855 |
+
# create/overwrite backup
|
| 856 |
+
fname_path = Path(fname)
|
| 857 |
+
if fname_path.exists():
|
| 858 |
+
backup_fname = fname_path.parent / f"_{fname_path.stem}_backup.json"
|
| 859 |
+
if backup_fname.exists():
|
| 860 |
+
backup_fname.unlink()
|
| 861 |
+
fname_path.rename(backup_fname)
|
| 862 |
+
|
| 863 |
+
_store_readable(fname, scene_meta)
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
def _load_scene_meta(fname: Path | str, **kwargs) -> dict[str, Any]:
|
| 867 |
+
"""
|
| 868 |
+
Loads scene metadata from a readable file.
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
fname (str or Path): The filename to load the scene metadata from.
|
| 872 |
+
|
| 873 |
+
Returns:
|
| 874 |
+
dict: The loaded scene metadata, including an additional "frame_names" field that maps frame names to their indices.
|
| 875 |
+
|
| 876 |
+
Notes:
|
| 877 |
+
This function creates the "frame_names" field in the scene metadata for efficient lookup of frame indices by name.
|
| 878 |
+
"""
|
| 879 |
+
scene_meta = _load_readable_structured(fname)
|
| 880 |
+
# create the frame_name -> frame_idx for efficiency
|
| 881 |
+
scene_meta["frame_names"] = {
|
| 882 |
+
frame["frame_name"]: frame_idx
|
| 883 |
+
for frame_idx, frame in enumerate(scene_meta["frames"])
|
| 884 |
+
}
|
| 885 |
+
return scene_meta
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def _load_labeled_image(
|
| 889 |
+
fname: str | Path,
|
| 890 |
+
fmt: str = "torch",
|
| 891 |
+
resize: tuple[int, int] | None = None,
|
| 892 |
+
**kwargs,
|
| 893 |
+
) -> np.ndarray | torch.Tensor | Image.Image:
|
| 894 |
+
"""
|
| 895 |
+
Loads a labeled image from a PNG file.
|
| 896 |
+
|
| 897 |
+
Args:
|
| 898 |
+
fname (str or Path): The filename to load the image from.
|
| 899 |
+
fmt (str): The format of the output data. Can be one of:
|
| 900 |
+
- "torch": Returns a PyTorch int32 tensor with shape (H, W).
|
| 901 |
+
- "np": Returns a NumPy int32 array with shape (H, W).
|
| 902 |
+
- "pil": Returns a PIL Image object.
|
| 903 |
+
Defaults to "torch".
|
| 904 |
+
resize (tuple, optional): A tuple of two integers representing the desired width and height of the image.
|
| 905 |
+
If None, the image is not resized. Defaults to None.
|
| 906 |
+
|
| 907 |
+
Returns:
|
| 908 |
+
The loaded image in the specified output format.
|
| 909 |
+
|
| 910 |
+
Raises:
|
| 911 |
+
NotImplementedError: If the specified output format is not supported.
|
| 912 |
+
RuntimeError: If the 'id_to_color_mapping' is missing in the PNG metadata.
|
| 913 |
+
|
| 914 |
+
Notes:
|
| 915 |
+
The function expects the PNG file to contain metadata with a key 'id_to_color_mapping',
|
| 916 |
+
which maps from label ids to tuples of RGB values.
|
| 917 |
+
"""
|
| 918 |
+
with open(fname, "rb") as f:
|
| 919 |
+
pil_image = Image.open(f)
|
| 920 |
+
pil_image.load()
|
| 921 |
+
if pil_image.mode != "RGB":
|
| 922 |
+
raise OSError(
|
| 923 |
+
f"Expected a RGB image in {fname}, but instead found an image with mode {pil_image.mode}"
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
# Load id to RGB mapping
|
| 927 |
+
color_palette_json = pil_image.info.get("id_to_color_mapping", None)
|
| 928 |
+
if color_palette_json is None:
|
| 929 |
+
raise RuntimeError("'id_to_color_mapping' is missing in the PNG metadata.")
|
| 930 |
+
color_palette = json.loads(color_palette_json)
|
| 931 |
+
color_to_id_mapping = {
|
| 932 |
+
tuple(color): int(id) for id, color in color_palette.items()
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
if resize is not None:
|
| 936 |
+
pil_image = pil_image.resize(resize, Image.NEAREST)
|
| 937 |
+
|
| 938 |
+
if fmt == "pil":
|
| 939 |
+
return pil_image
|
| 940 |
+
|
| 941 |
+
# Reverse the color mapping: map from RGB colors to ids
|
| 942 |
+
img_data = np.array(pil_image)
|
| 943 |
+
|
| 944 |
+
# Create a lookup table for fast mapping
|
| 945 |
+
max_color_value = 256 # Assuming 8-bit per channel
|
| 946 |
+
lookup_table = np.full(
|
| 947 |
+
(max_color_value, max_color_value, max_color_value),
|
| 948 |
+
INVALID_ID,
|
| 949 |
+
dtype=np.int32,
|
| 950 |
+
)
|
| 951 |
+
for color, index in color_to_id_mapping.items():
|
| 952 |
+
lookup_table[color] = index
|
| 953 |
+
# Map colors to ids using the lookup table
|
| 954 |
+
img_data = lookup_table[img_data[..., 0], img_data[..., 1], img_data[..., 2]]
|
| 955 |
+
|
| 956 |
+
if fmt == "np":
|
| 957 |
+
return img_data
|
| 958 |
+
elif fmt == "torch":
|
| 959 |
+
return torch.from_numpy(img_data)
|
| 960 |
+
else:
|
| 961 |
+
raise NotImplementedError(f"Image format not supported: {fmt}")
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
def _store_labeled_image(
|
| 965 |
+
fname: str | Path,
|
| 966 |
+
img_data: np.ndarray | torch.Tensor | Image.Image,
|
| 967 |
+
semantic_color_mapping: np.ndarray | None = None,
|
| 968 |
+
**kwargs,
|
| 969 |
+
) -> None:
|
| 970 |
+
"""
|
| 971 |
+
Stores a labeled image as a uint8 RGB PNG file.
|
| 972 |
+
|
| 973 |
+
Args:
|
| 974 |
+
fname (str or Path): The filename to store the image in.
|
| 975 |
+
img_data (numpy.ndarray, torch.Tensor or PIL.Image.Image): The per-pixel label ids to store.
|
| 976 |
+
semantic_color_mapping (np.ndarray): Optional, preloaded NumPy array of semantic colors.
|
| 977 |
+
|
| 978 |
+
Raises:
|
| 979 |
+
ValueError: If the file suffix is not supported (i.e., not .png).
|
| 980 |
+
RuntimeError: If the type of the image data is different from uint16, int16 or int32.
|
| 981 |
+
|
| 982 |
+
Notes:
|
| 983 |
+
The function takes an image with per-pixel label ids and converts it into an RGB image
|
| 984 |
+
using a specified mapping from label ids to RGB colors. The resulting image is saved as
|
| 985 |
+
a PNG file, with the mapping stored as metadata.
|
| 986 |
+
"""
|
| 987 |
+
if Path(fname).suffix != ".png":
|
| 988 |
+
raise ValueError(
|
| 989 |
+
f"Only filenames with suffix .png allowed but received: {fname}"
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
if isinstance(img_data, Image.Image) and img_data.mode != "I;16":
|
| 993 |
+
raise RuntimeError(
|
| 994 |
+
f"The provided image does not seem to be a labeled image. The provided PIL image has mode {img_data.mode}."
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
if isinstance(img_data, np.ndarray) and img_data.dtype not in [
|
| 998 |
+
np.uint16,
|
| 999 |
+
np.int16,
|
| 1000 |
+
np.int32,
|
| 1001 |
+
]:
|
| 1002 |
+
raise RuntimeError(
|
| 1003 |
+
f"The provided NumPy array has type {img_data.dtype} but the expected type is np.uint16, np.int16 or np.int32."
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
if isinstance(img_data, torch.Tensor):
|
| 1007 |
+
if img_data.dtype not in [torch.uint16, torch.int16, torch.int32]:
|
| 1008 |
+
raise RuntimeError(
|
| 1009 |
+
f"The provided PyTorch tensor has type {img_data.dtype} but the expected type is torch.uint16, torch.int16 or torch.int32."
|
| 1010 |
+
)
|
| 1011 |
+
img_data = img_data.numpy()
|
| 1012 |
+
|
| 1013 |
+
if semantic_color_mapping is None:
|
| 1014 |
+
# Mapping from ids to colors not provided, load it now
|
| 1015 |
+
semantic_color_mapping = load_semantic_color_mapping()
|
| 1016 |
+
|
| 1017 |
+
img_data, color_palette = apply_id_to_color_mapping(
|
| 1018 |
+
img_data, semantic_color_mapping
|
| 1019 |
+
)
|
| 1020 |
+
pil_image = Image.fromarray(img_data, "RGB")
|
| 1021 |
+
|
| 1022 |
+
# Create a PngInfo object to store metadata
|
| 1023 |
+
meta = PngImagePlugin.PngInfo()
|
| 1024 |
+
meta.add_text("id_to_color_mapping", json.dumps(color_palette))
|
| 1025 |
+
|
| 1026 |
+
pil_image.save(fname, pnginfo=meta)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
def _load_generic_mesh(mesh_path: str | Path, **kwargs) -> trimesh.Trimesh:
|
| 1030 |
+
"""Load mesh with the trimesh library.
|
| 1031 |
+
|
| 1032 |
+
Args:
|
| 1033 |
+
mesh_path (str): Path to the mesh file
|
| 1034 |
+
|
| 1035 |
+
Returns:
|
| 1036 |
+
The trimesh object from trimesh.load().
|
| 1037 |
+
|
| 1038 |
+
Raises:
|
| 1039 |
+
ValueError: If the file format is not supported.
|
| 1040 |
+
"""
|
| 1041 |
+
|
| 1042 |
+
# needed to load big texture files
|
| 1043 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 1044 |
+
|
| 1045 |
+
# load mesh with trimesh
|
| 1046 |
+
mesh_data = trimesh.load(mesh_path, process=False)
|
| 1047 |
+
|
| 1048 |
+
return mesh_data
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
def _store_generic_mesh(
|
| 1052 |
+
file_path: str | Path, mesh_data: dict | trimesh.Trimesh, **kwargs
|
| 1053 |
+
) -> None:
|
| 1054 |
+
"""
|
| 1055 |
+
Dummy function for storing generic mesh data.
|
| 1056 |
+
|
| 1057 |
+
Args:
|
| 1058 |
+
file_path (str): The filename to store the mesh in.
|
| 1059 |
+
mesh_data (dict): Dictionary containing mesh data.
|
| 1060 |
+
**kwargs: Additional keyword arguments.
|
| 1061 |
+
|
| 1062 |
+
Raises:
|
| 1063 |
+
NotImplementedError: This function is not implemented yet.
|
| 1064 |
+
"""
|
| 1065 |
+
raise NotImplementedError("Storing generic meshes is not implemented yet.")
|
| 1066 |
+
|
| 1067 |
+
|
| 1068 |
+
def _load_labeled_mesh(
|
| 1069 |
+
file_path: str | Path,
|
| 1070 |
+
fmt: str = "torch",
|
| 1071 |
+
palette: str = "rgb",
|
| 1072 |
+
**kwargs,
|
| 1073 |
+
) -> dict | trimesh.Trimesh:
|
| 1074 |
+
"""
|
| 1075 |
+
Loads a mesh from a labeled mesh file (PLY binary format).
|
| 1076 |
+
|
| 1077 |
+
Args:
|
| 1078 |
+
file_path (str): The path to the labeled mesh file (.ply).
|
| 1079 |
+
fmt (str): Output format of the mesh data. Can be one of:
|
| 1080 |
+
- "torch": Returns a dict of PyTorch tensors containing mesh data.
|
| 1081 |
+
- "np": Returns a dict of NumPy arrays containing mesh data.
|
| 1082 |
+
- "trimesh": Returns a trimesh mesh object.
|
| 1083 |
+
Defaults to "torch".
|
| 1084 |
+
palette (str): Output color of the trimesh mesh data. Can be one of:
|
| 1085 |
+
- "rgb": Colors the mesh with original rgb colors
|
| 1086 |
+
- "semantic_class": Colors the mesh with semantic class colors
|
| 1087 |
+
- "instance": Colors the mesh with semantic instance colors
|
| 1088 |
+
Applied only when fmt is "trimesh".
|
| 1089 |
+
|
| 1090 |
+
Returns:
|
| 1091 |
+
The loaded mesh in the specified output format.
|
| 1092 |
+
|
| 1093 |
+
Raises:
|
| 1094 |
+
NotImplementedError: If the specified output format is not supported.
|
| 1095 |
+
|
| 1096 |
+
Notes:
|
| 1097 |
+
This function reads a binary PLY file with vertex position, color, and optional
|
| 1098 |
+
semantic class and instance IDs. The faces are stored as lists of vertex indices.
|
| 1099 |
+
"""
|
| 1100 |
+
# load data (NOTE: define known_list_len to enable faster read)
|
| 1101 |
+
ply_data = PlyData.read(file_path, known_list_len={"face": {"vertex_indices": 3}})
|
| 1102 |
+
|
| 1103 |
+
# get vertices
|
| 1104 |
+
vertex_data = ply_data["vertex"].data
|
| 1105 |
+
vertices = np.column_stack(
|
| 1106 |
+
(vertex_data["x"], vertex_data["y"], vertex_data["z"])
|
| 1107 |
+
).astype(np.float32)
|
| 1108 |
+
|
| 1109 |
+
# initialize output data
|
| 1110 |
+
mesh_data = {}
|
| 1111 |
+
mesh_data["is_labeled_mesh"] = True
|
| 1112 |
+
mesh_data["vertices"] = vertices
|
| 1113 |
+
|
| 1114 |
+
# get faces if available
|
| 1115 |
+
if "face" in ply_data:
|
| 1116 |
+
faces = np.asarray(ply_data["face"].data["vertex_indices"]).astype(np.int32)
|
| 1117 |
+
mesh_data["faces"] = faces
|
| 1118 |
+
|
| 1119 |
+
# get rgb colors if available
|
| 1120 |
+
if all(color in vertex_data.dtype.names for color in ["red", "green", "blue"]):
|
| 1121 |
+
vertices_color = np.column_stack(
|
| 1122 |
+
(vertex_data["red"], vertex_data["green"], vertex_data["blue"])
|
| 1123 |
+
).astype(np.uint8)
|
| 1124 |
+
mesh_data["vertices_color"] = vertices_color
|
| 1125 |
+
|
| 1126 |
+
# get vertices class and instance if available
|
| 1127 |
+
if "semantic_class_id" in vertex_data.dtype.names:
|
| 1128 |
+
vertices_class = vertex_data["semantic_class_id"].astype(np.int32)
|
| 1129 |
+
mesh_data["vertices_semantic_class_id"] = vertices_class
|
| 1130 |
+
|
| 1131 |
+
if "instance_id" in vertex_data.dtype.names:
|
| 1132 |
+
vertices_instance = vertex_data["instance_id"].astype(np.int32)
|
| 1133 |
+
mesh_data["vertices_instance_id"] = vertices_instance
|
| 1134 |
+
|
| 1135 |
+
# get class colors if available
|
| 1136 |
+
if all(
|
| 1137 |
+
color in vertex_data.dtype.names
|
| 1138 |
+
for color in [
|
| 1139 |
+
"semantic_class_red",
|
| 1140 |
+
"semantic_class_green",
|
| 1141 |
+
"semantic_class_blue",
|
| 1142 |
+
]
|
| 1143 |
+
):
|
| 1144 |
+
vertices_semantic_class_color = np.column_stack(
|
| 1145 |
+
(
|
| 1146 |
+
vertex_data["semantic_class_red"],
|
| 1147 |
+
vertex_data["semantic_class_green"],
|
| 1148 |
+
vertex_data["semantic_class_blue"],
|
| 1149 |
+
)
|
| 1150 |
+
).astype(np.uint8)
|
| 1151 |
+
mesh_data["vertices_semantic_class_color"] = vertices_semantic_class_color
|
| 1152 |
+
|
| 1153 |
+
# get instance colors if available
|
| 1154 |
+
if all(
|
| 1155 |
+
color in vertex_data.dtype.names
|
| 1156 |
+
for color in ["instance_red", "instance_green", "instance_blue"]
|
| 1157 |
+
):
|
| 1158 |
+
vertices_instance_color = np.column_stack(
|
| 1159 |
+
(
|
| 1160 |
+
vertex_data["instance_red"],
|
| 1161 |
+
vertex_data["instance_green"],
|
| 1162 |
+
vertex_data["instance_blue"],
|
| 1163 |
+
)
|
| 1164 |
+
).astype(np.uint8)
|
| 1165 |
+
mesh_data["vertices_instance_color"] = vertices_instance_color
|
| 1166 |
+
|
| 1167 |
+
# convert data into output format (if needed)
|
| 1168 |
+
if fmt == "np":
|
| 1169 |
+
return mesh_data
|
| 1170 |
+
elif fmt == "torch":
|
| 1171 |
+
return {k: torch.tensor(v) for k, v in mesh_data.items()}
|
| 1172 |
+
elif fmt == "trimesh":
|
| 1173 |
+
trimesh_mesh = trimesh.Trimesh(
|
| 1174 |
+
vertices=mesh_data["vertices"], faces=mesh_data["faces"]
|
| 1175 |
+
)
|
| 1176 |
+
# color the mesh according to the palette
|
| 1177 |
+
if palette == "rgb":
|
| 1178 |
+
# original rgb colors
|
| 1179 |
+
if "vertices_color" in mesh_data:
|
| 1180 |
+
trimesh_mesh.visual.vertex_colors = mesh_data["vertices_color"]
|
| 1181 |
+
else:
|
| 1182 |
+
raise ValueError(
|
| 1183 |
+
f"Palette {palette} could not be applied. Missing vertices_color in mesh data."
|
| 1184 |
+
)
|
| 1185 |
+
elif palette == "semantic_class":
|
| 1186 |
+
# semantic class colors
|
| 1187 |
+
if "vertices_semantic_class_color" in mesh_data:
|
| 1188 |
+
trimesh_mesh.visual.vertex_colors = mesh_data[
|
| 1189 |
+
"vertices_semantic_class_color"
|
| 1190 |
+
]
|
| 1191 |
+
else:
|
| 1192 |
+
raise ValueError(
|
| 1193 |
+
f"Palette {palette} could not be applied. Missing vertices_semantic_class_color in mesh data."
|
| 1194 |
+
)
|
| 1195 |
+
elif palette == "instance":
|
| 1196 |
+
# semantic instance colors
|
| 1197 |
+
if "vertices_instance_color" in mesh_data:
|
| 1198 |
+
trimesh_mesh.visual.vertex_colors = mesh_data["vertices_instance_color"]
|
| 1199 |
+
else:
|
| 1200 |
+
raise ValueError(
|
| 1201 |
+
f"Palette {palette} could not be applied. Missing vertices_instance_color in mesh data."
|
| 1202 |
+
)
|
| 1203 |
+
else:
|
| 1204 |
+
raise ValueError(f"Invalid palette: {palette}.")
|
| 1205 |
+
return trimesh_mesh
|
| 1206 |
+
else:
|
| 1207 |
+
raise NotImplementedError(f"Labeled mesh format not supported: {fmt}")
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
def _store_labeled_mesh(file_path: str | Path, mesh_data: dict, **kwargs) -> None:
|
| 1211 |
+
"""
|
| 1212 |
+
Stores a mesh in WAI format (PLY binary format).
|
| 1213 |
+
|
| 1214 |
+
Args:
|
| 1215 |
+
file_path (str): The filename to store the mesh in.
|
| 1216 |
+
mesh_data (dict): Dictionary containing mesh data with keys:
|
| 1217 |
+
- 'vertices' (numpy.ndarray): Array of vertex coordinates with shape (N, 3).
|
| 1218 |
+
- 'faces' (numpy.ndarray, optional): Array of face indices.
|
| 1219 |
+
- 'vertices_color' (numpy.ndarray, optional): Array of vertex colors with shape (N, 3).
|
| 1220 |
+
- 'vertices_semantic_class_id' (numpy.ndarray, optional): Array of semantic classes for each vertex with shape (N).
|
| 1221 |
+
- 'vertices_instance_id' (numpy.ndarray, optional): Array of instance IDs for each vertex with shape (N).
|
| 1222 |
+
- 'vertices_semantic_class_color' (numpy.ndarray, optional): Array of vertex semantic class colors with shape (N, 3).
|
| 1223 |
+
- 'vertices_instance_color' (numpy.ndarray, optional): Array of vertex instance colors with shape (N, 3).
|
| 1224 |
+
|
| 1225 |
+
Notes:
|
| 1226 |
+
This function writes a binary PLY file with vertex position, color, and optional
|
| 1227 |
+
semantic class and instance IDs. The faces are stored as lists of vertex indices.
|
| 1228 |
+
"""
|
| 1229 |
+
# Validate input data
|
| 1230 |
+
if "vertices" not in mesh_data:
|
| 1231 |
+
raise ValueError("Mesh data must contain 'vertices'")
|
| 1232 |
+
|
| 1233 |
+
# create vertex data with properties
|
| 1234 |
+
vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4")]
|
| 1235 |
+
if "vertices_color" in mesh_data:
|
| 1236 |
+
vertex_dtype.extend([("red", "u1"), ("green", "u1"), ("blue", "u1")])
|
| 1237 |
+
if "vertices_semantic_class_id" in mesh_data:
|
| 1238 |
+
vertex_dtype.append(("semantic_class_id", "i4"))
|
| 1239 |
+
if "vertices_instance_id" in mesh_data:
|
| 1240 |
+
vertex_dtype.append(("instance_id", "i4"))
|
| 1241 |
+
if "vertices_semantic_class_color" in mesh_data:
|
| 1242 |
+
vertex_dtype.extend(
|
| 1243 |
+
[
|
| 1244 |
+
("semantic_class_red", "u1"),
|
| 1245 |
+
("semantic_class_green", "u1"),
|
| 1246 |
+
("semantic_class_blue", "u1"),
|
| 1247 |
+
]
|
| 1248 |
+
)
|
| 1249 |
+
if "vertices_instance_color" in mesh_data:
|
| 1250 |
+
vertex_dtype.extend(
|
| 1251 |
+
[("instance_red", "u1"), ("instance_green", "u1"), ("instance_blue", "u1")]
|
| 1252 |
+
)
|
| 1253 |
+
vertex_count = len(mesh_data["vertices"])
|
| 1254 |
+
vertex_data = np.zeros(vertex_count, dtype=vertex_dtype)
|
| 1255 |
+
|
| 1256 |
+
# vertex positions
|
| 1257 |
+
vertex_data["x"] = mesh_data["vertices"][:, 0]
|
| 1258 |
+
vertex_data["y"] = mesh_data["vertices"][:, 1]
|
| 1259 |
+
vertex_data["z"] = mesh_data["vertices"][:, 2]
|
| 1260 |
+
|
| 1261 |
+
# vertex colors
|
| 1262 |
+
if "vertices_color" in mesh_data:
|
| 1263 |
+
vertex_data["red"] = mesh_data["vertices_color"][:, 0]
|
| 1264 |
+
vertex_data["green"] = mesh_data["vertices_color"][:, 1]
|
| 1265 |
+
vertex_data["blue"] = mesh_data["vertices_color"][:, 2]
|
| 1266 |
+
|
| 1267 |
+
# vertex class
|
| 1268 |
+
if "vertices_semantic_class_id" in mesh_data:
|
| 1269 |
+
vertex_data["semantic_class_id"] = mesh_data["vertices_semantic_class_id"]
|
| 1270 |
+
|
| 1271 |
+
# vertex instance
|
| 1272 |
+
if "vertices_instance_id" in mesh_data:
|
| 1273 |
+
vertex_data["instance_id"] = mesh_data["vertices_instance_id"]
|
| 1274 |
+
|
| 1275 |
+
# vertex class colors
|
| 1276 |
+
if "vertices_semantic_class_color" in mesh_data:
|
| 1277 |
+
vertex_data["semantic_class_red"] = mesh_data["vertices_semantic_class_color"][
|
| 1278 |
+
:, 0
|
| 1279 |
+
]
|
| 1280 |
+
vertex_data["semantic_class_green"] = mesh_data[
|
| 1281 |
+
"vertices_semantic_class_color"
|
| 1282 |
+
][:, 1]
|
| 1283 |
+
vertex_data["semantic_class_blue"] = mesh_data["vertices_semantic_class_color"][
|
| 1284 |
+
:, 2
|
| 1285 |
+
]
|
| 1286 |
+
|
| 1287 |
+
# vertex instance colors
|
| 1288 |
+
if "vertices_instance_color" in mesh_data:
|
| 1289 |
+
vertex_data["instance_red"] = mesh_data["vertices_instance_color"][:, 0]
|
| 1290 |
+
vertex_data["instance_green"] = mesh_data["vertices_instance_color"][:, 1]
|
| 1291 |
+
vertex_data["instance_blue"] = mesh_data["vertices_instance_color"][:, 2]
|
| 1292 |
+
|
| 1293 |
+
# initialize data to save
|
| 1294 |
+
vertex_element = PlyElement.describe(vertex_data, "vertex")
|
| 1295 |
+
data_to_save = [vertex_element]
|
| 1296 |
+
|
| 1297 |
+
# faces data
|
| 1298 |
+
if "faces" in mesh_data:
|
| 1299 |
+
face_dtype = [("vertex_indices", "i4", (3,))]
|
| 1300 |
+
face_data = np.zeros(len(mesh_data["faces"]), dtype=face_dtype)
|
| 1301 |
+
face_data["vertex_indices"] = mesh_data["faces"]
|
| 1302 |
+
face_element = PlyElement.describe(face_data, "face")
|
| 1303 |
+
data_to_save.append(face_element)
|
| 1304 |
+
|
| 1305 |
+
# Create and write a binary PLY file
|
| 1306 |
+
ply_data = PlyData(data_to_save, text=False)
|
| 1307 |
+
ply_data.write(file_path)
|
| 1308 |
+
|
| 1309 |
+
|
| 1310 |
+
def _get_method(
|
| 1311 |
+
fname: Path | str, format_type: str | None = None, load: bool = True
|
| 1312 |
+
) -> Callable:
|
| 1313 |
+
"""
|
| 1314 |
+
Returns a method for loading or storing data in a specific format.
|
| 1315 |
+
|
| 1316 |
+
Args:
|
| 1317 |
+
fname (str or Path): The filename to load or store data from/to.
|
| 1318 |
+
format_type (str, optional): The format of the data. If None, it will be inferred from the file extension.
|
| 1319 |
+
Defaults to None.
|
| 1320 |
+
load (bool, optional): Whether to return a method for loading or storing data.
|
| 1321 |
+
Defaults to True.
|
| 1322 |
+
|
| 1323 |
+
Returns:
|
| 1324 |
+
callable: A method for loading or storing data in the specified format.
|
| 1325 |
+
|
| 1326 |
+
Raises:
|
| 1327 |
+
ValueError: If the format cannot be inferred from the file extension.
|
| 1328 |
+
NotImplementedError: If the specified format is not supported.
|
| 1329 |
+
|
| 1330 |
+
Notes:
|
| 1331 |
+
This function supports various formats, including readable files (JSON, YAML), images, NumPy arrays,
|
| 1332 |
+
PyTorch tensors, memory-mapped files, and scene metadata.
|
| 1333 |
+
"""
|
| 1334 |
+
fname = Path(fname)
|
| 1335 |
+
if format_type is None:
|
| 1336 |
+
# use default formats
|
| 1337 |
+
if fname.suffix in [".json", ".yaml", ".yml"]:
|
| 1338 |
+
format_type = "readable"
|
| 1339 |
+
elif fname.suffix in [".jpg", ".jpeg", ".png", ".webp"]:
|
| 1340 |
+
format_type = "image"
|
| 1341 |
+
elif fname.suffix in [".npy", ".npz"]:
|
| 1342 |
+
format_type = "numpy"
|
| 1343 |
+
elif fname.suffix == ".ptz":
|
| 1344 |
+
format_type = "ptz"
|
| 1345 |
+
elif fname.suffix == ".sft":
|
| 1346 |
+
format_type = "sft"
|
| 1347 |
+
elif fname.suffix == ".exr":
|
| 1348 |
+
format_type = "scalar"
|
| 1349 |
+
elif fname.suffix in [".glb", ".obj", ".ply"]:
|
| 1350 |
+
format_type = "mesh"
|
| 1351 |
+
else:
|
| 1352 |
+
raise ValueError(f"Cannot infer format for {fname}")
|
| 1353 |
+
methods = {
|
| 1354 |
+
"readable": (_load_readable, _store_readable),
|
| 1355 |
+
"scalar": (_read_exr, _write_exr),
|
| 1356 |
+
"image": (_load_image, _store_image),
|
| 1357 |
+
"binary": (_load_binary_mask, _store_binary_mask),
|
| 1358 |
+
"latent": (_load_sft, _store_sft),
|
| 1359 |
+
"depth": (_load_depth, _store_depth),
|
| 1360 |
+
"normals": (_load_normals, _store_normals),
|
| 1361 |
+
"numpy": (_load_numpy, _store_numpy),
|
| 1362 |
+
"ptz": (_load_ptz, _store_ptz),
|
| 1363 |
+
"sft": (_load_sft, _store_sft),
|
| 1364 |
+
"mmap": (_load_mmap, _store_mmap),
|
| 1365 |
+
"scene_meta": (_load_scene_meta, _store_scene_meta),
|
| 1366 |
+
"labeled_image": (_load_labeled_image, _store_labeled_image),
|
| 1367 |
+
"mesh": (_load_generic_mesh, _store_generic_mesh),
|
| 1368 |
+
"labeled_mesh": (_load_labeled_mesh, _store_labeled_mesh),
|
| 1369 |
+
}
|
| 1370 |
+
try:
|
| 1371 |
+
return methods[format_type][0 if load else 1]
|
| 1372 |
+
except KeyError as e:
|
| 1373 |
+
raise NotImplementedError(f"Format not supported: {format_type}") from e
|
mapanything/utils/wai/m_ops.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def m_dot(
|
| 6 |
+
transform: torch.Tensor,
|
| 7 |
+
points: torch.Tensor | list,
|
| 8 |
+
maintain_shape: bool = False,
|
| 9 |
+
) -> torch.Tensor | list:
|
| 10 |
+
"""
|
| 11 |
+
Apply batch matrix multiplication between transform matrices and points.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
transform: Batch of transformation matrices [..., 3/4, 3/4]
|
| 15 |
+
points: Batch of points [..., N, 3] or a list of points
|
| 16 |
+
maintain_shape: If True, preserves the original shape of points
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Transformed points with shape [..., N, 3] or a list of transformed points
|
| 20 |
+
"""
|
| 21 |
+
if isinstance(points, list):
|
| 22 |
+
return [m_dot(t, p, maintain_shape) for t, p in zip(transform, points)]
|
| 23 |
+
|
| 24 |
+
# Store original shape and flatten batch dimensions
|
| 25 |
+
orig_shape = points.shape
|
| 26 |
+
batch_dims = points.shape[:-3]
|
| 27 |
+
|
| 28 |
+
# Reshape to standard batch format
|
| 29 |
+
transform_flat = transform.reshape(-1, transform.shape[-2], transform.shape[-1])
|
| 30 |
+
points_flat = points.reshape(transform_flat.shape[0], -1, points.shape[-1])
|
| 31 |
+
|
| 32 |
+
# Apply transformation
|
| 33 |
+
pts = torch.bmm(
|
| 34 |
+
transform_flat[:, :3, :3],
|
| 35 |
+
points_flat[..., :3].permute(0, 2, 1).to(transform_flat.dtype),
|
| 36 |
+
).permute(0, 2, 1)
|
| 37 |
+
|
| 38 |
+
if transform.shape[-1] == 4:
|
| 39 |
+
pts = pts + transform_flat[:, :3, 3].unsqueeze(1)
|
| 40 |
+
|
| 41 |
+
# Restore original shape
|
| 42 |
+
if maintain_shape:
|
| 43 |
+
return pts.reshape(orig_shape)
|
| 44 |
+
else:
|
| 45 |
+
return pts.reshape(*batch_dims, -1, 3)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def m_unproject(
|
| 49 |
+
depth: torch.Tensor,
|
| 50 |
+
intrinsic: torch.Tensor,
|
| 51 |
+
cam2world: torch.Tensor = None,
|
| 52 |
+
img_grid: torch.Tensor = None,
|
| 53 |
+
valid: torch.Tensor = None,
|
| 54 |
+
H: int | None = None,
|
| 55 |
+
W: int | None = None,
|
| 56 |
+
img_feats: torch.Tensor = None,
|
| 57 |
+
maintain_shape: bool = False,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Unproject 2D image points with depth values to 3D points in camera or world space.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
depth: Depth values, either a tensor of shape ...xHxW or a float value
|
| 64 |
+
intrinsic: Camera intrinsic matrix of shape ...x3x3
|
| 65 |
+
cam2world: Optional camera-to-world transformation matrix of shape ...x4x4
|
| 66 |
+
img_grid: Optional pre-computed image grid. If None, will be created
|
| 67 |
+
valid: Optional mask for valid depth values or minimum depth threshold
|
| 68 |
+
H: Image height (required if depth is a scalar)
|
| 69 |
+
W: Image width (required if depth is a scalar)
|
| 70 |
+
img_feats: Optional image features to append to 3D points
|
| 71 |
+
maintain_shape: If True, preserves the original shape of points
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
3D points in camera or world space, with optional features appended
|
| 75 |
+
"""
|
| 76 |
+
# Get device and shape information from intrinsic matrix
|
| 77 |
+
device = intrinsic.device
|
| 78 |
+
pre_shape = intrinsic.shape[:-2] # Batch dimensions
|
| 79 |
+
|
| 80 |
+
# Validate inputs
|
| 81 |
+
if isinstance(depth, (int, float)) and H is None:
|
| 82 |
+
raise ValueError("H must be provided if depth is a scalar")
|
| 83 |
+
|
| 84 |
+
# Determine image dimensions from depth if not provided
|
| 85 |
+
if isinstance(depth, torch.Tensor) and H is None:
|
| 86 |
+
H, W = depth.shape[-2:]
|
| 87 |
+
|
| 88 |
+
# Create image grid if not provided
|
| 89 |
+
if img_grid is None:
|
| 90 |
+
# Create coordinate grid with shape HxWx3 (last dimension is homogeneous)
|
| 91 |
+
img_grid = _create_image_grid(H, W, device)
|
| 92 |
+
# Add homogeneous coordinate
|
| 93 |
+
img_grid = torch.cat([img_grid, torch.ones_like(img_grid[..., :1])], -1)
|
| 94 |
+
|
| 95 |
+
# Expand img_grid to match batch dimensions of intrinsic
|
| 96 |
+
if img_grid.dim() <= intrinsic.dim():
|
| 97 |
+
img_grid = img_grid.unsqueeze(0)
|
| 98 |
+
img_grid = img_grid.expand(*pre_shape, *img_grid.shape[-3:])
|
| 99 |
+
|
| 100 |
+
# Handle valid mask or minimum depth threshold
|
| 101 |
+
depth_mask = None
|
| 102 |
+
if valid is not None:
|
| 103 |
+
if isinstance(valid, float):
|
| 104 |
+
# Create mask for minimum depth value
|
| 105 |
+
depth_mask = depth > valid
|
| 106 |
+
elif isinstance(valid, torch.Tensor):
|
| 107 |
+
depth_mask = valid
|
| 108 |
+
|
| 109 |
+
# Apply mask to image grid and other inputs
|
| 110 |
+
img_grid = masking(img_grid, depth_mask, dim=intrinsic.dim())
|
| 111 |
+
if not isinstance(depth, (int, float)):
|
| 112 |
+
depth = masking(depth, depth_mask, dim=intrinsic.dim() - 1)
|
| 113 |
+
if img_feats is not None:
|
| 114 |
+
img_feats = masking(img_feats, depth_mask, dim=intrinsic.dim() - 1)
|
| 115 |
+
|
| 116 |
+
# Unproject 2D points to 3D camera space
|
| 117 |
+
cam_pts: torch.Tensor = m_dot(
|
| 118 |
+
m_inverse_intrinsics(intrinsic),
|
| 119 |
+
img_grid[..., [1, 0, 2]],
|
| 120 |
+
maintain_shape=True,
|
| 121 |
+
)
|
| 122 |
+
# Scale by depth values
|
| 123 |
+
cam_pts = mult(cam_pts, depth.unsqueeze(-1))
|
| 124 |
+
|
| 125 |
+
# Transform to world space if cam2world is provided
|
| 126 |
+
if cam2world is not None:
|
| 127 |
+
cam_pts = m_dot(cam2world, cam_pts, maintain_shape=True)
|
| 128 |
+
|
| 129 |
+
# Append image features if provided
|
| 130 |
+
if img_feats is not None:
|
| 131 |
+
if isinstance(cam_pts, list):
|
| 132 |
+
if isinstance(cam_pts[0], list):
|
| 133 |
+
# Handle nested list case
|
| 134 |
+
result = []
|
| 135 |
+
for batch_idx, batch in enumerate(cam_pts):
|
| 136 |
+
batch_result = []
|
| 137 |
+
for view_idx, view in enumerate(batch):
|
| 138 |
+
batch_result.append(
|
| 139 |
+
torch.cat([view, img_feats[batch_idx][view_idx]], -1)
|
| 140 |
+
)
|
| 141 |
+
result.append(batch_result)
|
| 142 |
+
cam_pts = result
|
| 143 |
+
else:
|
| 144 |
+
# Handle single list case
|
| 145 |
+
cam_pts = [
|
| 146 |
+
torch.cat([pts, feats], -1)
|
| 147 |
+
for pts, feats in zip(cam_pts, img_feats)
|
| 148 |
+
]
|
| 149 |
+
else:
|
| 150 |
+
# Handle tensor case
|
| 151 |
+
cam_pts = torch.cat([cam_pts, img_feats], -1)
|
| 152 |
+
|
| 153 |
+
if maintain_shape:
|
| 154 |
+
return cam_pts
|
| 155 |
+
|
| 156 |
+
# Flatten last dimension
|
| 157 |
+
return cam_pts.reshape(*pre_shape, -1, 3)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def m_project(
|
| 161 |
+
world_pts: torch.Tensor,
|
| 162 |
+
intrinsic: torch.Tensor,
|
| 163 |
+
world2cam: torch.Tensor | None = None,
|
| 164 |
+
maintain_shape: bool = False,
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Project 3D world points to 2D image coordinates.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
world_pts: 3D points in world coordinates
|
| 171 |
+
intrinsic: Camera intrinsic matrix
|
| 172 |
+
world2cam: Optional transformation from world to camera coordinates
|
| 173 |
+
maintain_shape: If True, preserves the original shape of points
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Image points with coordinates in img_y,img_x,z order
|
| 177 |
+
"""
|
| 178 |
+
# Transform points from world to camera space if world2cam is provided
|
| 179 |
+
cam_pts: torch.Tensor = world_pts
|
| 180 |
+
if world2cam is not None:
|
| 181 |
+
cam_pts = m_dot(world2cam, world_pts, maintain_shape=maintain_shape)
|
| 182 |
+
|
| 183 |
+
# Get shapes to properly expand intrinsics
|
| 184 |
+
shared_dims = intrinsic.shape[:-2]
|
| 185 |
+
extra_dims = cam_pts.shape[len(shared_dims) : -1]
|
| 186 |
+
|
| 187 |
+
# Expand intrinsics to match cam_pts shape
|
| 188 |
+
expanded_intrinsic = intrinsic.view(*shared_dims, *([1] * len(extra_dims)), 3, 3)
|
| 189 |
+
expanded_intrinsic = expanded_intrinsic.expand(*shared_dims, *extra_dims, 3, 3)
|
| 190 |
+
|
| 191 |
+
# Project points from camera space to image space
|
| 192 |
+
depth_abs = cam_pts[..., 2].abs().clamp(min=1e-5)
|
| 193 |
+
return torch.stack(
|
| 194 |
+
[
|
| 195 |
+
expanded_intrinsic[..., 1, 1] * cam_pts[..., 1] / depth_abs
|
| 196 |
+
+ expanded_intrinsic[..., 1, 2],
|
| 197 |
+
expanded_intrinsic[..., 0, 0] * cam_pts[..., 0] / depth_abs
|
| 198 |
+
+ expanded_intrinsic[..., 0, 2],
|
| 199 |
+
cam_pts[..., 2],
|
| 200 |
+
],
|
| 201 |
+
-1,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def in_image(
|
| 206 |
+
image_pts: torch.Tensor | list,
|
| 207 |
+
H: int,
|
| 208 |
+
W: int,
|
| 209 |
+
min_depth: float = 0.0,
|
| 210 |
+
) -> torch.Tensor | list:
|
| 211 |
+
"""
|
| 212 |
+
Check if image points are within the image boundaries.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
image_pts: Image points in pixel coordinates
|
| 216 |
+
H: Image height
|
| 217 |
+
W: Image width
|
| 218 |
+
min_depth: Minimum valid depth
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Boolean mask indicating which points are within the image
|
| 222 |
+
"""
|
| 223 |
+
is_list = isinstance(image_pts, list)
|
| 224 |
+
if is_list:
|
| 225 |
+
return [in_image(pts, H, W, min_depth=min_depth) for pts in image_pts]
|
| 226 |
+
|
| 227 |
+
in_image_mask = (
|
| 228 |
+
torch.all(image_pts >= 0, -1)
|
| 229 |
+
& (image_pts[..., 0] < H)
|
| 230 |
+
& (image_pts[..., 1] < W)
|
| 231 |
+
)
|
| 232 |
+
if (min_depth is not None) and image_pts.shape[-1] == 3:
|
| 233 |
+
in_image_mask &= image_pts[..., 2] > min_depth
|
| 234 |
+
return in_image_mask
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _create_image_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
|
| 238 |
+
"""
|
| 239 |
+
Create a coordinate grid for image pixels.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
H: Image height
|
| 243 |
+
W: Image width
|
| 244 |
+
device: Computation device
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Image grid with shape HxWx3 (last dimension is homogeneous)
|
| 248 |
+
"""
|
| 249 |
+
y_coords = torch.arange(H, device=device)
|
| 250 |
+
x_coords = torch.arange(W, device=device)
|
| 251 |
+
|
| 252 |
+
# Use meshgrid with indexing="ij" for correct orientation
|
| 253 |
+
y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij")
|
| 254 |
+
|
| 255 |
+
# Stack coordinates and add homogeneous coordinate
|
| 256 |
+
img_grid = torch.stack([y_grid, x_grid, torch.ones_like(y_grid)], dim=-1)
|
| 257 |
+
|
| 258 |
+
return img_grid
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def masking(
|
| 262 |
+
X: torch.Tensor | list,
|
| 263 |
+
mask: torch.Tensor | list,
|
| 264 |
+
dim: int = 3,
|
| 265 |
+
) -> torch.Tensor | list:
|
| 266 |
+
"""
|
| 267 |
+
Apply a Boolean mask to tensor or list elements.
|
| 268 |
+
Handles nested structures by recursively applying the mask.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
X: Input tensor or list to be masked
|
| 272 |
+
mask: Boolean mask to apply
|
| 273 |
+
dim: Dimension threshold for recursive processing
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Masked tensor or list with the same structure as input
|
| 277 |
+
"""
|
| 278 |
+
if isinstance(X, list) or (isinstance(X, torch.Tensor) and X.dim() >= dim):
|
| 279 |
+
return [masking(x, m, dim) for x, m in zip(X, mask)]
|
| 280 |
+
return X[mask]
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def m_inverse_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor:
|
| 284 |
+
"""
|
| 285 |
+
Compute the inverse of camera intrinsics matrices analytically.
|
| 286 |
+
This is much faster than using torch.inverse() for intrinsics matrices.
|
| 287 |
+
|
| 288 |
+
The intrinsics matrix has the form:
|
| 289 |
+
K = [fx s cx]
|
| 290 |
+
[0 fy cy]
|
| 291 |
+
[0 0 1]
|
| 292 |
+
|
| 293 |
+
And its inverse is:
|
| 294 |
+
K^-1 = [1/fx -s/(fx*fy) (s*cy-cx*fy)/(fx*fy)]
|
| 295 |
+
[0 1/fy -cy/fy ]
|
| 296 |
+
[0 0 1 ]
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
intrinsics: Camera intrinsics matrices of shape [..., 3, 3]
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Inverse intrinsics matrices of shape [..., 3, 3]
|
| 303 |
+
"""
|
| 304 |
+
# Extract the components of the intrinsics matrix
|
| 305 |
+
fx = intrinsics[..., 0, 0]
|
| 306 |
+
s = intrinsics[..., 0, 1] # skew, usually 0
|
| 307 |
+
cx = intrinsics[..., 0, 2]
|
| 308 |
+
fy = intrinsics[..., 1, 1]
|
| 309 |
+
cy = intrinsics[..., 1, 2]
|
| 310 |
+
|
| 311 |
+
# Create output tensor with same shape and device
|
| 312 |
+
inv_intrinsics = torch.zeros_like(intrinsics)
|
| 313 |
+
|
| 314 |
+
# Compute the inverse analytically
|
| 315 |
+
inv_intrinsics[..., 0, 0] = 1.0 / fx
|
| 316 |
+
inv_intrinsics[..., 0, 1] = -s / (fx * fy)
|
| 317 |
+
inv_intrinsics[..., 0, 2] = (s * cy - cx * fy) / (fx * fy)
|
| 318 |
+
inv_intrinsics[..., 1, 1] = 1.0 / fy
|
| 319 |
+
inv_intrinsics[..., 1, 2] = -cy / fy
|
| 320 |
+
inv_intrinsics[..., 2, 2] = 1.0
|
| 321 |
+
|
| 322 |
+
return inv_intrinsics
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def mult(
|
| 326 |
+
A: torch.Tensor | np.ndarray | list | float | int,
|
| 327 |
+
B: torch.Tensor | np.ndarray | list | float | int,
|
| 328 |
+
) -> torch.Tensor | np.ndarray | list | float | int:
|
| 329 |
+
"""
|
| 330 |
+
Multiply two objects with support for lists, tensors, arrays, and scalars.
|
| 331 |
+
Handles nested structures by recursively applying multiplication.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
A: First operand (tensor, array, list, or scalar)
|
| 335 |
+
B: Second operand (tensor, array, list, or scalar)
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
Result of multiplication with the same structure as inputs
|
| 339 |
+
"""
|
| 340 |
+
if isinstance(A, list) and isinstance(B, (int, float)):
|
| 341 |
+
return [mult(a, B) for a in A]
|
| 342 |
+
if isinstance(B, list) and isinstance(A, (int, float)):
|
| 343 |
+
return [mult(A, b) for b in B]
|
| 344 |
+
if isinstance(A, list) and isinstance(B, list):
|
| 345 |
+
return [mult(a, b) for a, b in zip(A, B)]
|
| 346 |
+
return A * B
|
mapanything/utils/wai/ops.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This utils script contains PORTAGE of wai-core ops methods for MapAnything.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def to_numpy(
|
| 12 |
+
data: torch.Tensor | np.ndarray | int | float,
|
| 13 |
+
dtype: np.dtype | str | type = np.float32,
|
| 14 |
+
) -> np.ndarray:
|
| 15 |
+
"""
|
| 16 |
+
Convert data to a NumPy array with the specified dtype (default: float32).
|
| 17 |
+
|
| 18 |
+
This function handles conversion from NumPy arrays and PyTorch tensors to a NumPy array.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
data: Input data (torch.Tensor, np.ndarray, or scalar)
|
| 22 |
+
dtype: Target data type (NumPy dtype, str, or type). Default: np.float32.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Converted data as NumPy array with specified dtype.
|
| 26 |
+
"""
|
| 27 |
+
# Set default dtype if not defined
|
| 28 |
+
assert dtype is not None, "dtype cannot be None"
|
| 29 |
+
dtype = np.dtype(dtype)
|
| 30 |
+
|
| 31 |
+
# Handle torch.Tensor
|
| 32 |
+
if isinstance(data, torch.Tensor):
|
| 33 |
+
return data.detach().cpu().numpy().astype(dtype)
|
| 34 |
+
|
| 35 |
+
# Handle numpy.ndarray
|
| 36 |
+
if isinstance(data, np.ndarray):
|
| 37 |
+
return data.astype(dtype)
|
| 38 |
+
|
| 39 |
+
# Handle scalar values
|
| 40 |
+
if isinstance(data, (int, float)):
|
| 41 |
+
return np.array(data, dtype=dtype)
|
| 42 |
+
|
| 43 |
+
raise NotImplementedError(f"Unsupported data type: {type(data)}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_dtype_device(
|
| 47 |
+
data: torch.Tensor | np.ndarray | dict | list,
|
| 48 |
+
) -> tuple[torch.dtype | np.dtype | None, torch.device | str | type | None]:
|
| 49 |
+
"""
|
| 50 |
+
Determine the data type and device of the input data.
|
| 51 |
+
|
| 52 |
+
This function recursively inspects the input data and determines its data type
|
| 53 |
+
and device. It handles PyTorch tensors, NumPy arrays, dictionaries, and lists.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
data: Input data (torch.Tensor, np.ndarray, dict, list, or other)
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
tuple: (dtype, device) where:
|
| 60 |
+
- dtype: The data type (torch.dtype or np.dtype)
|
| 61 |
+
- device: The device (torch.device, 'cpu', 'cuda:X', or np.ndarray)
|
| 62 |
+
|
| 63 |
+
Raises:
|
| 64 |
+
ValueError: If tensors in a dictionary are on different CUDA devices
|
| 65 |
+
"""
|
| 66 |
+
if isinstance(data, torch.Tensor):
|
| 67 |
+
return data.dtype, data.device
|
| 68 |
+
|
| 69 |
+
if isinstance(data, np.ndarray):
|
| 70 |
+
return data.dtype, np.ndarray
|
| 71 |
+
|
| 72 |
+
if isinstance(data, dict):
|
| 73 |
+
dtypes = {get_dtype_device(v)[0] for v in data.values()}
|
| 74 |
+
devices = {get_dtype_device(v)[1] for v in data.values()}
|
| 75 |
+
cuda_devices = {device for device in devices if str(device).startswith("cuda")}
|
| 76 |
+
cpu_devices = {device for device in devices if str(device).startswith("cpu")}
|
| 77 |
+
if (len(cuda_devices) > 0) or (len(cpu_devices) > 0):
|
| 78 |
+
# torch.tensor
|
| 79 |
+
dtype = torch.float
|
| 80 |
+
if all(dtype == torch.half for dtype in dtypes):
|
| 81 |
+
dtype = torch.half
|
| 82 |
+
device = None
|
| 83 |
+
if len(cuda_devices) > 1:
|
| 84 |
+
raise ValueError("All tensors must be on the same device")
|
| 85 |
+
if len(cuda_devices) == 1:
|
| 86 |
+
device = list(cuda_devices)[0]
|
| 87 |
+
if (device is None) and (len(cpu_devices) == 1):
|
| 88 |
+
device = list(cpu_devices)[0]
|
| 89 |
+
else:
|
| 90 |
+
dtype = np.float32
|
| 91 |
+
# Fix typo in numpy float16 check
|
| 92 |
+
if all(dtype == np.float16 for dtype in dtypes):
|
| 93 |
+
dtype = np.float16
|
| 94 |
+
device = np.ndarray
|
| 95 |
+
|
| 96 |
+
elif isinstance(data, list):
|
| 97 |
+
if not data: # Handle empty list case
|
| 98 |
+
return None, None
|
| 99 |
+
dtype, device = get_dtype_device(data[0])
|
| 100 |
+
|
| 101 |
+
else:
|
| 102 |
+
return np.float32, np.ndarray
|
| 103 |
+
|
| 104 |
+
return dtype, device
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def crop(
|
| 108 |
+
data: np.ndarray | torch.Tensor | Image.Image,
|
| 109 |
+
bbox: tuple[int, int, int, int] | tuple[int, int],
|
| 110 |
+
) -> np.ndarray | torch.Tensor | Image.Image:
|
| 111 |
+
"""
|
| 112 |
+
Crop data of different formats (numpy arrays, PyTorch tensors, PIL Images) to a target size.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
data: Input data to resize (numpy.ndarray, torch.Tensor, or PIL.Image.Image)
|
| 116 |
+
size: Target size as tuple (offset_height, offset_width, height, width) or tuple (height, width)
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Cropped data in the same format as the input
|
| 120 |
+
"""
|
| 121 |
+
if len(bbox) == 4:
|
| 122 |
+
offset_height, offset_width, target_height, target_width = bbox
|
| 123 |
+
elif len(bbox) == 2:
|
| 124 |
+
target_height, target_width = bbox
|
| 125 |
+
offset_height, offset_width = 0, 0
|
| 126 |
+
else:
|
| 127 |
+
raise ValueError(f"Unsupported size length {len(bbox)}.")
|
| 128 |
+
|
| 129 |
+
end_height = offset_height + target_height
|
| 130 |
+
end_width = offset_width + target_width
|
| 131 |
+
|
| 132 |
+
if any([sz < 0 for sz in bbox]):
|
| 133 |
+
raise ValueError("Bounding box can't have negative values.")
|
| 134 |
+
|
| 135 |
+
if isinstance(data, np.ndarray):
|
| 136 |
+
if (
|
| 137 |
+
max(offset_height, end_height) > data.shape[0]
|
| 138 |
+
or max(offset_width, end_width) > data.shape[1]
|
| 139 |
+
):
|
| 140 |
+
raise ValueError("Invalid bounding box.")
|
| 141 |
+
cropped_data = data[offset_height:end_height, offset_width:end_width, ...]
|
| 142 |
+
return cropped_data
|
| 143 |
+
|
| 144 |
+
# Handle PIL images
|
| 145 |
+
elif isinstance(data, Image.Image):
|
| 146 |
+
if (
|
| 147 |
+
max(offset_height, end_height) > data.size[1]
|
| 148 |
+
or max(offset_width, end_width) > data.size[0]
|
| 149 |
+
):
|
| 150 |
+
raise ValueError("Invalid bounding box.")
|
| 151 |
+
return data.crop((offset_width, offset_height, end_width, end_height))
|
| 152 |
+
|
| 153 |
+
# Handle PyTorch tensors
|
| 154 |
+
elif isinstance(data, torch.Tensor):
|
| 155 |
+
if data.is_nested:
|
| 156 |
+
# special handling for nested tensors
|
| 157 |
+
return torch.stack([crop(nested_tensor, bbox) for nested_tensor in data])
|
| 158 |
+
if (
|
| 159 |
+
max(offset_height, end_height) > data.shape[-2]
|
| 160 |
+
or max(offset_width, end_width) > data.shape[-1]
|
| 161 |
+
):
|
| 162 |
+
raise ValueError("Invalid bounding box.")
|
| 163 |
+
cropped_data = data[..., offset_height:end_height, offset_width:end_width]
|
| 164 |
+
return cropped_data
|
| 165 |
+
else:
|
| 166 |
+
raise TypeError(f"Unsupported data type '{type(data)}'.")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def stack(
|
| 170 |
+
data: list[
|
| 171 |
+
dict[str, torch.Tensor | np.ndarray]
|
| 172 |
+
| list[torch.Tensor | np.ndarray]
|
| 173 |
+
| tuple[torch.Tensor | np.ndarray]
|
| 174 |
+
],
|
| 175 |
+
) -> dict[str, torch.Tensor | np.ndarray] | list[torch.Tensor | np.ndarray]:
|
| 176 |
+
"""
|
| 177 |
+
Stack a list of dictionaries into a single dictionary with stacked values.
|
| 178 |
+
Or when given a list of sublists, stack the sublists using torch or numpy stack
|
| 179 |
+
if the items are of equal size, or nested tensors if the items are PyTorch tensors
|
| 180 |
+
of different size.
|
| 181 |
+
|
| 182 |
+
This utility function is similar to PyTorch's collate function, but specifically
|
| 183 |
+
designed for stacking dictionaries of numpy arrays or PyTorch tensors.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
data (list): A list of dictionaries with the same keys, where values are
|
| 187 |
+
either numpy arrays or PyTorch tensors.
|
| 188 |
+
OR
|
| 189 |
+
A list of sublist, where the values of sublists are PyTorch tensors
|
| 190 |
+
or np arrrays.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
dict: A dictionary with the same keys as input dictionaries, but with values
|
| 194 |
+
stacked along a new first dimension.
|
| 195 |
+
OR
|
| 196 |
+
list: If the input was a list with sublists, it returns a list with a stacked
|
| 197 |
+
output for each original input sublist.
|
| 198 |
+
|
| 199 |
+
Raises:
|
| 200 |
+
ValueError: If dictionaries in the list have inconsistent keys.
|
| 201 |
+
NotImplementedError: If input is not a list or contains non-dictionary elements.
|
| 202 |
+
"""
|
| 203 |
+
if not isinstance(data, list):
|
| 204 |
+
raise NotImplementedError(f"Stack: Data type not supported: {data}")
|
| 205 |
+
|
| 206 |
+
if len(data) == 0:
|
| 207 |
+
return data
|
| 208 |
+
|
| 209 |
+
if all(isinstance(entry, dict) for entry in data):
|
| 210 |
+
stacked_data = {}
|
| 211 |
+
keys = list(data[0].keys())
|
| 212 |
+
if any(set(entry.keys()) != set(keys) for entry in data):
|
| 213 |
+
raise ValueError("Data not consistent for stacking")
|
| 214 |
+
|
| 215 |
+
for key in keys:
|
| 216 |
+
stacked_data[key] = []
|
| 217 |
+
for entry in data:
|
| 218 |
+
stacked_data[key].append(entry[key])
|
| 219 |
+
|
| 220 |
+
# stack it according to data format
|
| 221 |
+
if all(isinstance(v, np.ndarray) for v in stacked_data[key]):
|
| 222 |
+
stacked_data[key] = np.stack(stacked_data[key])
|
| 223 |
+
elif all(isinstance(v, torch.Tensor) for v in stacked_data[key]):
|
| 224 |
+
# Check if all tensors have the same shape
|
| 225 |
+
first_shape = stacked_data[key][0].shape
|
| 226 |
+
if all(tensor.shape == first_shape for tensor in stacked_data[key]):
|
| 227 |
+
stacked_data[key] = torch.stack(stacked_data[key])
|
| 228 |
+
else:
|
| 229 |
+
# Use nested tensors if shapes are not consistent
|
| 230 |
+
stacked_data[key] = torch.nested.nested_tensor(stacked_data[key])
|
| 231 |
+
return stacked_data
|
| 232 |
+
|
| 233 |
+
if all(isinstance(entry, list) for entry in data):
|
| 234 |
+
# new stacked data will be a list with all of the sublist
|
| 235 |
+
stacked_data = []
|
| 236 |
+
for sublist in data:
|
| 237 |
+
# stack it according to data format
|
| 238 |
+
if all(isinstance(v, np.ndarray) for v in sublist):
|
| 239 |
+
stacked_data.append(np.stack(sublist))
|
| 240 |
+
elif all(isinstance(v, torch.Tensor) for v in sublist):
|
| 241 |
+
# Check if all tensors have the same shape
|
| 242 |
+
first_shape = sublist[0].shape
|
| 243 |
+
if all(tensor.shape == first_shape for tensor in sublist):
|
| 244 |
+
stacked_data.append(torch.stack(sublist))
|
| 245 |
+
else:
|
| 246 |
+
# Use nested tensors if shapes are not consistent
|
| 247 |
+
stacked_data.append(torch.nested.nested_tensor(sublist))
|
| 248 |
+
return stacked_data
|
| 249 |
+
|
| 250 |
+
raise NotImplementedError(f"Stack: Data type not supported: {data}")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def resize(
|
| 254 |
+
data: np.ndarray | torch.Tensor | Image.Image,
|
| 255 |
+
size: tuple[int, int] | int | None = None,
|
| 256 |
+
scale: float | None = None,
|
| 257 |
+
modality_format: str | None = None,
|
| 258 |
+
) -> np.ndarray | torch.Tensor | Image.Image:
|
| 259 |
+
"""
|
| 260 |
+
Resize data of different formats (numpy arrays, PyTorch tensors, PIL Images) to a target size.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
data: Input data to resize (numpy.ndarray, torch.Tensor, or PIL.Image.Image)
|
| 264 |
+
size: Target size as tuple (height, width) or single int for long-side scaling
|
| 265 |
+
scale: Scale factor to apply to the original dimensions
|
| 266 |
+
modality_format: Type of data being resized ('depth', 'normals', or None)
|
| 267 |
+
Affects interpolation method used
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Resized data in the same format as the input
|
| 271 |
+
|
| 272 |
+
Raises:
|
| 273 |
+
ValueError: If neither size nor scale is provided, or if both are provided
|
| 274 |
+
TypeError: If data is not a supported type
|
| 275 |
+
"""
|
| 276 |
+
# Validate input parameters
|
| 277 |
+
if size is not None and scale is not None:
|
| 278 |
+
raise ValueError("Only one of size or scale should be provided.")
|
| 279 |
+
|
| 280 |
+
# Calculate size from scale if needed
|
| 281 |
+
if size is None:
|
| 282 |
+
if scale is None:
|
| 283 |
+
raise ValueError("Either size or scale must be provided.")
|
| 284 |
+
|
| 285 |
+
size = (1, 1)
|
| 286 |
+
if isinstance(data, (np.ndarray, torch.Tensor)):
|
| 287 |
+
size = (int(data.shape[-2] * scale), int(data.shape[-1] * scale))
|
| 288 |
+
elif isinstance(data, Image.Image):
|
| 289 |
+
size = (int(data.size[1] * scale), int(data.size[0] * scale))
|
| 290 |
+
else:
|
| 291 |
+
raise TypeError(f"Unsupported data type '{type(data)}'.")
|
| 292 |
+
|
| 293 |
+
# Handle long-side scaling when size is a single integer
|
| 294 |
+
elif isinstance(size, int):
|
| 295 |
+
long_side = size
|
| 296 |
+
if isinstance(data, (np.ndarray, torch.Tensor)):
|
| 297 |
+
if isinstance(data, torch.Tensor) and data.is_nested:
|
| 298 |
+
raise ValueError(
|
| 299 |
+
"Long-side scaling not support for nested tensors, use fixed size instead."
|
| 300 |
+
)
|
| 301 |
+
h, w = data.shape[-2], data.shape[-1]
|
| 302 |
+
elif isinstance(data, Image.Image):
|
| 303 |
+
w, h = data.size
|
| 304 |
+
else:
|
| 305 |
+
raise TypeError(f"Unsupported data type '{type(data)}'.")
|
| 306 |
+
if h > w:
|
| 307 |
+
size = (long_side, int(w * long_side / h))
|
| 308 |
+
else:
|
| 309 |
+
size = (int(h * long_side / w), long_side)
|
| 310 |
+
|
| 311 |
+
target_height, target_width = size
|
| 312 |
+
|
| 313 |
+
# Set interpolation method based on modality
|
| 314 |
+
if modality_format in ["depth", "normals"]:
|
| 315 |
+
interpolation = Image.Resampling.NEAREST
|
| 316 |
+
torch_interpolation = "nearest"
|
| 317 |
+
else:
|
| 318 |
+
interpolation = Image.Resampling.LANCZOS
|
| 319 |
+
torch_interpolation = "bilinear"
|
| 320 |
+
|
| 321 |
+
# Handle numpy arrays
|
| 322 |
+
if isinstance(data, np.ndarray):
|
| 323 |
+
pil_image = Image.fromarray(data)
|
| 324 |
+
resized_image = pil_image.resize((target_width, target_height), interpolation)
|
| 325 |
+
return np.array(resized_image)
|
| 326 |
+
|
| 327 |
+
# Handle PIL images
|
| 328 |
+
elif isinstance(data, Image.Image):
|
| 329 |
+
return data.resize((target_width, target_height), interpolation)
|
| 330 |
+
|
| 331 |
+
# Handle PyTorch tensors
|
| 332 |
+
elif isinstance(data, torch.Tensor):
|
| 333 |
+
if data.is_nested:
|
| 334 |
+
# special handling for nested tensors
|
| 335 |
+
return torch.stack(
|
| 336 |
+
[
|
| 337 |
+
resize(nested_tensor, size, scale, modality_format)
|
| 338 |
+
for nested_tensor in data
|
| 339 |
+
]
|
| 340 |
+
)
|
| 341 |
+
original_dim = data.ndim
|
| 342 |
+
if original_dim == 2: # (H, W)
|
| 343 |
+
data = data.unsqueeze(0).unsqueeze(0) # Add channel and batch dimensions
|
| 344 |
+
elif original_dim == 3: # (C/B, H W)
|
| 345 |
+
if modality_format == "depth":
|
| 346 |
+
data = data.unsqueeze(1) # channel batch dimension
|
| 347 |
+
else:
|
| 348 |
+
data = data.unsqueeze(0) # Add batch dimension
|
| 349 |
+
resized_tensor = F.interpolate(
|
| 350 |
+
data,
|
| 351 |
+
size=(target_height, target_width),
|
| 352 |
+
mode=torch_interpolation,
|
| 353 |
+
align_corners=False if torch_interpolation != "nearest" else None,
|
| 354 |
+
)
|
| 355 |
+
if original_dim == 2:
|
| 356 |
+
return resized_tensor.squeeze(0).squeeze(
|
| 357 |
+
0
|
| 358 |
+
) # Remove batch and channel dimensions
|
| 359 |
+
elif original_dim == 3:
|
| 360 |
+
if modality_format == "depth":
|
| 361 |
+
return resized_tensor.squeeze(1) # Remove channel dimension
|
| 362 |
+
|
| 363 |
+
return resized_tensor.squeeze(0) # Remove batch dimension
|
| 364 |
+
else:
|
| 365 |
+
return resized_tensor
|
| 366 |
+
|
| 367 |
+
else:
|
| 368 |
+
raise TypeError(f"Unsupported data type '{type(data)}'.")
|
mapanything/utils/wai/scene_frame.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from mapanything.utils.wai.io import (
|
| 11 |
+
_load_readable,
|
| 12 |
+
_load_scene_meta,
|
| 13 |
+
get_processing_state,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_scene_frame_names(
|
| 20 |
+
cfg: dict | object,
|
| 21 |
+
root: Path | str | None = None,
|
| 22 |
+
scene_frames_fn: str | None = None,
|
| 23 |
+
keyframes: bool = True,
|
| 24 |
+
) -> dict[str, list[str | float]] | None:
|
| 25 |
+
"""
|
| 26 |
+
Retrieve scene frame names based on configuration and optional parameters.
|
| 27 |
+
|
| 28 |
+
This function determines the scene frame names by resolving the scene frame file
|
| 29 |
+
and applying any necessary filters based on the provided configuration.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
cfg: Configuration object containing settings and parameters.
|
| 33 |
+
root: Optional root directory path. If not provided, it will be fetched from cfg.
|
| 34 |
+
scene_frames_fn: Optional scene frames file name. If not provided, it will be fetched from cfg.
|
| 35 |
+
keyframes: Optional, used only for a video. If True (default), return only keyframes (with camera poses).
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
A dictionary mapping scene names to their respective frame names.
|
| 39 |
+
"""
|
| 40 |
+
scene_frames_fn = (
|
| 41 |
+
cfg.get("scene_frames_fn") if scene_frames_fn is None else scene_frames_fn
|
| 42 |
+
)
|
| 43 |
+
scene_frame_names = None
|
| 44 |
+
if scene_frames_fn is not None:
|
| 45 |
+
# load scene_frames based on scene_frame file
|
| 46 |
+
scene_frame_names = _resolve_scene_frames_fn(scene_frames_fn)
|
| 47 |
+
|
| 48 |
+
scene_names = get_scene_names(
|
| 49 |
+
cfg,
|
| 50 |
+
root=root,
|
| 51 |
+
scene_names=(
|
| 52 |
+
list(scene_frame_names.keys()) if scene_frame_names is not None else None
|
| 53 |
+
),
|
| 54 |
+
)
|
| 55 |
+
scene_frame_names = _resolve_scene_frame_names(
|
| 56 |
+
cfg,
|
| 57 |
+
scene_names,
|
| 58 |
+
root=root,
|
| 59 |
+
scene_frame_names=scene_frame_names,
|
| 60 |
+
keyframes=keyframes,
|
| 61 |
+
)
|
| 62 |
+
return scene_frame_names
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_scene_names(
|
| 66 |
+
cfg: dict | object,
|
| 67 |
+
root: Path | str | None = None,
|
| 68 |
+
scene_names: list[str] | None = None,
|
| 69 |
+
shuffle: bool = False,
|
| 70 |
+
) -> list[str]:
|
| 71 |
+
"""
|
| 72 |
+
Retrieve scene names based on the provided configuration and optional parameters.
|
| 73 |
+
|
| 74 |
+
This function determines the scene names by checking the root directory for subdirectories
|
| 75 |
+
and applying any necessary filters based on the provided configuration.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
cfg: Configuration object containing settings and parameters.
|
| 79 |
+
root: Optional root directory path. If not provided, it will be fetched from cfg.
|
| 80 |
+
scene_names: Optional list of scene names. If not provided, it will be determined from the root directory.
|
| 81 |
+
shuffle: Optional bool. Default to False. If True, it will return the list of scene names in random order.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
A list of scene names after applying any filters specified in the configuration.
|
| 85 |
+
"""
|
| 86 |
+
root = cfg.get("root") if root is None else root
|
| 87 |
+
if root is not None:
|
| 88 |
+
# Check if the root exists
|
| 89 |
+
if not Path(root).exists():
|
| 90 |
+
raise IOError(f"Root directory does not exist: {root}")
|
| 91 |
+
|
| 92 |
+
# Check if the root is a directory
|
| 93 |
+
if not Path(root).is_dir():
|
| 94 |
+
raise IOError(f"Root directory is not a directory: {root}")
|
| 95 |
+
|
| 96 |
+
if scene_names is None:
|
| 97 |
+
scene_filters = cfg.get("scene_filters")
|
| 98 |
+
if (
|
| 99 |
+
scene_filters
|
| 100 |
+
and len(scene_filters) == 1
|
| 101 |
+
and isinstance(scene_filters[0], list)
|
| 102 |
+
and all(isinstance(entry, str) for entry in scene_filters[0])
|
| 103 |
+
):
|
| 104 |
+
# Shortcut the scene_names if the scene_filters is only a list of scene names
|
| 105 |
+
scene_names = scene_filters[0]
|
| 106 |
+
else:
|
| 107 |
+
# List all subdirectories in the root as scenes
|
| 108 |
+
scene_names = sorted(
|
| 109 |
+
[entry.name for entry in os.scandir(root) if entry.is_dir()]
|
| 110 |
+
)
|
| 111 |
+
# Filter scenes based on scene_filters
|
| 112 |
+
scene_names = _filter_scenes(root, scene_names, cfg.get("scene_filters"))
|
| 113 |
+
|
| 114 |
+
# shuffle the list if needed (in place)
|
| 115 |
+
if shuffle:
|
| 116 |
+
random.shuffle(scene_names)
|
| 117 |
+
|
| 118 |
+
return scene_names
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _filter_scenes(
|
| 122 |
+
root: Path | str,
|
| 123 |
+
scene_names: list[str],
|
| 124 |
+
scene_filters: tuple | list | None,
|
| 125 |
+
) -> list[str]:
|
| 126 |
+
if scene_filters is None:
|
| 127 |
+
return scene_names
|
| 128 |
+
|
| 129 |
+
if not isinstance(scene_filters, (tuple, list)):
|
| 130 |
+
raise ValueError("scene_filters must be a list or tuple")
|
| 131 |
+
|
| 132 |
+
for scene_filter in scene_filters:
|
| 133 |
+
if scene_filter in [None, "all"]:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
elif isinstance(scene_filter, (tuple, list)):
|
| 137 |
+
if len(scene_filter) == 0:
|
| 138 |
+
raise ValueError("scene_filter cannot be empty")
|
| 139 |
+
|
| 140 |
+
elif all(isinstance(x, int) for x in scene_filter):
|
| 141 |
+
if len(scene_filter) == 2:
|
| 142 |
+
# start/end index
|
| 143 |
+
scene_names = scene_names[scene_filter[0] : scene_filter[1]]
|
| 144 |
+
elif len(scene_filter) == 3:
|
| 145 |
+
# start/end/step
|
| 146 |
+
scene_names = scene_names[
|
| 147 |
+
scene_filter[0] : scene_filter[1] : scene_filter[2]
|
| 148 |
+
]
|
| 149 |
+
else:
|
| 150 |
+
# omegaconf conversion issue (converts strings to integers whenever possible)
|
| 151 |
+
if str(scene_filter[0]) in scene_names:
|
| 152 |
+
scene_names = [str(s) for s in scene_filter]
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError(
|
| 155 |
+
"scene_filter format [start_idx, end_idx] or [start_idx, end_idx, step_size] or [scene_name1, scene_name2, ...]"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
elif all(isinstance(x, str) for x in scene_filter):
|
| 159 |
+
# explicit scene names
|
| 160 |
+
if set(scene_filter).issubset(set(scene_names)):
|
| 161 |
+
scene_names = list(scene_filter)
|
| 162 |
+
else:
|
| 163 |
+
logger.warning(
|
| 164 |
+
f"Scene(s) not available: {set(scene_filter) - set(scene_names)}"
|
| 165 |
+
)
|
| 166 |
+
scene_names = list(set(scene_names) & set(scene_filter))
|
| 167 |
+
else:
|
| 168 |
+
raise TypeError(
|
| 169 |
+
f"Scene filter type not supported: {type(scene_filter)}"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
elif isinstance(scene_filter, dict):
|
| 173 |
+
# reserved key words
|
| 174 |
+
if modality := scene_filter.get("exists"):
|
| 175 |
+
scene_names = [
|
| 176 |
+
scene_name
|
| 177 |
+
for scene_name in scene_names
|
| 178 |
+
if Path(root, scene_name, modality).exists()
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
elif modality := scene_filter.get("exists_not"):
|
| 182 |
+
scene_names = [
|
| 183 |
+
scene_name
|
| 184 |
+
for scene_name in scene_names
|
| 185 |
+
if not Path(root, scene_name, modality).exists()
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
elif process_filter := scene_filter.get("process_state"):
|
| 189 |
+
# filter for where <process_key> has <process_state>
|
| 190 |
+
(process_key, process_state) = process_filter
|
| 191 |
+
filtered_scene_names = []
|
| 192 |
+
for scene_name in scene_names:
|
| 193 |
+
# load processing state and check for
|
| 194 |
+
processing_state = get_processing_state(Path(root, scene_name))
|
| 195 |
+
if "*" in process_key: # regex matching
|
| 196 |
+
for process_name in processing_state:
|
| 197 |
+
if re.match(process_key, process_name):
|
| 198 |
+
process_key = process_name
|
| 199 |
+
break
|
| 200 |
+
if process_key not in processing_state:
|
| 201 |
+
continue
|
| 202 |
+
if processing_state[process_key]["state"] == process_state:
|
| 203 |
+
filtered_scene_names.append(scene_name)
|
| 204 |
+
scene_names = filtered_scene_names
|
| 205 |
+
|
| 206 |
+
elif process_filter := scene_filter.get("process_state_not"):
|
| 207 |
+
# filter for where <process_key> does not have <process_state>
|
| 208 |
+
(process_key, process_state) = process_filter
|
| 209 |
+
filtered_scene_names = []
|
| 210 |
+
for scene_name in scene_names:
|
| 211 |
+
# load processing state and check for
|
| 212 |
+
try:
|
| 213 |
+
processing_state = get_processing_state(Path(root, scene_name))
|
| 214 |
+
except Exception:
|
| 215 |
+
filtered_scene_names.append(scene_name)
|
| 216 |
+
continue
|
| 217 |
+
if "*" in process_key: # regex matching
|
| 218 |
+
for process_name in processing_state:
|
| 219 |
+
if re.match(process_key, process_name):
|
| 220 |
+
process_key = process_name
|
| 221 |
+
break
|
| 222 |
+
if (process_key not in processing_state) or (
|
| 223 |
+
processing_state[process_key]["state"] != process_state
|
| 224 |
+
):
|
| 225 |
+
filtered_scene_names.append(scene_name)
|
| 226 |
+
scene_names = filtered_scene_names
|
| 227 |
+
|
| 228 |
+
else:
|
| 229 |
+
raise ValueError(f"Scene filter not supported: {scene_filter}")
|
| 230 |
+
|
| 231 |
+
elif isinstance(scene_filter, str):
|
| 232 |
+
# regex
|
| 233 |
+
scene_names = [
|
| 234 |
+
scene_name
|
| 235 |
+
for scene_name in scene_names
|
| 236 |
+
if re.fullmatch(scene_filter, scene_name)
|
| 237 |
+
]
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(f"Scene filter not supported: {scene_filter}")
|
| 240 |
+
|
| 241 |
+
return scene_names
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _resolve_scene_frames_fn(scene_frames_fn: str) -> dict[str, list[str] | None]:
|
| 245 |
+
# support for file list in forms of lists or dicts
|
| 246 |
+
# containing scene_names [-> frames]
|
| 247 |
+
scene_frames_list = _load_readable(scene_frames_fn)
|
| 248 |
+
scene_frame_names = {}
|
| 249 |
+
|
| 250 |
+
# TODO: The following code seems unreachable as scene_frames_list is always a dict
|
| 251 |
+
if isinstance(scene_frames_list, (list, tuple)):
|
| 252 |
+
for entry in scene_frames_list:
|
| 253 |
+
if isinstance(entry, (tuple, list)):
|
| 254 |
+
if (
|
| 255 |
+
(len(entry) != 2)
|
| 256 |
+
or (not isinstance(entry[0], str))
|
| 257 |
+
or (not isinstance(entry[1], list))
|
| 258 |
+
):
|
| 259 |
+
raise NotImplementedError(
|
| 260 |
+
"Only supports lists of [<scene_name>, [frame_names]]"
|
| 261 |
+
)
|
| 262 |
+
scene_frame_names[entry[0]] = entry[1]
|
| 263 |
+
elif isinstance(entry, str):
|
| 264 |
+
scene_frame_names[entry] = None
|
| 265 |
+
elif isinstance(entry, dict):
|
| 266 |
+
# scene_name -> frames
|
| 267 |
+
raise NotImplementedError("Dict entry not supported yet")
|
| 268 |
+
else:
|
| 269 |
+
raise IOError(f"File list contains an entry of wrong format: {entry}")
|
| 270 |
+
|
| 271 |
+
elif isinstance(scene_frames_list, dict):
|
| 272 |
+
# scene_name -> frames
|
| 273 |
+
for scene_name, frame in scene_frames_list.items():
|
| 274 |
+
if isinstance(frame, (tuple, list)):
|
| 275 |
+
scene_frame_names[scene_name] = frame
|
| 276 |
+
elif isinstance(frame, dict):
|
| 277 |
+
if "frame_names" in frame:
|
| 278 |
+
scene_frame_names[scene_name] = frame["frame_names"]
|
| 279 |
+
else:
|
| 280 |
+
raise IOError(f"Scene frames format not supported: {frame}")
|
| 281 |
+
elif frame is None:
|
| 282 |
+
scene_frame_names[scene_name] = frame
|
| 283 |
+
else:
|
| 284 |
+
raise IOError(f"Scene frames format not supported: {frame}")
|
| 285 |
+
|
| 286 |
+
else:
|
| 287 |
+
raise IOError(f"Scene frames format not supported: {scene_frames_list}")
|
| 288 |
+
|
| 289 |
+
return scene_frame_names
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def _resolve_scene_frame_names(
|
| 293 |
+
cfg: dict | object,
|
| 294 |
+
scene_names: list[str],
|
| 295 |
+
root: Path | str | None = None,
|
| 296 |
+
scene_frame_names: dict[str, list[str | float] | None] | None = None,
|
| 297 |
+
keyframes: bool = True,
|
| 298 |
+
) -> dict[str, list[str]]:
|
| 299 |
+
root = cfg.get("root") if root is None else root
|
| 300 |
+
if scene_frame_names is not None:
|
| 301 |
+
# restrict to the additional scene-level prefiltering
|
| 302 |
+
scene_frame_names = {
|
| 303 |
+
scene_name: scene_frame_names[scene_name] for scene_name in scene_names
|
| 304 |
+
}
|
| 305 |
+
# dict already loaded, apply additional filters
|
| 306 |
+
for scene_name, frame_names in scene_frame_names.items():
|
| 307 |
+
if frame_names is None:
|
| 308 |
+
scene_meta = _load_scene_meta(
|
| 309 |
+
Path(
|
| 310 |
+
root, scene_name, cfg.get("scene_meta_path", "scene_meta.json")
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
frame_names = [frame["frame_name"] for frame in scene_meta["frames"]]
|
| 314 |
+
# TODO: add some logic for video keyframes
|
| 315 |
+
|
| 316 |
+
scene_frame_names[scene_name] = _filter_frame_names(
|
| 317 |
+
root, frame_names, scene_name, cfg.get("frame_filters")
|
| 318 |
+
)
|
| 319 |
+
else:
|
| 320 |
+
scene_frame_names = {}
|
| 321 |
+
for scene_name in scene_names:
|
| 322 |
+
scene_meta = _load_scene_meta(
|
| 323 |
+
Path(root, scene_name, cfg.get("scene_meta_path", "scene_meta.json"))
|
| 324 |
+
)
|
| 325 |
+
if not keyframes:
|
| 326 |
+
frame_names = get_video_frames(scene_meta)
|
| 327 |
+
if frame_names is None:
|
| 328 |
+
keyframes = True
|
| 329 |
+
if keyframes:
|
| 330 |
+
frame_names = [frame["frame_name"] for frame in scene_meta["frames"]]
|
| 331 |
+
frame_names = _filter_frame_names(
|
| 332 |
+
root, frame_names, scene_name, cfg.get("frame_filters")
|
| 333 |
+
)
|
| 334 |
+
scene_frame_names[scene_name] = frame_names
|
| 335 |
+
return scene_frame_names
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _filter_frame_names(
|
| 339 |
+
root: Path | str,
|
| 340 |
+
frame_names: list[str],
|
| 341 |
+
scene_name: str,
|
| 342 |
+
frame_filters: list | tuple | None,
|
| 343 |
+
) -> list[str]:
|
| 344 |
+
if frame_filters is None:
|
| 345 |
+
return frame_names
|
| 346 |
+
|
| 347 |
+
if not isinstance(frame_filters, (tuple, list)):
|
| 348 |
+
raise ValueError("frame_filters must be a list or tuple")
|
| 349 |
+
|
| 350 |
+
for frame_filter in frame_filters:
|
| 351 |
+
if frame_filter in [None, "all"]:
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
+
elif isinstance(frame_filter, (tuple, list)):
|
| 355 |
+
if len(frame_filter) == 0:
|
| 356 |
+
raise ValueError("frame_filter cannot be empty")
|
| 357 |
+
|
| 358 |
+
if isinstance(frame_filter[0], int):
|
| 359 |
+
if len(frame_filter) == 2:
|
| 360 |
+
# start/end index
|
| 361 |
+
frame_names = frame_names[frame_filter[0] : frame_filter[1]]
|
| 362 |
+
|
| 363 |
+
elif len(frame_filter) == 3:
|
| 364 |
+
# start/end/step
|
| 365 |
+
frame_names = frame_names[
|
| 366 |
+
frame_filter[0] : frame_filter[1] : frame_filter[2]
|
| 367 |
+
]
|
| 368 |
+
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(
|
| 371 |
+
"frame_filter format [start_idx, end_idx] or [start_idx, end_idx,step_size]"
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
raise TypeError(
|
| 375 |
+
f"frame_filter[0] type not supported: {type(frame_filter[0])}"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
elif isinstance(frame_filter, str):
|
| 379 |
+
# reserved key words
|
| 380 |
+
if match := re.match("exists: (.+)", frame_filter):
|
| 381 |
+
modality = match.group(1)
|
| 382 |
+
frame_names = [
|
| 383 |
+
frame_name
|
| 384 |
+
for frame_name in frame_names
|
| 385 |
+
if any(Path(root, scene_name, modality).glob(f"{frame_name}.*"))
|
| 386 |
+
]
|
| 387 |
+
|
| 388 |
+
elif match := re.match("!exists: (.+)", frame_filter):
|
| 389 |
+
modality = match.group(1)
|
| 390 |
+
frame_names = [
|
| 391 |
+
frame_name
|
| 392 |
+
for frame_name in frame_names
|
| 393 |
+
if not any(Path(root, scene_name, modality).glob(f"{frame_name}.*"))
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
else: # general regex
|
| 397 |
+
frame_names = [
|
| 398 |
+
frame_name
|
| 399 |
+
for frame_name in frame_names
|
| 400 |
+
if re.match(frame_filter, frame_name)
|
| 401 |
+
]
|
| 402 |
+
|
| 403 |
+
else:
|
| 404 |
+
raise ValueError(f"frame_filter type not supported: {type(frame_filter)}")
|
| 405 |
+
|
| 406 |
+
return frame_names
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def get_video_frames(scene_meta: dict[str, Any]):
|
| 410 |
+
"""
|
| 411 |
+
Return names of video frames.
|
| 412 |
+
Args:
|
| 413 |
+
scene_meta: dictionary with scene_meat data.
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
A list of video frame names.
|
| 417 |
+
"""
|
| 418 |
+
image_modality = [mod for mod in scene_meta["frame_modalities"] if "image" in mod]
|
| 419 |
+
if len(image_modality) > 0:
|
| 420 |
+
image_modality = scene_meta["frame_modalities"][image_modality[0]]
|
| 421 |
+
if "chunks" in image_modality:
|
| 422 |
+
file_list = image_modality["chunks"]
|
| 423 |
+
else:
|
| 424 |
+
file_list = [image_modality]
|
| 425 |
+
frame_names = []
|
| 426 |
+
for chunk in file_list:
|
| 427 |
+
start, end, fps = chunk["start"], chunk["end"], chunk["fps"]
|
| 428 |
+
chunk_frame_names = np.arange(start, end, 1.0 / fps).tolist()
|
| 429 |
+
frame_names += chunk_frame_names
|
| 430 |
+
return frame_names
|
| 431 |
+
return None
|
mapanything/utils/wai/semantics.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This utils script contains PORTAGE of wai-core semantics methods for MapAnything.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
INVALID_ID = 0
|
| 9 |
+
INVALID_COLOR = (0, 0, 0)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_semantic_color_mapping(filename: str = "colors_fps_5k.npz") -> np.ndarray:
|
| 13 |
+
"""Loads a precomputed colormap."""
|
| 14 |
+
from mapanything.utils.wai.core import WAI_COLORMAP_PATH
|
| 15 |
+
|
| 16 |
+
return np.load(WAI_COLORMAP_PATH / filename).get("arr_0")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def apply_id_to_color_mapping(
|
| 20 |
+
data_id: np.ndarray | Image.Image,
|
| 21 |
+
semantic_color_mapping: np.ndarray,
|
| 22 |
+
) -> tuple[np.ndarray, dict[int, tuple[int, int, int]]]:
|
| 23 |
+
"""Maps semantic class/instance IDs to RGB colors."""
|
| 24 |
+
if isinstance(data_id, Image.Image):
|
| 25 |
+
data_id = np.array(data_id)
|
| 26 |
+
|
| 27 |
+
max_color_id = semantic_color_mapping.shape[0] - 1
|
| 28 |
+
max_data_id = data_id.max()
|
| 29 |
+
if max_data_id > max_color_id:
|
| 30 |
+
raise ValueError("The provided color palette does not have enough colors!")
|
| 31 |
+
|
| 32 |
+
# Create palette containing the id->color mappings of the input data IDs
|
| 33 |
+
unique_indices = np.unique(data_id).tolist()
|
| 34 |
+
color_palette = {
|
| 35 |
+
index: semantic_color_mapping[index, :].tolist() for index in unique_indices
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
data_colors = semantic_color_mapping[data_id]
|
| 39 |
+
|
| 40 |
+
return data_colors, color_palette
|
requirements.txt
CHANGED
|
@@ -18,4 +18,5 @@ einops
|
|
| 18 |
requests
|
| 19 |
psutil
|
| 20 |
tqdm
|
|
|
|
| 21 |
uniception==0.1.4
|
|
|
|
| 18 |
requests
|
| 19 |
psutil
|
| 20 |
tqdm
|
| 21 |
+
safetensors
|
| 22 |
uniception==0.1.4
|