aknapitsch user commited on
Commit
37de32d
·
1 Parent(s): 8c1e404

simpler inference and refactoring

Browse files
Files changed (47) hide show
  1. app.py +165 -445
  2. hf_utils/vgg_geometry.py +0 -166
  3. hf_utils/visual_util.py +5 -5
  4. mapanything/__init__.py +0 -0
  5. mapanything/datasets/wai/ase.py +1 -1
  6. mapanything/datasets/wai/bedlam.py +1 -1
  7. mapanything/datasets/wai/blendedmvs.py +2 -9
  8. mapanything/datasets/wai/dl3dv.py +7 -28
  9. mapanything/datasets/wai/dtu.py +1 -1
  10. mapanything/datasets/wai/dynamicreplica.py +1 -1
  11. mapanything/datasets/wai/eth3d.py +1 -1
  12. mapanything/datasets/wai/gta_sfm.py +1 -1
  13. mapanything/datasets/wai/matrixcity.py +1 -1
  14. mapanything/datasets/wai/megadepth.py +2 -9
  15. mapanything/datasets/wai/mpsd.py +2 -9
  16. mapanything/datasets/wai/mvs_synth.py +1 -1
  17. mapanything/datasets/wai/paralleldomain4d.py +1 -1
  18. mapanything/datasets/wai/sailvos3d.py +1 -1
  19. mapanything/datasets/wai/scannetpp.py +1 -1
  20. mapanything/datasets/wai/spring.py +2 -9
  21. mapanything/datasets/wai/structured3d.py +1 -1
  22. mapanything/datasets/wai/tav2_wb.py +2 -9
  23. mapanything/datasets/wai/unrealstereo4k.py +1 -1
  24. mapanything/datasets/wai/xrooms.py +1 -1
  25. mapanything/models/external/README.md +5 -0
  26. mapanything/models/external/moge/models/v1.py +1 -1
  27. mapanything/models/external/moge/models/v2.py +1 -1
  28. mapanything/models/mapanything/ablations.py +4 -2
  29. mapanything/models/mapanything/model.py +220 -4
  30. mapanything/models/mapanything/modular_dust3r.py +4 -2
  31. mapanything/train/losses.py +283 -9
  32. mapanything/utils/geometry.py +91 -0
  33. mapanything/utils/image.py +11 -10
  34. mapanything/utils/inference.py +389 -0
  35. mapanything/utils/viz.py +2 -2
  36. mapanything/utils/wai/__init__.py +3 -0
  37. mapanything/utils/wai/basic_dataset.py +131 -0
  38. mapanything/utils/wai/camera.py +263 -0
  39. mapanything/utils/wai/colormaps/colors_fps_5k.npz +3 -0
  40. mapanything/utils/wai/core.py +492 -0
  41. mapanything/utils/wai/intersection_check.py +462 -0
  42. mapanything/utils/wai/io.py +1373 -0
  43. mapanything/utils/wai/m_ops.py +346 -0
  44. mapanything/utils/wai/ops.py +368 -0
  45. mapanything/utils/wai/scene_frame.py +431 -0
  46. mapanything/utils/wai/semantics.py +40 -0
  47. requirements.txt +1 -0
app.py CHANGED
@@ -18,24 +18,21 @@ import gradio as gr
18
  import numpy as np
19
  import spaces
20
  import torch
21
- from huggingface_hub import hf_hub_download
22
 
23
  sys.path.append("mapanything/")
24
 
25
  from hf_utils.css_and_html import (
 
 
26
  get_acknowledgements_html,
27
  get_description_html,
28
  get_gradio_theme,
29
  get_header_html,
30
- GRADIO_CSS,
31
- MEASURE_INSTRUCTIONS_HTML,
32
  )
33
- from hf_utils.vgg_geometry import unproject_depth_map_to_point_map
34
  from hf_utils.visual_util import predictions_to_glb
35
- from mapanything.models import init_model
36
- from mapanything.utils.geometry import depth_edge, normals_edge, points_to_normals
37
  from mapanything.utils.image import load_images, rgb
38
- from mapanything.utils.inference import loss_of_one_batch_multi_view
39
 
40
 
41
  def get_logo_base64():
@@ -103,69 +100,16 @@ def init_hydra_config(config_path, overrides=None):
103
  return cfg
104
 
105
 
106
- def init_inference_model(config, ckpt_path, device):
107
- "Initialize the model for inference"
108
- if isinstance(config, dict):
109
- config_path = config["path"]
110
- overrrides = config["config_overrides"]
111
- model_args = init_hydra_config(config_path, overrides=overrrides)
112
- model = init_model(model_args.model.model_str, model_args.model.model_config)
113
- else:
114
- config_path = config
115
- model_args = init_hydra_config(config_path)
116
- model = init_model(model_args.model_str, model_args.model_config)
117
- model.to(device)
118
- if ckpt_path is not None:
119
- print("Loading model from: ", ckpt_path)
120
-
121
- # Load HuggingFace token for private repositories
122
- hf_token = load_hf_token()
123
-
124
- # Try to download from HuggingFace Hub first if it's a HF URL
125
- if "huggingface.co" in ckpt_path:
126
- try:
127
- # Extract repo_id and filename from URL
128
- # URL format: https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth
129
- parts = ckpt_path.replace("https://huggingface.co/", "").split("/")
130
- repo_id = f"{parts[0]}/{parts[1]}" # e.g., "facebook/MapAnything"
131
- filename = "/".join(
132
- parts[4:]
133
- ) # e.g., "mapa_curri_24v_13d_48ipg_64g.pth"
134
-
135
- print(f"Downloading from HuggingFace Hub: {repo_id}/{filename}")
136
- local_file = hf_hub_download(
137
- repo_id=repo_id,
138
- filename=filename,
139
- token=hf_token,
140
- cache_dir=None, # Use default cache
141
- )
142
- ckpt = torch.load(local_file, map_location=device, weights_only=False)
143
- except Exception as e:
144
- print(f"HuggingFace Hub download failed: {e}")
145
- print("Falling back to torch.hub.load_state_dict_from_url...")
146
- # Fallback to original method
147
- ckpt = torch.hub.load_state_dict_from_url(
148
- ckpt_path, map_location=device
149
- )
150
- else:
151
- # Use original method for non-HF URLs
152
- ckpt = torch.hub.load_state_dict_from_url(ckpt_path, map_location=device)
153
-
154
- print(model.load_state_dict(ckpt["model"], strict=False))
155
- model.eval()
156
- return model
157
-
158
-
159
  # MapAnything Configuration
160
  high_level_config = {
161
  "path": "configs/train.yaml",
 
162
  "config_overrides": [
163
  "machine=aws",
164
  "model=mapanything",
165
  "model/task=images_only",
166
  "model.encoder.uses_torch_hub=false",
167
  ],
168
- "checkpoint_path": "https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth",
169
  "trained_with_amp": True,
170
  "trained_with_amp_dtype": "fp16",
171
  "data_norm_type": "dinov2",
@@ -181,7 +125,7 @@ model = None
181
  # 1) Core model inference
182
  # -------------------------------------------------------------------------
183
  @spaces.GPU(duration=120)
184
- def run_model(target_dir, model_placeholder):
185
  """
186
  Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
187
  """
@@ -191,15 +135,16 @@ def run_model(target_dir, model_placeholder):
191
  # Device check
192
  device = "cuda" if torch.cuda.is_available() else "cpu"
193
  device = torch.device(device)
194
- # if not torch.cuda.is_available():
195
- # raise ValueError("CUDA is not available. Check your environment.")
196
 
197
  # Initialize model if not already done
198
  if model is None:
199
  print("Initializing MapAnything model...")
200
- model = init_inference_model(
201
- high_level_config, high_level_config["checkpoint_path"], device
 
 
202
  )
 
203
  else:
204
  model = model.to(device)
205
 
@@ -208,30 +153,18 @@ def run_model(target_dir, model_placeholder):
208
  # Load images using MapAnything's load_images function
209
  print("Loading images...")
210
  image_folder_path = os.path.join(target_dir, "images")
211
- views = load_images(
212
- image_folder_path,
213
- resolution_set=high_level_config["resolution"],
214
- verbose=False,
215
- norm_type=high_level_config["data_norm_type"],
216
- patch_size=high_level_config["patch_size"],
217
- stride=1,
218
- )
219
 
220
  print(f"Loaded {len(views)} images")
221
  if len(views) == 0:
222
  raise ValueError("No images found. Check your upload.")
223
 
224
- # Run inference using MapAnything's inference function
225
- print("Running MapAnything inference...")
226
- with torch.no_grad():
227
- pred_result = loss_of_one_batch_multi_view(
228
- views,
229
- model,
230
- None,
231
- device,
232
- use_amp=high_level_config["trained_with_amp"],
233
- amp_dtype=high_level_config["trained_with_amp_dtype"],
234
- )
235
 
236
  # Convert predictions to format expected by visualization
237
  predictions = {}
@@ -242,167 +175,40 @@ def run_model(target_dir, model_placeholder):
242
  world_points_list = []
243
  depth_maps_list = []
244
  images_list = []
245
- confidence_list = []
246
  final_mask_list = []
247
 
248
- # Check if confidence data is available
249
- has_confidence = False
250
- for view_idx, view in enumerate(views):
251
- view_key = f"pred{view_idx + 1}"
252
- if view_key in pred_result and "conf" in pred_result[view_key]:
253
- has_confidence = True
254
- break
255
-
256
- # Extract predictions for each view
257
- for view_idx, view in enumerate(views):
258
- # Get image for colors
259
- image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
260
 
261
- view_key = f"pred{view_idx + 1}"
262
- if view_key in pred_result:
263
- pred_pts3d = pred_result[view_key]["pts3d"][0].cpu().numpy()
264
-
265
- # Get confidence data if available
266
- confidence_map = None
267
- if "conf" in pred_result[view_key]:
268
- confidence_map = pred_result[view_key]["conf"][0].cpu().numpy()
269
-
270
- # Compute final_mask just like in visualize_raw_inference_output function
271
- # Create the prediction mask based on parameters
272
- pred_mask = None
273
- use_gt_mask_on_pred = False # Set based on your requirements
274
- use_pred_mask = True # Set based on your requirements
275
- use_non_ambi_mask = True # Set based on your requirements
276
- use_conf_mask = False # Set based on your requirements
277
- conf_percentile = 10 # Set based on your requirements
278
- use_edge_mask = True # Set based on your requirements
279
- pts_edge_tol = 5 # Set based on your requirements
280
- depth_edge_rtol = 0.03 # Set based on your requirements
281
-
282
- if use_pred_mask:
283
- # Get non ambiguous mask if available and requested
284
- has_non_ambiguous_mask = (
285
- "non_ambiguous_mask" in pred_result[view_key] and use_non_ambi_mask
286
- )
287
- if has_non_ambiguous_mask:
288
- non_ambiguous_mask = (
289
- pred_result[view_key]["non_ambiguous_mask"][0].cpu().numpy()
290
- )
291
- pred_mask = non_ambiguous_mask
292
-
293
- # Get confidence mask if available and requested
294
- has_conf = "conf" in pred_result[view_key] and use_conf_mask
295
- if has_conf:
296
- confidences = pred_result[view_key]["conf"][0].cpu()
297
- percentile_threshold = torch.quantile(
298
- confidences, conf_percentile / 100.0
299
- )
300
- conf_mask = confidences > percentile_threshold
301
- conf_mask = conf_mask.numpy()
302
- if pred_mask is not None:
303
- pred_mask = pred_mask & conf_mask
304
- else:
305
- pred_mask = conf_mask
306
-
307
- # Apply edge mask if requested
308
- if use_edge_mask and pred_mask is not None:
309
- if "cam_quats" not in pred_result[view_key]:
310
- # For direct point prediction
311
- # Compute normals and edge mask
312
- normals, normals_mask = points_to_normals(
313
- pred_pts3d, mask=pred_mask
314
- )
315
- edge_mask = ~(
316
- normals_edge(normals, tol=pts_edge_tol, mask=normals_mask)
317
- )
318
- else:
319
- # For ray-based prediction
320
- ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
321
- local_pts3d = (
322
- pred_result[view_key]["ray_directions"][0].cpu() * ray_depth
323
- )
324
- depth_z = local_pts3d[..., 2].numpy()
325
 
326
- # Compute normals and edge mask
327
- normals, normals_mask = points_to_normals(
328
- pred_pts3d, mask=pred_mask
329
- )
330
- edge_mask = ~(
331
- depth_edge(depth_z, rtol=depth_edge_rtol, mask=pred_mask)
332
- & normals_edge(normals, tol=pts_edge_tol, mask=normals_mask)
333
- )
334
- if pred_mask is not None:
335
- pred_mask = pred_mask & edge_mask
336
-
337
- # Determine final mask to use (like in visualize_raw_inference_output)
338
- final_mask = None
339
- valid_mask = np.ones_like(
340
- pred_pts3d[..., 0], dtype=bool
341
- ) # Create dummy valid_mask for app.py context
342
-
343
- if use_gt_mask_on_pred:
344
- final_mask = valid_mask
345
- if use_pred_mask and pred_mask is not None:
346
- final_mask = final_mask & pred_mask
347
- elif use_pred_mask and pred_mask is not None:
348
- final_mask = pred_mask
349
- else:
350
- final_mask = np.ones_like(valid_mask, dtype=bool)
351
-
352
- # Check if we have camera pose and intrinsics data
353
- if "cam_quats" in pred_result[view_key]:
354
- # Get decoupled quantities (like in visualize_raw_custom_data_inference_output)
355
- cam_quats = pred_result[view_key]["cam_quats"][0].cpu()
356
- cam_trans = pred_result[view_key]["cam_trans"][0].cpu()
357
- ray_directions = pred_result[view_key]["ray_directions"][0].cpu()
358
- ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
359
-
360
- # Convert the quantities
361
- from mapanything.utils.geometry import (
362
- quaternion_to_rotation_matrix,
363
- recover_pinhole_intrinsics_from_ray_directions,
364
- )
365
 
366
- cam_rot = quaternion_to_rotation_matrix(cam_quats)
367
- cam_pose = torch.eye(4)
368
- cam_pose[:3, :3] = cam_rot
369
- cam_pose[:3, 3] = cam_trans
370
- cam_pose = np.linalg.inv(cam_pose)
371
- cam_intrinsics = recover_pinhole_intrinsics_from_ray_directions(
372
- ray_directions, use_geometric_calculation=True
373
- )
374
 
375
- # Compute depth as in app_map.py
376
- local_pts3d = ray_directions * ray_depth
377
- depth_z = local_pts3d[..., 2]
378
 
379
- # Convert to numpy and extract 3x4 extrinsic (remove bottom row)
380
- extrinsic = cam_pose[:3, :4].numpy() # Shape: (3, 4)
381
- intrinsic = cam_intrinsics.numpy() # Shape: (3, 3)
382
- depth_z = depth_z.numpy() # Shape: (H, W)
383
- else:
384
- # Use dummy values if camera info not available
385
- # extrinsic: (3, 4) - [R|t] matrix
386
- extrinsic = np.eye(3, 4) # Identity rotation, zero translation
387
- # intrinsic: (3, 3) - camera intrinsic matrix
388
- intrinsic = np.eye(3)
389
- # depth_z: (H, W) - dummy depth values
390
- depth_z = np.zeros_like(pred_pts3d[..., 0])
391
-
392
- # Append to lists
393
- extrinsic_list.append(extrinsic)
394
- intrinsic_list.append(intrinsic)
395
- world_points_list.append(pred_pts3d)
396
- depth_maps_list.append(depth_z)
397
- images_list.append(image[0]) # Add image to list
398
- final_mask_list.append(final_mask) # Add final_mask to list
399
-
400
- # Add confidence data (or None if not available)
401
- if confidence_map is not None:
402
- confidence_list.append(confidence_map)
403
- elif has_confidence:
404
- # If some views have confidence but this one doesn't, add dummy confidence
405
- confidence_list.append(np.ones_like(depth_z))
406
 
407
  # Convert lists to numpy arrays with required shapes
408
  # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
@@ -419,26 +225,18 @@ def run_model(target_dir, model_placeholder):
419
  # Add channel dimension if needed to match (S, H, W, 1) format
420
  if len(depth_maps.shape) == 3:
421
  depth_maps = depth_maps[..., np.newaxis]
 
422
  predictions["depth"] = depth_maps
423
 
424
  # images: (S, H, W, 3) - batch of input images
425
  predictions["images"] = np.stack(images_list, axis=0)
426
 
427
- # confidence: (S, H, W) - batch of confidence maps (only if available)
428
- if confidence_list:
429
- predictions["confidence"] = np.stack(confidence_list, axis=0)
430
-
431
  # final_mask: (S, H, W) - batch of final masks for filtering
432
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
433
 
434
- world_points = unproject_depth_map_to_point_map(
435
- depth_maps, predictions["extrinsic"], predictions["intrinsic"]
436
- )
437
- predictions["world_points_from_depth"] = world_points
438
-
439
  # Process data for visualization tabs (depth, normal, measure)
440
  processed_data = process_predictions_for_visualization(
441
- pred_result, views, high_level_config
442
  )
443
 
444
  # Clean up
@@ -474,43 +272,69 @@ def get_view_data_by_index(processed_data, view_index):
474
  return processed_data[view_keys[view_index]]
475
 
476
 
477
- def update_depth_view(processed_data, view_index, conf_thres=None):
478
- """Update depth view for a specific view index with optional confidence filtering"""
479
  view_data = get_view_data_by_index(processed_data, view_index)
480
  if view_data is None or view_data["depth"] is None:
481
  return None
482
 
483
  # Use confidence filtering if available
484
  confidence = view_data.get("confidence")
485
- return colorize_depth(
486
- view_data["depth"], confidence=confidence, conf_thres=conf_thres
487
- )
488
 
489
 
490
- def update_normal_view(processed_data, view_index, conf_thres=None):
491
- """Update normal view for a specific view index with optional confidence filtering"""
492
  view_data = get_view_data_by_index(processed_data, view_index)
493
  if view_data is None or view_data["normal"] is None:
494
  return None
495
 
496
  # Use confidence filtering if available
497
  confidence = view_data.get("confidence")
498
- return colorize_normal(
499
- view_data["normal"], confidence=confidence, conf_thres=conf_thres
500
- )
501
 
502
 
503
  def update_measure_view(processed_data, view_index):
504
- """Update measure view for a specific view index"""
505
  view_data = get_view_data_by_index(processed_data, view_index)
506
  if view_data is None:
507
  return None, [] # image, measure_points
508
- return view_data["image"], []
509
 
 
 
510
 
511
- def navigate_depth_view(
512
- processed_data, current_selector_value, direction, conf_thres=None
513
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  """Navigate depth view (direction: -1 for previous, +1 for next)"""
515
  if processed_data is None or len(processed_data) == 0:
516
  return "View 1", None
@@ -525,14 +349,12 @@ def navigate_depth_view(
525
  new_view = (current_view + direction) % num_views
526
 
527
  new_selector_value = f"View {new_view + 1}"
528
- depth_vis = update_depth_view(processed_data, new_view, conf_thres=conf_thres)
529
 
530
  return new_selector_value, depth_vis
531
 
532
 
533
- def navigate_normal_view(
534
- processed_data, current_selector_value, direction, conf_thres=None
535
- ):
536
  """Navigate normal view (direction: -1 for previous, +1 for next)"""
537
  if processed_data is None or len(processed_data) == 0:
538
  return "View 1", None
@@ -547,7 +369,7 @@ def navigate_normal_view(
547
  new_view = (current_view + direction) % num_views
548
 
549
  new_selector_value = f"View {new_view + 1}"
550
- normal_vis = update_normal_view(processed_data, new_view, conf_thres=conf_thres)
551
 
552
  return new_selector_value, normal_vis
553
 
@@ -572,14 +394,14 @@ def navigate_measure_view(processed_data, current_selector_value, direction):
572
  return new_selector_value, measure_image, measure_points
573
 
574
 
575
- def populate_visualization_tabs(processed_data, conf_thres=None):
576
  """Populate the depth, normal, and measure tabs with processed data"""
577
  if processed_data is None or len(processed_data) == 0:
578
  return None, None, None, []
579
 
580
  # Use update functions to ensure confidence filtering is applied from the start
581
- depth_vis = update_depth_view(processed_data, 0, conf_thres=conf_thres)
582
- normal_vis = update_normal_view(processed_data, 0, conf_thres=conf_thres)
583
  measure_img, _ = update_measure_view(processed_data, 0)
584
 
585
  return depth_vis, normal_vis, measure_img, []
@@ -683,13 +505,13 @@ def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
683
  @spaces.GPU(duration=120)
684
  def gradio_demo(
685
  target_dir,
686
- conf_thres=3.0,
687
  frame_filter="All",
688
  show_cam=True,
689
  filter_sky=False,
690
  filter_black_bg=False,
691
  filter_white_bg=False,
692
- mask_ambiguous=False,
 
693
  ):
694
  """
695
  Perform reconstruction using the already-created target_dir/images.
@@ -716,7 +538,9 @@ def gradio_demo(
716
 
717
  print("Running MapAnything model...")
718
  with torch.no_grad():
719
- predictions, processed_data = run_model(target_dir, None)
 
 
720
 
721
  # Save predictions
722
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
@@ -729,13 +553,12 @@ def gradio_demo(
729
  # Build a GLB file name
730
  glbfile = os.path.join(
731
  target_dir,
732
- f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_mask{mask_ambiguous}_pred{prediction_mode.replace(' ', '_')}.glb",
733
  )
734
 
735
  # Convert predictions to GLB
736
  glbscene = predictions_to_glb(
737
  predictions,
738
- conf_thres=conf_thres,
739
  filter_by_frames=frame_filter,
740
  show_cam=show_cam,
741
  target_dir=target_dir,
@@ -743,7 +566,6 @@ def gradio_demo(
743
  mask_sky=filter_sky,
744
  mask_black_bg=filter_black_bg,
745
  mask_white_bg=filter_white_bg,
746
- mask_ambiguous=mask_ambiguous,
747
  )
748
  glbscene.export(file_obj=glbfile)
749
 
@@ -760,7 +582,7 @@ def gradio_demo(
760
 
761
  # Populate visualization tabs with processed data
762
  depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
763
- processed_data, conf_thres=conf_thres
764
  )
765
 
766
  # Update view selectors based on available views
@@ -860,29 +682,30 @@ def colorize_normal(normal_map, confidence=None, conf_thres=None):
860
  return normal_vis
861
 
862
 
863
- def process_predictions_for_visualization(pred_result, views, high_level_config):
864
  """Extract depth, normal, and 3D points from predictions for visualization"""
865
  processed_data = {}
866
 
867
  # Check if confidence data is available in any view
868
  has_confidence_data = False
869
- for view_idx, view in enumerate(views):
870
- view_key = f"pred{view_idx + 1}"
871
- if view_key in pred_result and "conf" in pred_result[view_key]:
872
- has_confidence_data = True
873
- break
874
 
875
  # Process each view
876
  for view_idx, view in enumerate(views):
877
- view_key = f"pred{view_idx + 1}"
878
- if view_key not in pred_result:
879
- continue
880
 
881
  # Get image
882
  image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
 
883
 
884
  # Get predicted points
885
- pred_pts3d = pred_result[view_key]["pts3d"][0].cpu().numpy()
886
 
887
  # Initialize data for this view
888
  view_data = {
@@ -895,36 +718,12 @@ def process_predictions_for_visualization(pred_result, views, high_level_config)
895
  "has_confidence": has_confidence_data,
896
  }
897
 
898
- # Get confidence data if available
899
- if "conf" in pred_result[view_key]:
900
- confidence = pred_result[view_key]["conf"][0].cpu().numpy()
901
- view_data["confidence"] = confidence
902
 
903
- # Get masks if available
904
- has_non_ambiguous_mask = "non_ambiguous_mask" in pred_result[view_key]
905
- if has_non_ambiguous_mask:
906
- view_data["mask"] = (
907
- pred_result[view_key]["non_ambiguous_mask"][0].cpu().numpy()
908
- )
909
 
910
- # Extract depth and camera info if available
911
- if "cam_quats" in pred_result[view_key]:
912
- ray_directions = pred_result[view_key]["ray_directions"][0].cpu()
913
- ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
914
-
915
- # Compute depth
916
- local_pts3d = ray_directions * ray_depth
917
- depth_z = local_pts3d[..., 2].numpy()
918
- view_data["depth"] = depth_z
919
-
920
- # Compute normals if we have valid points
921
- if has_non_ambiguous_mask:
922
- try:
923
- normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
924
- view_data["normal"] = normals
925
- except:
926
- # If normal computation fails, skip it
927
- pass
928
 
929
  processed_data[view_idx] = view_data
930
 
@@ -972,10 +771,29 @@ def measure(
972
  point2d = event.index[0], event.index[1]
973
  print(f"Clicked point: {point2d}")
974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
975
  measure_points.append(point2d)
976
 
977
- # Get image and ensure it's valid
978
- image = current_view["image"]
979
  if image is None:
980
  return None, [], "No image available"
981
 
@@ -1093,14 +911,12 @@ def update_log():
1093
 
1094
  def update_visualization(
1095
  target_dir,
1096
- conf_thres,
1097
  frame_filter,
1098
  show_cam,
1099
  is_example,
1100
  filter_sky=False,
1101
  filter_black_bg=False,
1102
  filter_white_bg=False,
1103
- mask_ambiguous=False,
1104
  ):
1105
  """
1106
  Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
@@ -1135,13 +951,12 @@ def update_visualization(
1135
 
1136
  glbfile = os.path.join(
1137
  target_dir,
1138
- f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_pred{prediction_mode.replace(' ', '_')}.glb",
1139
  )
1140
 
1141
  if not os.path.exists(glbfile):
1142
  glbscene = predictions_to_glb(
1143
  predictions,
1144
- conf_thres=conf_thres,
1145
  filter_by_frames=frame_filter,
1146
  show_cam=show_cam,
1147
  target_dir=target_dir,
@@ -1149,7 +964,6 @@ def update_visualization(
1149
  mask_sky=filter_sky,
1150
  mask_black_bg=filter_black_bg,
1151
  mask_white_bg=filter_white_bg,
1152
- mask_ambiguous=mask_ambiguous,
1153
  )
1154
  glbscene.export(file_obj=glbfile)
1155
 
@@ -1346,6 +1160,9 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1346
  interactive=False,
1347
  sources=[],
1348
  )
 
 
 
1349
  measure_text = gr.Markdown("")
1350
 
1351
  with gr.Row():
@@ -1363,17 +1180,11 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1363
  )
1364
 
1365
  with gr.Row():
1366
- conf_thres = gr.Slider(
1367
- minimum=0,
1368
- maximum=100,
1369
- value=0,
1370
- step=0.1,
1371
- label="Confidence Threshold (%), only shown in depth and normals",
1372
- )
1373
  frame_filter = gr.Dropdown(
1374
  choices=["All"], value="All", label="Show Points from Frame"
1375
  )
1376
  with gr.Column():
 
1377
  show_cam = gr.Checkbox(label="Show Camera", value=True)
1378
  filter_sky = gr.Checkbox(
1379
  label="Filter Sky (using skyseg.onnx)", value=False
@@ -1384,8 +1195,11 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1384
  filter_white_bg = gr.Checkbox(
1385
  label="Filter White Background", value=False
1386
  )
1387
- mask_ambiguous = gr.Checkbox(label="Mask Ambiguous", value=True)
1388
-
 
 
 
1389
  # ---------------------- Example Scenes Section ----------------------
1390
  gr.Markdown("## Example Scenes")
1391
  gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
@@ -1446,13 +1260,13 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1446
  fn=gradio_demo,
1447
  inputs=[
1448
  target_dir_output,
1449
- conf_thres,
1450
  frame_filter,
1451
  show_cam,
1452
  filter_sky,
1453
  filter_black_bg,
1454
  filter_white_bg,
1455
- mask_ambiguous,
 
1456
  ],
1457
  outputs=[
1458
  reconstruction_output,
@@ -1476,76 +1290,10 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1476
  # -------------------------------------------------------------------------
1477
  # Real-time Visualization Updates
1478
  # -------------------------------------------------------------------------
1479
- def update_all_visualizations_on_conf_change(
1480
- processed_data,
1481
- depth_selector,
1482
- normal_selector,
1483
- conf_thres_val,
1484
- target_dir,
1485
- frame_filter,
1486
- show_cam,
1487
- is_example,
1488
- ):
1489
- """Update 3D view and all tabs when confidence threshold changes"""
1490
-
1491
- # Update 3D pointcloud visualization
1492
- glb_file, log_msg = update_visualization(
1493
- target_dir,
1494
- conf_thres_val,
1495
- frame_filter,
1496
- show_cam,
1497
- is_example,
1498
- )
1499
-
1500
- # Update depth and normal tabs with new confidence threshold
1501
- depth_vis = None
1502
- normal_vis = None
1503
-
1504
- if processed_data is not None:
1505
- # Get current view indices from selectors
1506
- try:
1507
- depth_view_idx = (
1508
- int(depth_selector.split()[1]) - 1 if depth_selector else 0
1509
- )
1510
- except:
1511
- depth_view_idx = 0
1512
-
1513
- try:
1514
- normal_view_idx = (
1515
- int(normal_selector.split()[1]) - 1 if normal_selector else 0
1516
- )
1517
- except:
1518
- normal_view_idx = 0
1519
-
1520
- # Update visualizations with new confidence threshold
1521
- depth_vis = update_depth_view(
1522
- processed_data, depth_view_idx, conf_thres=conf_thres_val
1523
- )
1524
- normal_vis = update_normal_view(
1525
- processed_data, normal_view_idx, conf_thres=conf_thres_val
1526
- )
1527
-
1528
- return glb_file, log_msg, depth_vis, normal_vis
1529
-
1530
- conf_thres.change(
1531
- fn=update_all_visualizations_on_conf_change,
1532
- inputs=[
1533
- processed_data_state,
1534
- depth_view_selector,
1535
- normal_view_selector,
1536
- conf_thres,
1537
- target_dir_output,
1538
- frame_filter,
1539
- show_cam,
1540
- is_example,
1541
- ],
1542
- outputs=[reconstruction_output, log_output, depth_map, normal_map],
1543
- )
1544
  frame_filter.change(
1545
  update_visualization,
1546
  [
1547
  target_dir_output,
1548
- conf_thres,
1549
  frame_filter,
1550
  show_cam,
1551
  is_example,
@@ -1556,7 +1304,6 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1556
  update_visualization,
1557
  [
1558
  target_dir_output,
1559
- conf_thres,
1560
  frame_filter,
1561
  show_cam,
1562
  is_example,
@@ -1567,14 +1314,12 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1567
  update_visualization,
1568
  [
1569
  target_dir_output,
1570
- conf_thres,
1571
  frame_filter,
1572
  show_cam,
1573
  is_example,
1574
  filter_sky,
1575
  filter_black_bg,
1576
  filter_white_bg,
1577
- mask_ambiguous,
1578
  ],
1579
  [reconstruction_output, log_output],
1580
  )
@@ -1582,14 +1327,12 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1582
  update_visualization,
1583
  [
1584
  target_dir_output,
1585
- conf_thres,
1586
  frame_filter,
1587
  show_cam,
1588
  is_example,
1589
  filter_sky,
1590
  filter_black_bg,
1591
  filter_white_bg,
1592
- mask_ambiguous,
1593
  ],
1594
  [reconstruction_output, log_output],
1595
  )
@@ -1597,29 +1340,12 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1597
  update_visualization,
1598
  [
1599
  target_dir_output,
1600
- conf_thres,
1601
  frame_filter,
1602
  show_cam,
1603
  is_example,
1604
  filter_sky,
1605
  filter_black_bg,
1606
  filter_white_bg,
1607
- mask_ambiguous,
1608
- ],
1609
- [reconstruction_output, log_output],
1610
- )
1611
- mask_ambiguous.change(
1612
- update_visualization,
1613
- [
1614
- target_dir_output,
1615
- conf_thres,
1616
- frame_filter,
1617
- show_cam,
1618
- is_example,
1619
- filter_sky,
1620
- filter_black_bg,
1621
- filter_white_bg,
1622
- mask_ambiguous,
1623
  ],
1624
  [reconstruction_output, log_output],
1625
  )
@@ -1653,67 +1379,61 @@ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1653
 
1654
  # Depth tab navigation
1655
  prev_depth_btn.click(
1656
- fn=lambda processed_data, current_selector, conf_thres_val: navigate_depth_view(
1657
- processed_data, current_selector, -1, conf_thres=conf_thres_val
1658
  ),
1659
- inputs=[processed_data_state, depth_view_selector, conf_thres],
1660
  outputs=[depth_view_selector, depth_map],
1661
  )
1662
 
1663
  next_depth_btn.click(
1664
- fn=lambda processed_data, current_selector, conf_thres_val: navigate_depth_view(
1665
- processed_data, current_selector, 1, conf_thres=conf_thres_val
1666
  ),
1667
- inputs=[processed_data_state, depth_view_selector, conf_thres],
1668
  outputs=[depth_view_selector, depth_map],
1669
  )
1670
 
1671
  depth_view_selector.change(
1672
- fn=lambda processed_data, selector_value, conf_thres_val: (
1673
  update_depth_view(
1674
  processed_data,
1675
  int(selector_value.split()[1]) - 1,
1676
- conf_thres=conf_thres_val,
1677
  )
1678
  if selector_value
1679
  else None
1680
  ),
1681
- inputs=[processed_data_state, depth_view_selector, conf_thres],
1682
  outputs=[depth_map],
1683
  )
1684
 
1685
  # Normal tab navigation
1686
  prev_normal_btn.click(
1687
- fn=lambda processed_data,
1688
- current_selector,
1689
- conf_thres_val: navigate_normal_view(
1690
- processed_data, current_selector, -1, conf_thres=conf_thres_val
1691
  ),
1692
- inputs=[processed_data_state, normal_view_selector, conf_thres],
1693
  outputs=[normal_view_selector, normal_map],
1694
  )
1695
 
1696
  next_normal_btn.click(
1697
- fn=lambda processed_data,
1698
- current_selector,
1699
- conf_thres_val: navigate_normal_view(
1700
- processed_data, current_selector, 1, conf_thres=conf_thres_val
1701
  ),
1702
- inputs=[processed_data_state, normal_view_selector, conf_thres],
1703
  outputs=[normal_view_selector, normal_map],
1704
  )
1705
 
1706
  normal_view_selector.change(
1707
- fn=lambda processed_data, selector_value, conf_thres_val: (
1708
  update_normal_view(
1709
  processed_data,
1710
  int(selector_value.split()[1]) - 1,
1711
- conf_thres=conf_thres_val,
1712
  )
1713
  if selector_value
1714
  else None
1715
  ),
1716
- inputs=[processed_data_state, normal_view_selector, conf_thres],
1717
  outputs=[normal_map],
1718
  )
1719
 
 
18
  import numpy as np
19
  import spaces
20
  import torch
 
21
 
22
  sys.path.append("mapanything/")
23
 
24
  from hf_utils.css_and_html import (
25
+ GRADIO_CSS,
26
+ MEASURE_INSTRUCTIONS_HTML,
27
  get_acknowledgements_html,
28
  get_description_html,
29
  get_gradio_theme,
30
  get_header_html,
 
 
31
  )
 
32
  from hf_utils.visual_util import predictions_to_glb
33
+ from mapanything.models import MapAnything
34
+ from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
35
  from mapanything.utils.image import load_images, rgb
 
36
 
37
 
38
  def get_logo_base64():
 
100
  return cfg
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # MapAnything Configuration
104
  high_level_config = {
105
  "path": "configs/train.yaml",
106
+ "hf_model_name": "facebook/map-anything",
107
  "config_overrides": [
108
  "machine=aws",
109
  "model=mapanything",
110
  "model/task=images_only",
111
  "model.encoder.uses_torch_hub=false",
112
  ],
 
113
  "trained_with_amp": True,
114
  "trained_with_amp_dtype": "fp16",
115
  "data_norm_type": "dinov2",
 
125
  # 1) Core model inference
126
  # -------------------------------------------------------------------------
127
  @spaces.GPU(duration=120)
128
+ def run_model(target_dir, model_placeholder, apply_mask=True, mask_edges=True):
129
  """
130
  Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
131
  """
 
135
  # Device check
136
  device = "cuda" if torch.cuda.is_available() else "cpu"
137
  device = torch.device(device)
 
 
138
 
139
  # Initialize model if not already done
140
  if model is None:
141
  print("Initializing MapAnything model...")
142
+
143
+ print("Loading CC-BY-NC 4.0 licensed MapAnything model...")
144
+ model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to(
145
+ device
146
  )
147
+
148
  else:
149
  model = model.to(device)
150
 
 
153
  # Load images using MapAnything's load_images function
154
  print("Loading images...")
155
  image_folder_path = os.path.join(target_dir, "images")
156
+ views = load_images(image_folder_path)
 
 
 
 
 
 
 
157
 
158
  print(f"Loaded {len(views)} images")
159
  if len(views) == 0:
160
  raise ValueError("No images found. Check your upload.")
161
 
162
+ # Run model inference
163
+ print("Running inference...")
164
+ # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
165
+ # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
166
+ # Use checkbox values
167
+ outputs = model.infer(views, apply_mask=apply_mask, mask_edges=mask_edges)
 
 
 
 
 
168
 
169
  # Convert predictions to format expected by visualization
170
  predictions = {}
 
175
  world_points_list = []
176
  depth_maps_list = []
177
  images_list = []
 
178
  final_mask_list = []
179
 
180
+ # Loop through the outputs
181
+ for pred in outputs:
182
+ # Extract data from predictions
183
+ depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
184
+ intrinsics_torch = pred["intrinsics"][0] # (3, 3)
185
+ camera_pose_torch = pred["camera_poses"][0] # (4, 4)
 
 
 
 
 
 
186
 
187
+ # Compute new pts3d using depth, intrinsics, and camera pose
188
+ pts3d_computed, valid_mask = depthmap_to_world_frame(
189
+ depthmap_torch, intrinsics_torch, camera_pose_torch
190
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ # Convert to numpy arrays for visualization
193
+ # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
194
+ if "mask" in pred:
195
+ mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
196
+ else:
197
+ # Fill with boolean trues in the size of depthmap_torch
198
+ mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ # Combine with valid depth mask
201
+ mask = mask & valid_mask.cpu().numpy()
 
 
 
 
 
 
202
 
203
+ image = pred["img_no_norm"][0].cpu().numpy()
 
 
204
 
205
+ # Append to lists
206
+ extrinsic_list.append(camera_pose_torch.cpu().numpy())
207
+ intrinsic_list.append(intrinsics_torch.cpu().numpy())
208
+ world_points_list.append(pts3d_computed.cpu().numpy())
209
+ depth_maps_list.append(depthmap_torch.cpu().numpy())
210
+ images_list.append(image) # Add image to list
211
+ final_mask_list.append(mask) # Add final_mask to list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  # Convert lists to numpy arrays with required shapes
214
  # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
 
225
  # Add channel dimension if needed to match (S, H, W, 1) format
226
  if len(depth_maps.shape) == 3:
227
  depth_maps = depth_maps[..., np.newaxis]
228
+
229
  predictions["depth"] = depth_maps
230
 
231
  # images: (S, H, W, 3) - batch of input images
232
  predictions["images"] = np.stack(images_list, axis=0)
233
 
 
 
 
 
234
  # final_mask: (S, H, W) - batch of final masks for filtering
235
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
236
 
 
 
 
 
 
237
  # Process data for visualization tabs (depth, normal, measure)
238
  processed_data = process_predictions_for_visualization(
239
+ predictions, views, high_level_config
240
  )
241
 
242
  # Clean up
 
272
  return processed_data[view_keys[view_index]]
273
 
274
 
275
+ def update_depth_view(processed_data, view_index):
276
+ """Update depth view for a specific view index"""
277
  view_data = get_view_data_by_index(processed_data, view_index)
278
  if view_data is None or view_data["depth"] is None:
279
  return None
280
 
281
  # Use confidence filtering if available
282
  confidence = view_data.get("confidence")
283
+ return colorize_depth(view_data["depth"], confidence=confidence)
 
 
284
 
285
 
286
+ def update_normal_view(processed_data, view_index):
287
+ """Update normal view for a specific view index"""
288
  view_data = get_view_data_by_index(processed_data, view_index)
289
  if view_data is None or view_data["normal"] is None:
290
  return None
291
 
292
  # Use confidence filtering if available
293
  confidence = view_data.get("confidence")
294
+ return colorize_normal(view_data["normal"], confidence=confidence)
 
 
295
 
296
 
297
  def update_measure_view(processed_data, view_index):
298
+ """Update measure view for a specific view index with mask overlay"""
299
  view_data = get_view_data_by_index(processed_data, view_index)
300
  if view_data is None:
301
  return None, [] # image, measure_points
 
302
 
303
+ # Get the base image
304
+ image = view_data["image"].copy()
305
 
306
+ # Ensure image is in uint8 format
307
+ if image.dtype != np.uint8:
308
+ if image.max() <= 1.0:
309
+ image = (image * 255).astype(np.uint8)
310
+ else:
311
+ image = image.astype(np.uint8)
312
+
313
+ # Apply mask overlay if mask is available
314
+ if view_data["mask"] is not None:
315
+ mask = view_data["mask"]
316
+
317
+ # Create light grey overlay for masked areas
318
+ # Masked areas (False values) will be overlaid with light grey
319
+ invalid_mask = ~mask # Areas where mask is False
320
+
321
+ if invalid_mask.any():
322
+ # Create a light grey overlay (RGB: 192, 192, 192)
323
+ overlay_color = np.array([192, 192, 192], dtype=np.uint8)
324
+
325
+ # Apply overlay with some transparency
326
+ alpha = 0.5 # Transparency level
327
+ for c in range(3): # RGB channels
328
+ image[:, :, c] = np.where(
329
+ invalid_mask,
330
+ (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
331
+ image[:, :, c],
332
+ ).astype(np.uint8)
333
+
334
+ return image, []
335
+
336
+
337
+ def navigate_depth_view(processed_data, current_selector_value, direction):
338
  """Navigate depth view (direction: -1 for previous, +1 for next)"""
339
  if processed_data is None or len(processed_data) == 0:
340
  return "View 1", None
 
349
  new_view = (current_view + direction) % num_views
350
 
351
  new_selector_value = f"View {new_view + 1}"
352
+ depth_vis = update_depth_view(processed_data, new_view)
353
 
354
  return new_selector_value, depth_vis
355
 
356
 
357
+ def navigate_normal_view(processed_data, current_selector_value, direction):
 
 
358
  """Navigate normal view (direction: -1 for previous, +1 for next)"""
359
  if processed_data is None or len(processed_data) == 0:
360
  return "View 1", None
 
369
  new_view = (current_view + direction) % num_views
370
 
371
  new_selector_value = f"View {new_view + 1}"
372
+ normal_vis = update_normal_view(processed_data, new_view)
373
 
374
  return new_selector_value, normal_vis
375
 
 
394
  return new_selector_value, measure_image, measure_points
395
 
396
 
397
+ def populate_visualization_tabs(processed_data):
398
  """Populate the depth, normal, and measure tabs with processed data"""
399
  if processed_data is None or len(processed_data) == 0:
400
  return None, None, None, []
401
 
402
  # Use update functions to ensure confidence filtering is applied from the start
403
+ depth_vis = update_depth_view(processed_data, 0)
404
+ normal_vis = update_normal_view(processed_data, 0)
405
  measure_img, _ = update_measure_view(processed_data, 0)
406
 
407
  return depth_vis, normal_vis, measure_img, []
 
505
  @spaces.GPU(duration=120)
506
  def gradio_demo(
507
  target_dir,
 
508
  frame_filter="All",
509
  show_cam=True,
510
  filter_sky=False,
511
  filter_black_bg=False,
512
  filter_white_bg=False,
513
+ apply_mask=True,
514
+ mask_edges=True,
515
  ):
516
  """
517
  Perform reconstruction using the already-created target_dir/images.
 
538
 
539
  print("Running MapAnything model...")
540
  with torch.no_grad():
541
+ predictions, processed_data = run_model(
542
+ target_dir, None, apply_mask, mask_edges
543
+ )
544
 
545
  # Save predictions
546
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
 
553
  # Build a GLB file name
554
  glbfile = os.path.join(
555
  target_dir,
556
+ f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_pred{prediction_mode.replace(' ', '_')}.glb",
557
  )
558
 
559
  # Convert predictions to GLB
560
  glbscene = predictions_to_glb(
561
  predictions,
 
562
  filter_by_frames=frame_filter,
563
  show_cam=show_cam,
564
  target_dir=target_dir,
 
566
  mask_sky=filter_sky,
567
  mask_black_bg=filter_black_bg,
568
  mask_white_bg=filter_white_bg,
 
569
  )
570
  glbscene.export(file_obj=glbfile)
571
 
 
582
 
583
  # Populate visualization tabs with processed data
584
  depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
585
+ processed_data
586
  )
587
 
588
  # Update view selectors based on available views
 
682
  return normal_vis
683
 
684
 
685
+ def process_predictions_for_visualization(predictions, views, high_level_config):
686
  """Extract depth, normal, and 3D points from predictions for visualization"""
687
  processed_data = {}
688
 
689
  # Check if confidence data is available in any view
690
  has_confidence_data = False
691
+ # for view_idx, view in enumerate(views):
692
+ # view_key = f"pred{view_idx + 1}"
693
+ # if view_key in pred_result and "conf" in pred_result[view_key]:
694
+ # has_confidence_data = True
695
+ # break
696
 
697
  # Process each view
698
  for view_idx, view in enumerate(views):
699
+ # view_key = f"pred{view_idx + 1}"
700
+ # if view_key not in pred_result:
701
+ # continue
702
 
703
  # Get image
704
  image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
705
+ # image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
706
 
707
  # Get predicted points
708
+ pred_pts3d = predictions["world_points"][view_idx]
709
 
710
  # Initialize data for this view
711
  view_data = {
 
718
  "has_confidence": has_confidence_data,
719
  }
720
 
721
+ view_data["mask"] = predictions["final_mask"][view_idx]
 
 
 
722
 
723
+ view_data["depth"] = predictions["depth"][view_idx].squeeze()
 
 
 
 
 
724
 
725
+ normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
726
+ view_data["normal"] = normals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
 
728
  processed_data[view_idx] = view_data
729
 
 
771
  point2d = event.index[0], event.index[1]
772
  print(f"Clicked point: {point2d}")
773
 
774
+ # Check if the clicked point is in a masked area (prevent interaction)
775
+ if (
776
+ current_view["mask"] is not None
777
+ and 0 <= point2d[1] < current_view["mask"].shape[0]
778
+ and 0 <= point2d[0] < current_view["mask"].shape[1]
779
+ ):
780
+ # Check if the point is in a masked (invalid) area
781
+ if not current_view["mask"][point2d[1], point2d[0]]:
782
+ print(f"Clicked point {point2d} is in masked area, ignoring click")
783
+ # Always return image with mask overlay
784
+ masked_image, _ = update_measure_view(
785
+ processed_data, current_view_index
786
+ )
787
+ return (
788
+ masked_image,
789
+ measure_points,
790
+ '<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
791
+ )
792
+
793
  measure_points.append(point2d)
794
 
795
+ # Get image with mask overlay and ensure it's valid
796
+ image, _ = update_measure_view(processed_data, current_view_index)
797
  if image is None:
798
  return None, [], "No image available"
799
 
 
911
 
912
  def update_visualization(
913
  target_dir,
 
914
  frame_filter,
915
  show_cam,
916
  is_example,
917
  filter_sky=False,
918
  filter_black_bg=False,
919
  filter_white_bg=False,
 
920
  ):
921
  """
922
  Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
 
951
 
952
  glbfile = os.path.join(
953
  target_dir,
954
+ f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_pred{prediction_mode.replace(' ', '_')}.glb",
955
  )
956
 
957
  if not os.path.exists(glbfile):
958
  glbscene = predictions_to_glb(
959
  predictions,
 
960
  filter_by_frames=frame_filter,
961
  show_cam=show_cam,
962
  target_dir=target_dir,
 
964
  mask_sky=filter_sky,
965
  mask_black_bg=filter_black_bg,
966
  mask_white_bg=filter_white_bg,
 
967
  )
968
  glbscene.export(file_obj=glbfile)
969
 
 
1160
  interactive=False,
1161
  sources=[],
1162
  )
1163
+ gr.Markdown(
1164
+ "**Note:** Gray areas indicate regions with no depth information where measurements cannot be taken."
1165
+ )
1166
  measure_text = gr.Markdown("")
1167
 
1168
  with gr.Row():
 
1180
  )
1181
 
1182
  with gr.Row():
 
 
 
 
 
 
 
1183
  frame_filter = gr.Dropdown(
1184
  choices=["All"], value="All", label="Show Points from Frame"
1185
  )
1186
  with gr.Column():
1187
+ gr.Markdown("### Pointcloud options (live updates)")
1188
  show_cam = gr.Checkbox(label="Show Camera", value=True)
1189
  filter_sky = gr.Checkbox(
1190
  label="Filter Sky (using skyseg.onnx)", value=False
 
1195
  filter_white_bg = gr.Checkbox(
1196
  label="Filter White Background", value=False
1197
  )
1198
+ gr.Markdown("### Reconstruction options: (updated on next run)")
1199
+ apply_mask_checkbox = gr.Checkbox(
1200
+ label="Apply non-ambiguous mask", value=True
1201
+ )
1202
+ mask_edges_checkbox = apply_mask_checkbox
1203
  # ---------------------- Example Scenes Section ----------------------
1204
  gr.Markdown("## Example Scenes")
1205
  gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
 
1260
  fn=gradio_demo,
1261
  inputs=[
1262
  target_dir_output,
 
1263
  frame_filter,
1264
  show_cam,
1265
  filter_sky,
1266
  filter_black_bg,
1267
  filter_white_bg,
1268
+ apply_mask_checkbox,
1269
+ mask_edges_checkbox,
1270
  ],
1271
  outputs=[
1272
  reconstruction_output,
 
1290
  # -------------------------------------------------------------------------
1291
  # Real-time Visualization Updates
1292
  # -------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1293
  frame_filter.change(
1294
  update_visualization,
1295
  [
1296
  target_dir_output,
 
1297
  frame_filter,
1298
  show_cam,
1299
  is_example,
 
1304
  update_visualization,
1305
  [
1306
  target_dir_output,
 
1307
  frame_filter,
1308
  show_cam,
1309
  is_example,
 
1314
  update_visualization,
1315
  [
1316
  target_dir_output,
 
1317
  frame_filter,
1318
  show_cam,
1319
  is_example,
1320
  filter_sky,
1321
  filter_black_bg,
1322
  filter_white_bg,
 
1323
  ],
1324
  [reconstruction_output, log_output],
1325
  )
 
1327
  update_visualization,
1328
  [
1329
  target_dir_output,
 
1330
  frame_filter,
1331
  show_cam,
1332
  is_example,
1333
  filter_sky,
1334
  filter_black_bg,
1335
  filter_white_bg,
 
1336
  ],
1337
  [reconstruction_output, log_output],
1338
  )
 
1340
  update_visualization,
1341
  [
1342
  target_dir_output,
 
1343
  frame_filter,
1344
  show_cam,
1345
  is_example,
1346
  filter_sky,
1347
  filter_black_bg,
1348
  filter_white_bg,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1349
  ],
1350
  [reconstruction_output, log_output],
1351
  )
 
1379
 
1380
  # Depth tab navigation
1381
  prev_depth_btn.click(
1382
+ fn=lambda processed_data, current_selector: navigate_depth_view(
1383
+ processed_data, current_selector, -1
1384
  ),
1385
+ inputs=[processed_data_state, depth_view_selector],
1386
  outputs=[depth_view_selector, depth_map],
1387
  )
1388
 
1389
  next_depth_btn.click(
1390
+ fn=lambda processed_data, current_selector: navigate_depth_view(
1391
+ processed_data, current_selector, 1
1392
  ),
1393
+ inputs=[processed_data_state, depth_view_selector],
1394
  outputs=[depth_view_selector, depth_map],
1395
  )
1396
 
1397
  depth_view_selector.change(
1398
+ fn=lambda processed_data, selector_value: (
1399
  update_depth_view(
1400
  processed_data,
1401
  int(selector_value.split()[1]) - 1,
 
1402
  )
1403
  if selector_value
1404
  else None
1405
  ),
1406
+ inputs=[processed_data_state, depth_view_selector],
1407
  outputs=[depth_map],
1408
  )
1409
 
1410
  # Normal tab navigation
1411
  prev_normal_btn.click(
1412
+ fn=lambda processed_data, current_selector: navigate_normal_view(
1413
+ processed_data, current_selector, -1
 
 
1414
  ),
1415
+ inputs=[processed_data_state, normal_view_selector],
1416
  outputs=[normal_view_selector, normal_map],
1417
  )
1418
 
1419
  next_normal_btn.click(
1420
+ fn=lambda processed_data, current_selector: navigate_normal_view(
1421
+ processed_data, current_selector, 1
 
 
1422
  ),
1423
+ inputs=[processed_data_state, normal_view_selector],
1424
  outputs=[normal_view_selector, normal_map],
1425
  )
1426
 
1427
  normal_view_selector.change(
1428
+ fn=lambda processed_data, selector_value: (
1429
  update_normal_view(
1430
  processed_data,
1431
  int(selector_value.split()[1]) - 1,
 
1432
  )
1433
  if selector_value
1434
  else None
1435
  ),
1436
+ inputs=[processed_data_state, normal_view_selector],
1437
  outputs=[normal_map],
1438
  )
1439
 
hf_utils/vgg_geometry.py DELETED
@@ -1,166 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import os
8
- import torch
9
- import numpy as np
10
-
11
-
12
- def unproject_depth_map_to_point_map(
13
- depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
14
- ) -> np.ndarray:
15
- """
16
- Unproject a batch of depth maps to 3D world coordinates.
17
-
18
- Args:
19
- depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
20
- extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
21
- intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
22
-
23
- Returns:
24
- np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
25
- """
26
- if isinstance(depth_map, torch.Tensor):
27
- depth_map = depth_map.cpu().numpy()
28
- if isinstance(extrinsics_cam, torch.Tensor):
29
- extrinsics_cam = extrinsics_cam.cpu().numpy()
30
- if isinstance(intrinsics_cam, torch.Tensor):
31
- intrinsics_cam = intrinsics_cam.cpu().numpy()
32
-
33
- world_points_list = []
34
- for frame_idx in range(depth_map.shape[0]):
35
- cur_world_points, _, _ = depth_to_world_coords_points(
36
- depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
37
- )
38
- world_points_list.append(cur_world_points)
39
- world_points_array = np.stack(world_points_list, axis=0)
40
-
41
- return world_points_array
42
-
43
-
44
- def depth_to_world_coords_points(
45
- depth_map: np.ndarray,
46
- extrinsic: np.ndarray,
47
- intrinsic: np.ndarray,
48
- eps=1e-8,
49
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
50
- """
51
- Convert a depth map to world coordinates.
52
-
53
- Args:
54
- depth_map (np.ndarray): Depth map of shape (H, W).
55
- intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
56
- extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
57
-
58
- Returns:
59
- tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
60
- """
61
- if depth_map is None:
62
- return None, None, None
63
-
64
- # Valid depth mask
65
- point_mask = depth_map > eps
66
-
67
- # Convert depth map to camera coordinates
68
- cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
69
-
70
- # Multiply with the inverse of extrinsic matrix to transform to world coordinates
71
- # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
72
- cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
73
-
74
- R_cam_to_world = cam_to_world_extrinsic[:3, :3]
75
- t_cam_to_world = cam_to_world_extrinsic[:3, 3]
76
-
77
- # Apply the rotation and translation to the camera coordinates
78
- world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
79
- # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
80
-
81
- return world_coords_points, cam_coords_points, point_mask
82
-
83
-
84
- def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
85
- """
86
- Convert a depth map to camera coordinates.
87
-
88
- Args:
89
- depth_map (np.ndarray): Depth map of shape (H, W).
90
- intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
91
-
92
- Returns:
93
- tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
94
- """
95
- H, W = depth_map.shape
96
- assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
97
- assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
98
-
99
- # Intrinsic parameters
100
- fu, fv = intrinsic[0, 0], intrinsic[1, 1]
101
- cu, cv = intrinsic[0, 2], intrinsic[1, 2]
102
-
103
- # Generate grid of pixel coordinates
104
- u, v = np.meshgrid(np.arange(W), np.arange(H))
105
-
106
- # Unproject to camera coordinates
107
- x_cam = (u - cu) * depth_map / fu
108
- y_cam = (v - cv) * depth_map / fv
109
- z_cam = depth_map
110
-
111
- # Stack to form camera coordinates
112
- cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
113
-
114
- return cam_coords
115
-
116
-
117
- def closed_form_inverse_se3(se3, R=None, T=None):
118
- """
119
- Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
120
-
121
- If `R` and `T` are provided, they must correspond to the rotation and translation
122
- components of `se3`. Otherwise, they will be extracted from `se3`.
123
-
124
- Args:
125
- se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
126
- R (optional): Nx3x3 array or tensor of rotation matrices.
127
- T (optional): Nx3x1 array or tensor of translation vectors.
128
-
129
- Returns:
130
- Inverted SE3 matrices with the same type and device as `se3`.
131
-
132
- Shapes:
133
- se3: (N, 4, 4)
134
- R: (N, 3, 3)
135
- T: (N, 3, 1)
136
- """
137
- # Check if se3 is a numpy array or a torch tensor
138
- is_numpy = isinstance(se3, np.ndarray)
139
-
140
- # Validate shapes
141
- if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
142
- raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
143
-
144
- # Extract R and T if not provided
145
- if R is None:
146
- R = se3[:, :3, :3] # (N,3,3)
147
- if T is None:
148
- T = se3[:, :3, 3:] # (N,3,1)
149
-
150
- # Transpose R
151
- if is_numpy:
152
- # Compute the transpose of the rotation for NumPy
153
- R_transposed = np.transpose(R, (0, 2, 1))
154
- # -R^T t for NumPy
155
- top_right = -np.matmul(R_transposed, T)
156
- inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
157
- else:
158
- R_transposed = R.transpose(1, 2) # (N,3,3)
159
- top_right = -torch.bmm(R_transposed, T) # (N,3,1)
160
- inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
161
- inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
162
-
163
- inverted_matrix[:, :3, :3] = R_transposed
164
- inverted_matrix[:, :3, 3:] = top_right
165
-
166
- return inverted_matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf_utils/visual_util.py CHANGED
@@ -221,14 +221,14 @@ def predictions_to_glb(
221
 
222
  # Prepare 4x4 matrices for camera extrinsics
223
  num_cameras = len(camera_matrices)
224
- extrinsics_matrices = np.zeros((num_cameras, 4, 4))
225
- extrinsics_matrices[:, :3, :4] = camera_matrices
226
- extrinsics_matrices[:, 3, 3] = 1
227
 
228
  if show_cam:
229
  # Add camera models to the scene
230
  for i in range(num_cameras):
231
- world_to_camera = extrinsics_matrices[i]
232
  camera_to_world = np.linalg.inv(world_to_camera)
233
  rgba_color = colormap(i / num_cameras)
234
  current_color = tuple(int(255 * x) for x in rgba_color[:3])
@@ -238,7 +238,7 @@ def predictions_to_glb(
238
  )
239
 
240
  # Align scene to the observation of the first camera
241
- scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
242
 
243
  print("GLB Scene built")
244
  return scene_3d
 
221
 
222
  # Prepare 4x4 matrices for camera extrinsics
223
  num_cameras = len(camera_matrices)
224
+ # extrinsics_matrices = np.zeros((num_cameras, 4, 4))
225
+ # extrinsics_matrices[:, :3, :4] = camera_matrices
226
+ # extrinsics_matrices[:, 3, 3] = 1
227
 
228
  if show_cam:
229
  # Add camera models to the scene
230
  for i in range(num_cameras):
231
+ world_to_camera = camera_matrices[i]
232
  camera_to_world = np.linalg.inv(world_to_camera)
233
  rgba_color = colormap(i / num_cameras)
234
  current_color = tuple(int(255 * x) for x in rgba_color[:3])
 
238
  )
239
 
240
  # Align scene to the observation of the first camera
241
+ scene_3d = apply_scene_alignment(scene_3d, camera_matrices)
242
 
243
  print("GLB Scene built")
244
  return scene_3d
mapanything/__init__.py ADDED
File without changes
mapanything/datasets/wai/ase.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class ASEWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class ASEWAI(BaseDataset):
mapanything/datasets/wai/bedlam.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class BedlamWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class BedlamWAI(BaseDataset):
mapanything/datasets/wai/blendedmvs.py CHANGED
@@ -8,7 +8,7 @@ import cv2
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
- from wai import load_data, load_frame
12
 
13
 
14
  class BlendedMVSWAI(BaseDataset):
@@ -108,16 +108,9 @@ class BlendedMVSWAI(BaseDataset):
108
  view_data = load_frame(
109
  scene_root,
110
  view_file_name,
111
- modalities=["image", "depth"],
112
- # modalities=["image", "depth", "pred_mask/moge2"],
113
  scene_meta=scene_meta,
114
  )
115
- ### HOTFIX: Load required additional masks manually
116
- ### Remove once stability issue with scene_meta is fixed
117
- mask_path = os.path.join(
118
- scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
119
- )
120
- view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
121
 
122
  # Convert necessary data to numpy
123
  image = view_data["image"].permute(1, 2, 0).numpy()
 
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
+ from mapanything.utils.wai.core import load_data, load_frame
12
 
13
 
14
  class BlendedMVSWAI(BaseDataset):
 
108
  view_data = load_frame(
109
  scene_root,
110
  view_file_name,
111
+ modalities=["image", "depth", "pred_mask/moge2"],
 
112
  scene_meta=scene_meta,
113
  )
 
 
 
 
 
 
114
 
115
  # Convert necessary data to numpy
116
  image = view_data["image"].permute(1, 2, 0).numpy()
mapanything/datasets/wai/dl3dv.py CHANGED
@@ -12,7 +12,7 @@ from mapanything.utils.cropping import (
12
  rescale_image_and_other_optional_info,
13
  resize_with_nearest_interpolation_to_match_aspect_ratio,
14
  )
15
- from wai import load_data, load_frame
16
 
17
 
18
  class DL3DVWAI(BaseDataset):
@@ -115,35 +115,14 @@ class DL3DVWAI(BaseDataset):
115
  view_data = load_frame(
116
  scene_root,
117
  view_file_name,
118
- modalities=["image"],
119
- # modalities=[
120
- # "image",
121
- # "pred_depth/mvsanywhere",
122
- # "pred_mask/moge2",
123
- # "depth_confidence/mvsanywhere",
124
- # ],
125
  scene_meta=scene_meta,
126
  )
127
- ### HOTFIX: Load required additional modalities manually
128
- ### Remove once stability issue with scene_meta is fixed
129
- mvs_depth_path = os.path.join(
130
- scene_root, "mvsanywhere", "v0", "depth", f"{view_file_name}.exr"
131
- )
132
- mvs_conf_path = os.path.join(
133
- scene_root,
134
- "mvsanywhere",
135
- "v0",
136
- "depth_confidence",
137
- f"{view_file_name}.exr",
138
- )
139
- mask_path = os.path.join(
140
- scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
141
- )
142
- view_data["pred_depth/mvsanywhere"] = load_data(mvs_depth_path, "depth")
143
- view_data["depth_confidence/mvsanywhere"] = load_data(
144
- mvs_conf_path, "scalar"
145
- )
146
- view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
147
 
148
  # Convert necessary data to numpy
149
  image = view_data["image"].permute(1, 2, 0).numpy()
 
12
  rescale_image_and_other_optional_info,
13
  resize_with_nearest_interpolation_to_match_aspect_ratio,
14
  )
15
+ from mapanything.utils.wai.core import load_data, load_frame
16
 
17
 
18
  class DL3DVWAI(BaseDataset):
 
115
  view_data = load_frame(
116
  scene_root,
117
  view_file_name,
118
+ modalities=[
119
+ "image",
120
+ "pred_depth/mvsanywhere",
121
+ "pred_mask/moge2",
122
+ "depth_confidence/mvsanywhere",
123
+ ],
 
124
  scene_meta=scene_meta,
125
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # Convert necessary data to numpy
128
  image = view_data["image"].permute(1, 2, 0).numpy()
mapanything/datasets/wai/dtu.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class DTUWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class DTUWAI(BaseDataset):
mapanything/datasets/wai/dynamicreplica.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class DynamicReplicaWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class DynamicReplicaWAI(BaseDataset):
mapanything/datasets/wai/eth3d.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class ETH3DWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class ETH3DWAI(BaseDataset):
mapanything/datasets/wai/gta_sfm.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class GTASfMWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class GTASfMWAI(BaseDataset):
mapanything/datasets/wai/matrixcity.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class MatrixCityWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class MatrixCityWAI(BaseDataset):
mapanything/datasets/wai/megadepth.py CHANGED
@@ -8,7 +8,7 @@ import cv2
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
- from wai import load_data, load_frame
12
 
13
 
14
  class MegaDepthWAI(BaseDataset):
@@ -109,16 +109,9 @@ class MegaDepthWAI(BaseDataset):
109
  view_data = load_frame(
110
  scene_root,
111
  view_file_name,
112
- modalities=["image", "depth"],
113
- # modalities=["image", "depth", "pred_mask/moge2"],
114
  scene_meta=scene_meta,
115
  )
116
- ### HOTFIX: Load required additional masks manually
117
- ### Remove once stability issue with scene_meta is fixed
118
- mask_path = os.path.join(
119
- scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
120
- )
121
- view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
122
 
123
  # Convert necessary data to numpy
124
  image = view_data["image"].permute(1, 2, 0).numpy()
 
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
+ from mapanything.utils.wai.core import load_data, load_frame
12
 
13
 
14
  class MegaDepthWAI(BaseDataset):
 
109
  view_data = load_frame(
110
  scene_root,
111
  view_file_name,
112
+ modalities=["image", "depth", "pred_mask/moge2"],
 
113
  scene_meta=scene_meta,
114
  )
 
 
 
 
 
 
115
 
116
  # Convert necessary data to numpy
117
  image = view_data["image"].permute(1, 2, 0).numpy()
mapanything/datasets/wai/mpsd.py CHANGED
@@ -8,7 +8,7 @@ import cv2
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
- from wai import load_data, load_frame
12
 
13
 
14
  class MPSDWAI(BaseDataset):
@@ -108,16 +108,9 @@ class MPSDWAI(BaseDataset):
108
  view_data = load_frame(
109
  scene_root,
110
  view_file_name,
111
- modalities=["image", "depth"],
112
- # modalities=["image", "depth", "pred_mask/moge2"],
113
  scene_meta=scene_meta,
114
  )
115
- ### HOTFIX: Load required additional masks manually
116
- ### Remove once stability issue with scene_meta is fixed
117
- mask_path = os.path.join(
118
- scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
119
- )
120
- view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
121
 
122
  # Convert necessary data to numpy
123
  image = view_data["image"].permute(1, 2, 0).numpy()
 
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
+ from mapanything.utils.wai.core import load_data, load_frame
12
 
13
 
14
  class MPSDWAI(BaseDataset):
 
108
  view_data = load_frame(
109
  scene_root,
110
  view_file_name,
111
+ modalities=["image", "depth", "pred_mask/moge2"],
 
112
  scene_meta=scene_meta,
113
  )
 
 
 
 
 
 
114
 
115
  # Convert necessary data to numpy
116
  image = view_data["image"].permute(1, 2, 0).numpy()
mapanything/datasets/wai/mvs_synth.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class MVSSynthWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class MVSSynthWAI(BaseDataset):
mapanything/datasets/wai/paralleldomain4d.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class ParallelDomain4DWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class ParallelDomain4DWAI(BaseDataset):
mapanything/datasets/wai/sailvos3d.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class SAILVOS3DWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class SAILVOS3DWAI(BaseDataset):
mapanything/datasets/wai/scannetpp.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class ScanNetPPWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class ScanNetPPWAI(BaseDataset):
mapanything/datasets/wai/spring.py CHANGED
@@ -8,7 +8,7 @@ import cv2
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
- from wai import load_data, load_frame
12
 
13
 
14
  class SpringWAI(BaseDataset):
@@ -107,16 +107,9 @@ class SpringWAI(BaseDataset):
107
  view_data = load_frame(
108
  scene_root,
109
  view_file_name,
110
- modalities=["image", "depth", "skymask"],
111
- # modalities=["image", "depth", "skymask", "pred_mask/moge2"],
112
  scene_meta=scene_meta,
113
  )
114
- ### HOTFIX: Load required additional masks manually
115
- ### Remove once stability issue with scene_meta is fixed
116
- mask_path = os.path.join(
117
- scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
118
- )
119
- view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
120
 
121
  # Convert necessary data to numpy
122
  image = view_data["image"].permute(1, 2, 0).numpy()
 
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
+ from mapanything.utils.wai.core import load_data, load_frame
12
 
13
 
14
  class SpringWAI(BaseDataset):
 
107
  view_data = load_frame(
108
  scene_root,
109
  view_file_name,
110
+ modalities=["image", "depth", "skymask", "pred_mask/moge2"],
 
111
  scene_meta=scene_meta,
112
  )
 
 
 
 
 
 
113
 
114
  # Convert necessary data to numpy
115
  image = view_data["image"].permute(1, 2, 0).numpy()
mapanything/datasets/wai/structured3d.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class Structured3DWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class Structured3DWAI(BaseDataset):
mapanything/datasets/wai/tav2_wb.py CHANGED
@@ -8,7 +8,7 @@ import cv2
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
- from wai import load_data, load_frame
12
 
13
 
14
  class TartanAirV2WBWAI(BaseDataset):
@@ -108,16 +108,9 @@ class TartanAirV2WBWAI(BaseDataset):
108
  view_data = load_frame(
109
  scene_root,
110
  view_file_name,
111
- modalities=["image", "depth"],
112
- # modalities=["image", "depth", "pred_mask/moge2"],
113
  scene_meta=scene_meta,
114
  )
115
- ### HOTFIX: Load required additional masks manually
116
- ### Remove once stability issue with scene_meta is fixed
117
- mask_path = os.path.join(
118
- scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
119
- )
120
- view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
121
 
122
  # Convert necessary data to numpy
123
  image = view_data["image"].permute(1, 2, 0).numpy()
 
8
  import numpy as np
9
 
10
  from mapanything.datasets.base.base_dataset import BaseDataset
11
+ from mapanything.utils.wai.core import load_data, load_frame
12
 
13
 
14
  class TartanAirV2WBWAI(BaseDataset):
 
108
  view_data = load_frame(
109
  scene_root,
110
  view_file_name,
111
+ modalities=["image", "depth", "pred_mask/moge2"],
 
112
  scene_meta=scene_meta,
113
  )
 
 
 
 
 
 
114
 
115
  # Convert necessary data to numpy
116
  image = view_data["image"].permute(1, 2, 0).numpy()
mapanything/datasets/wai/unrealstereo4k.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class UnrealStereo4KWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class UnrealStereo4KWAI(BaseDataset):
mapanything/datasets/wai/xrooms.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
- from wai import load_data, load_frame
11
 
12
 
13
  class XRoomsWAI(BaseDataset):
 
7
  import numpy as np
8
 
9
  from mapanything.datasets.base.base_dataset import BaseDataset
10
+ from mapanything.utils.wai.core import load_data, load_frame
11
 
12
 
13
  class XRoomsWAI(BaseDataset):
mapanything/models/external/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # External Model Code for Benchmarking & Re-Training
2
+
3
+ This directory contains external model code that we use to train and benchmark external models fairly. These libraries are not part of the core MapAnything codebase and are included for only benchmarking purposes. The code in this directory is licensed under the same license as the source code from which it was derived, unless otherwise specified.
4
+
5
+ The open-source Apache 2.0 License of MapAnything does not apply to these libraries.
mapanything/models/external/moge/models/v1.py CHANGED
@@ -475,7 +475,7 @@ class MoGeModel(nn.Module):
475
  return_dict = {"points": points, "mask": mask}
476
  return return_dict
477
 
478
- @torch.inference_mode()
479
  def infer(
480
  self,
481
  image: torch.Tensor,
 
475
  return_dict = {"points": points, "mask": mask}
476
  return return_dict
477
 
478
+ # @torch.inference_mode()
479
  def infer(
480
  self,
481
  image: torch.Tensor,
mapanything/models/external/moge/models/v2.py CHANGED
@@ -227,7 +227,7 @@ class MoGeModel(nn.Module):
227
 
228
  return return_dict
229
 
230
- @torch.inference_mode()
231
  def infer(
232
  self,
233
  image: torch.Tensor,
 
227
 
228
  return return_dict
229
 
230
+ # @torch.inference_mode()
231
  def infer(
232
  self,
233
  image: torch.Tensor,
mapanything/models/mapanything/ablations.py CHANGED
@@ -134,8 +134,10 @@ class MapAnythingAblations(nn.Module):
134
  # Initialize image encoder
135
  if self.encoder_config["uses_torch_hub"]:
136
  self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
137
- del self.encoder_config["uses_torch_hub"]
138
- self.encoder = encoder_factory(**self.encoder_config)
 
 
139
 
140
  # Initialize the encoder for ray directions
141
  ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
 
134
  # Initialize image encoder
135
  if self.encoder_config["uses_torch_hub"]:
136
  self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
137
+ # Create a copy of the config before deleting the key to preserve it for serialization
138
+ encoder_config_copy = self.encoder_config.copy()
139
+ del encoder_config_copy["uses_torch_hub"]
140
+ self.encoder = encoder_factory(**encoder_config_copy)
141
 
142
  # Initialize the encoder for ray directions
143
  ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
mapanything/models/mapanything/model.py CHANGED
@@ -2,11 +2,13 @@
2
  MapAnything model class defined using UniCeption modules.
3
  """
4
 
 
5
  from functools import partial
6
- from typing import Callable, Dict, Type, Union
7
 
8
  import torch
9
  import torch.nn as nn
 
10
 
11
  from mapanything.utils.geometry import (
12
  apply_log_to_norm,
@@ -15,6 +17,11 @@ from mapanything.utils.geometry import (
15
  normalize_pose_translations,
16
  transform_pose_using_quats_and_trans_2_to_1,
17
  )
 
 
 
 
 
18
  from uniception.models.encoders import (
19
  encoder_factory,
20
  EncoderGlobalRepInput,
@@ -72,7 +79,7 @@ if hasattr(torch.backends.cuda, "matmul") and hasattr(
72
  torch.backends.cuda.matmul.allow_tf32 = True
73
 
74
 
75
- class MapAnything(nn.Module):
76
  "Modular MapAnything model class that supports input of images & optional geometric modalities (multiple reconstruction tasks)."
77
 
78
  def __init__(
@@ -139,8 +146,10 @@ class MapAnything(nn.Module):
139
  # Initialize image encoder
140
  if self.encoder_config["uses_torch_hub"]:
141
  self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
142
- del self.encoder_config["uses_torch_hub"]
143
- self.encoder = encoder_factory(**self.encoder_config)
 
 
144
 
145
  # Initialize the encoder for ray directions
146
  ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
@@ -199,6 +208,14 @@ class MapAnything(nn.Module):
199
  # Load pretrained weights
200
  self._load_pretrained_weights()
201
 
 
 
 
 
 
 
 
 
202
  def _initialize_info_sharing(self, info_sharing_config):
203
  """
204
  Initialize the information sharing module based on the configuration.
@@ -1717,3 +1734,202 @@ class MapAnything(nn.Module):
1717
  res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i]
1718
 
1719
  return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  MapAnything model class defined using UniCeption modules.
3
  """
4
 
5
+ import warnings
6
  from functools import partial
7
+ from typing import Any, Callable, Dict, List, Type, Union
8
 
9
  import torch
10
  import torch.nn as nn
11
+ from huggingface_hub import PyTorchModelHubMixin
12
 
13
  from mapanything.utils.geometry import (
14
  apply_log_to_norm,
 
17
  normalize_pose_translations,
18
  transform_pose_using_quats_and_trans_2_to_1,
19
  )
20
+ from mapanything.utils.inference import (
21
+ postprocess_model_outputs_for_inference,
22
+ preprocess_input_views_for_inference,
23
+ validate_input_views_for_inference,
24
+ )
25
  from uniception.models.encoders import (
26
  encoder_factory,
27
  EncoderGlobalRepInput,
 
79
  torch.backends.cuda.matmul.allow_tf32 = True
80
 
81
 
82
+ class MapAnything(nn.Module, PyTorchModelHubMixin):
83
  "Modular MapAnything model class that supports input of images & optional geometric modalities (multiple reconstruction tasks)."
84
 
85
  def __init__(
 
146
  # Initialize image encoder
147
  if self.encoder_config["uses_torch_hub"]:
148
  self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
149
+ # Create a copy of the config before deleting the key to preserve it for serialization
150
+ encoder_config_copy = self.encoder_config.copy()
151
+ del encoder_config_copy["uses_torch_hub"]
152
+ self.encoder = encoder_factory(**encoder_config_copy)
153
 
154
  # Initialize the encoder for ray directions
155
  ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
 
208
  # Load pretrained weights
209
  self._load_pretrained_weights()
210
 
211
+ @property
212
+ def device(self) -> torch.device:
213
+ return next(self.parameters()).device
214
+
215
+ @property
216
+ def dtype(self) -> torch.dtype:
217
+ return next(self.parameters()).dtype
218
+
219
  def _initialize_info_sharing(self, info_sharing_config):
220
  """
221
  Initialize the information sharing module based on the configuration.
 
1734
  res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i]
1735
 
1736
  return res
1737
+
1738
+ def _configure_geometric_input_config(
1739
+ self,
1740
+ use_calibration: bool,
1741
+ use_depth: bool,
1742
+ use_pose: bool,
1743
+ use_depth_scale: bool,
1744
+ use_pose_scale: bool,
1745
+ ):
1746
+ """
1747
+ Configure the geometric input configuration
1748
+ """
1749
+ # Store original config for restoration
1750
+ if not hasattr(self, "_original_geometric_config"):
1751
+ self._original_geometric_config = dict(self.geometric_input_config)
1752
+
1753
+ # Set the geometric input configuration
1754
+ if not (use_calibration or use_depth or use_pose):
1755
+ # No geometric inputs (images-only mode)
1756
+ self.geometric_input_config.update(
1757
+ {
1758
+ "overall_prob": 0.0,
1759
+ "dropout_prob": 1.0,
1760
+ "ray_dirs_prob": 0.0,
1761
+ "depth_prob": 0.0,
1762
+ "cam_prob": 0.0,
1763
+ "sparse_depth_prob": 0.0,
1764
+ "depth_scale_norm_all_prob": 0.0,
1765
+ "pose_scale_norm_all_prob": 0.0,
1766
+ }
1767
+ )
1768
+ else:
1769
+ # Enable geometric inputs with deterministic behavior
1770
+ self.geometric_input_config.update(
1771
+ {
1772
+ "overall_prob": 1.0,
1773
+ "dropout_prob": 0.0,
1774
+ "ray_dirs_prob": 1.0 if use_calibration else 0.0,
1775
+ "depth_prob": 1.0 if use_depth else 0.0,
1776
+ "cam_prob": 1.0 if use_pose else 0.0,
1777
+ "sparse_depth_prob": 0.0, # No sparsification during inference
1778
+ "depth_scale_norm_all_prob": 0.0 if use_depth_scale else 1.0,
1779
+ "pose_scale_norm_all_prob": 0.0 if use_pose_scale else 1.0,
1780
+ }
1781
+ )
1782
+
1783
+ def _restore_original_geometric_input_config(self):
1784
+ """
1785
+ Restore original geometric input configuration
1786
+ """
1787
+ if hasattr(self, "_original_geometric_config"):
1788
+ self.geometric_input_config.update(self._original_geometric_config)
1789
+
1790
+ @torch.inference_mode()
1791
+ def infer(
1792
+ self,
1793
+ views: List[Dict[str, Any]],
1794
+ use_amp: bool = True,
1795
+ amp_dtype: str = "bf16",
1796
+ apply_mask: bool = True,
1797
+ mask_edges: bool = True,
1798
+ edge_normal_threshold: float = 5.0,
1799
+ edge_depth_threshold: float = 0.03,
1800
+ apply_confidence_mask: bool = False,
1801
+ confidence_percentile: float = 10,
1802
+ ignore_calibration_inputs: bool = False,
1803
+ ignore_depth_inputs: bool = False,
1804
+ ignore_pose_inputs: bool = False,
1805
+ ignore_depth_scale_inputs: bool = False,
1806
+ ignore_pose_scale_inputs: bool = False,
1807
+ ) -> List[Dict[str, torch.Tensor]]:
1808
+ """
1809
+ User-friendly inference with strict input validation and automatic conversion.
1810
+
1811
+ Args:
1812
+ views: List of view dictionaries. Each dict can contain:
1813
+ Required:
1814
+ - 'img': torch.Tensor of shape (B, 3, H, W) - normalized RGB images
1815
+ - 'data_norm_type': str - normalization type used to normalize the images (must be equal to self.model.encoder.data_norm_type)
1816
+
1817
+ Optional Geometric Inputs (only one of intrinsics OR ray_directions):
1818
+ - 'intrinsics': torch.Tensor of shape (B, 3, 3) - will be converted to ray directions
1819
+ - 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame
1820
+ - 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame (intrinsics or ray_directions must be provided)
1821
+ - 'camera_poses': torch.Tensor of shape (B, 4, 4) or tuple of (quats - (B, 4), trans - (B, 3)) - can be any world frame
1822
+ - 'is_metric_scale': bool or torch.Tensor of shape (B,) - if not provided, defaults to True
1823
+
1824
+ Optional Additional Info:
1825
+ - 'instance': List[str] where length of list is B - instance info for each view
1826
+ - 'idx': List[int] where length of list is B - index info for each view
1827
+ - 'true_shape': List[tuple] where length of list is B - true shape info (H, W) for each view
1828
+
1829
+ use_amp: Whether to use automatic mixed precision for faster inference. Defaults to True.
1830
+ amp_dtype: The dtype to use for mixed precision. Defaults to "bf16" (bfloat16). Options: "fp16", "bf16", "fp32".
1831
+ apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
1832
+ mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
1833
+ edge_normal_threshold: Tolerance threshold for normals-based edge detection. Defaults to 5.0.
1834
+ edge_depth_threshold: Relative tolerance threshold for depth-based edge detection. Defaults to 0.03.
1835
+ apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False.
1836
+ confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10.
1837
+ ignore_calibration_inputs: Whether to ignore the calibration inputs (intrinsics and ray_directions). Defaults to False.
1838
+ ignore_depth_inputs: Whether to ignore the depth inputs. Defaults to False.
1839
+ ignore_pose_inputs: Whether to ignore the pose inputs. Defaults to False.
1840
+ ignore_depth_scale_inputs: Whether to ignore the depth scale inputs. Defaults to False.
1841
+ ignore_pose_scale_inputs: Whether to ignore the pose scale inputs. Defaults to False.
1842
+
1843
+ IMPORTANT CONSTRAINTS:
1844
+ - Cannot provide both 'intrinsics' and 'ray_directions' (they represent the same information)
1845
+ - If 'depth' is provided, then 'intrinsics' or 'ray_directions' must also be provided
1846
+ - If ANY view has 'camera_poses', then view 0 (first view) MUST also have 'camera_poses'
1847
+
1848
+ Returns:
1849
+ List of prediction dictionaries, one per view. Each dict contains:
1850
+ - 'img_no_norm': torch.Tensor of shape (B, H, W, 3) - denormalized rgb images
1851
+ - 'pts3d': torch.Tensor of shape (B, H, W, 3) - predicted points in world frame
1852
+ - 'pts3d_cam': torch.Tensor of shape (B, H, W, 3) - predicted points in camera frame
1853
+ - 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame
1854
+ - 'intrinsics': torch.Tensor of shape (B, 3, 3) - pinhole camera intrinsics recovered from ray directions
1855
+ - 'depth_along_ray': torch.Tensor of shape (B, H, W, 1) - depth along ray in camera frame
1856
+ - 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame
1857
+ - 'cam_trans': torch.Tensor of shape (B, 3) - camera translation in world frame
1858
+ - 'cam_quats': torch.Tensor of shape (B, 4) - camera quaternion in world frame
1859
+ - 'camera_poses': torch.Tensor of shape (B, 4, 4) - camera pose in world frame
1860
+ - 'metric_scaling_factor': torch.Tensor of shape (B,) - applied metric scaling factor
1861
+ - 'mask': torch.Tensor of shape (B, H, W, 1) - combo of non-ambiguous mask, edge mask and confidence-based mask if used
1862
+ - 'non_ambiguous_mask': torch.Tensor of shape (B, H, W) - non-ambiguous mask
1863
+ - 'non_ambiguous_mask_logits': torch.Tensor of shape (B, H, W) - non-ambiguous mask logits
1864
+ - 'conf': torch.Tensor of shape (B, H, W) - confidence
1865
+
1866
+ Raises:
1867
+ ValueError: For invalid inputs, missing required keys, conflicting modalities, or constraint violations
1868
+ """
1869
+ # Determine the mixed precision floating point type
1870
+ if use_amp:
1871
+ if amp_dtype == "fp16":
1872
+ amp_dtype = torch.float16
1873
+ elif amp_dtype == "bf16":
1874
+ if torch.cuda.is_bf16_supported():
1875
+ amp_dtype = torch.bfloat16
1876
+ else:
1877
+ warnings.warn(
1878
+ "bf16 is not supported on this device. Using fp16 instead."
1879
+ )
1880
+ amp_dtype = torch.float16
1881
+ elif amp_dtype == "fp32":
1882
+ amp_dtype = torch.float32
1883
+ else:
1884
+ amp_dtype = torch.float32
1885
+
1886
+ # Validate the input views
1887
+ validated_views = validate_input_views_for_inference(views)
1888
+
1889
+ # Transfer the views to the same device as the model
1890
+ ignore_keys = set(
1891
+ [
1892
+ "instance",
1893
+ "idx",
1894
+ "true_shape",
1895
+ "data_norm_type",
1896
+ ]
1897
+ )
1898
+ for view in validated_views:
1899
+ for name in view.keys():
1900
+ if name in ignore_keys:
1901
+ continue
1902
+ view[name] = view[name].to(self.device, non_blocking=True)
1903
+
1904
+ # Pre-process the input views
1905
+ processed_views = preprocess_input_views_for_inference(validated_views)
1906
+
1907
+ # Set the model input probabilities based on input args for ignoring inputs
1908
+ self._configure_geometric_input_config(
1909
+ use_calibration=not ignore_calibration_inputs,
1910
+ use_depth=not ignore_depth_inputs,
1911
+ use_pose=not ignore_pose_inputs,
1912
+ use_depth_scale=not ignore_depth_scale_inputs,
1913
+ use_pose_scale=not ignore_pose_scale_inputs,
1914
+ )
1915
+
1916
+ # Run the model
1917
+ with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
1918
+ preds = self.forward(processed_views)
1919
+
1920
+ # Post-process the model outputs
1921
+ preds = postprocess_model_outputs_for_inference(
1922
+ raw_outputs=preds,
1923
+ input_views=processed_views,
1924
+ apply_mask=apply_mask,
1925
+ mask_edges=mask_edges,
1926
+ edge_normal_threshold=edge_normal_threshold,
1927
+ edge_depth_threshold=edge_depth_threshold,
1928
+ apply_confidence_mask=apply_confidence_mask,
1929
+ confidence_percentile=confidence_percentile,
1930
+ )
1931
+
1932
+ # Restore the original configuration
1933
+ self._restore_original_geometric_input_config()
1934
+
1935
+ return preds
mapanything/models/mapanything/modular_dust3r.py CHANGED
@@ -99,8 +99,10 @@ class ModularDUSt3R(nn.Module):
99
  # Initialize Encoder
100
  if self.encoder_config["uses_torch_hub"]:
101
  self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
102
- del self.encoder_config["uses_torch_hub"]
103
- self.encoder = encoder_factory(**self.encoder_config)
 
 
104
 
105
  # Initialize Custom Positional Encoding if required
106
  if custom_positional_encoding is not None:
 
99
  # Initialize Encoder
100
  if self.encoder_config["uses_torch_hub"]:
101
  self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
102
+ # Create a copy of the config before deleting the key to preserve it for serialization
103
+ encoder_config_copy = self.encoder_config.copy()
104
+ del encoder_config_copy["uses_torch_hub"]
105
+ self.encoder = encoder_factory(**encoder_config_copy)
106
 
107
  # Initialize Custom Positional Encoding if required
108
  if custom_positional_encoding is not None:
mapanything/train/losses.py CHANGED
@@ -1766,6 +1766,202 @@ class PointsPlusScaleRegr3D(Criterion, MultiLoss):
1766
  return losses, (details | {})
1767
 
1768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1769
  class FactoredGeometryRegr3D(Criterion, MultiLoss):
1770
  """
1771
  Regression Loss for Factored Geometry.
@@ -1787,6 +1983,7 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
1787
  pose_quats_loss_weight=1,
1788
  pose_trans_loss_weight=1,
1789
  compute_pairwise_relative_pose_loss=False,
 
1790
  compute_world_frame_points_loss=True,
1791
  world_frame_points_loss_weight=1,
1792
  ):
@@ -1821,6 +2018,8 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
1821
  pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
1822
  compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
1823
  exhaustive pairwise relative poses. Default: False.
 
 
1824
  compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
1825
  world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
1826
  """
@@ -1847,6 +2046,7 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
1847
  self.pose_quats_loss_weight = pose_quats_loss_weight
1848
  self.pose_trans_loss_weight = pose_trans_loss_weight
1849
  self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
 
1850
  self.compute_world_frame_points_loss = compute_world_frame_points_loss
1851
  self.world_frame_points_loss_weight = world_frame_points_loss_weight
1852
 
@@ -1869,6 +2069,19 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
1869
  gt_ray_directions = []
1870
  gt_pose_quats = []
1871
  # Predicted quantities
 
 
 
 
 
 
 
 
 
 
 
 
 
1872
  no_norm_pr_pts = []
1873
  no_norm_pr_pts_cam = []
1874
  no_norm_pr_depth = []
@@ -1922,16 +2135,34 @@ class FactoredGeometryRegr3D(Criterion, MultiLoss):
1922
  gt_pose_quats.append(gt_pose_quats_in_view0)
1923
  no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
1924
 
1925
- # Get predictions
1926
- no_norm_pr_pts.append(preds[i]["pts3d"])
1927
  no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"])
1928
  pr_ray_directions.append(preds[i]["ray_directions"])
1929
  if self.depth_type_for_loss == "depth_along_ray":
1930
  no_norm_pr_depth.append(preds[i]["depth_along_ray"])
1931
  elif self.depth_type_for_loss == "depth_z":
1932
  no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:])
1933
- no_norm_pr_pose_trans.append(preds[i]["cam_trans"])
1934
- pr_pose_quats.append(preds[i]["cam_quats"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1935
 
1936
  if dist_clip is not None:
1937
  # Points that are too far-away == invalid
@@ -2443,6 +2674,7 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
2443
  pose_quats_loss_weight=1,
2444
  pose_trans_loss_weight=1,
2445
  compute_pairwise_relative_pose_loss=False,
 
2446
  compute_world_frame_points_loss=True,
2447
  world_frame_points_loss_weight=1,
2448
  apply_normal_and_gm_loss_to_synthetic_data_only=True,
@@ -2478,6 +2710,8 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
2478
  pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
2479
  compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
2480
  exhaustive pairwise relative poses. Default: False.
 
 
2481
  compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
2482
  world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
2483
  apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
@@ -2500,6 +2734,7 @@ class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
2500
  pose_quats_loss_weight=pose_quats_loss_weight,
2501
  pose_trans_loss_weight=pose_trans_loss_weight,
2502
  compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
 
2503
  compute_world_frame_points_loss=compute_world_frame_points_loss,
2504
  world_frame_points_loss_weight=world_frame_points_loss_weight,
2505
  )
@@ -2895,6 +3130,7 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
2895
  pose_trans_loss_weight=1,
2896
  scale_loss_weight=1,
2897
  compute_pairwise_relative_pose_loss=False,
 
2898
  compute_world_frame_points_loss=True,
2899
  world_frame_points_loss_weight=1,
2900
  ):
@@ -2928,6 +3164,8 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
2928
  scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
2929
  compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
2930
  exhaustive pairwise relative poses. Default: False.
 
 
2931
  compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
2932
  world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
2933
  """
@@ -2948,6 +3186,7 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
2948
  self.pose_trans_loss_weight = pose_trans_loss_weight
2949
  self.scale_loss_weight = scale_loss_weight
2950
  self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
 
2951
  self.compute_world_frame_points_loss = compute_world_frame_points_loss
2952
  self.world_frame_points_loss_weight = world_frame_points_loss_weight
2953
 
@@ -2970,6 +3209,19 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
2970
  gt_ray_directions = []
2971
  gt_pose_quats = []
2972
  # Predicted quantities
 
 
 
 
 
 
 
 
 
 
 
 
 
2973
  no_norm_pr_pts = []
2974
  no_norm_pr_pts_cam = []
2975
  no_norm_pr_depth = []
@@ -3024,6 +3276,24 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
3024
  gt_pose_quats.append(gt_pose_quats_in_view0)
3025
  no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
3026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3027
  # Get predictions for normalized loss
3028
  if self.depth_type_for_loss == "depth_along_ray":
3029
  curr_view_no_norm_depth = preds[i]["depth_along_ray"]
@@ -3032,7 +3302,7 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
3032
  if "metric_scaling_factor" in preds[i].keys():
3033
  # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
3034
  # This detaches the predicted metric scaling factor from the geometry based loss
3035
- curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
3036
  "metric_scaling_factor"
3037
  ].unsqueeze(-1).unsqueeze(-1)
3038
  curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
@@ -3042,19 +3312,19 @@ class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
3042
  "metric_scaling_factor"
3043
  ].unsqueeze(-1).unsqueeze(-1)
3044
  curr_view_no_norm_pr_pose_trans = (
3045
- preds[i]["cam_trans"] / preds[i]["metric_scaling_factor"]
3046
  )
3047
  else:
3048
- curr_view_no_norm_pr_pts = preds[i]["pts3d"]
3049
  curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
3050
  curr_view_no_norm_depth = curr_view_no_norm_depth
3051
- curr_view_no_norm_pr_pose_trans = preds[i]["cam_trans"]
3052
  no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
3053
  no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
3054
  no_norm_pr_depth.append(curr_view_no_norm_depth)
3055
  no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
3056
  pr_ray_directions.append(preds[i]["ray_directions"])
3057
- pr_pose_quats.append(preds[i]["cam_quats"])
3058
 
3059
  # Get the predicted metric scale points
3060
  if "metric_scaling_factor" in preds[i].keys():
@@ -3553,6 +3823,7 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
3553
  pose_trans_loss_weight=1,
3554
  scale_loss_weight=1,
3555
  compute_pairwise_relative_pose_loss=False,
 
3556
  compute_world_frame_points_loss=True,
3557
  world_frame_points_loss_weight=1,
3558
  apply_normal_and_gm_loss_to_synthetic_data_only=True,
@@ -3585,6 +3856,8 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
3585
  scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
3586
  compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
3587
  exhaustive pairwise relative poses. Default: False.
 
 
3588
  compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
3589
  world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
3590
  apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
@@ -3607,6 +3880,7 @@ class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
3607
  pose_trans_loss_weight=pose_trans_loss_weight,
3608
  scale_loss_weight=scale_loss_weight,
3609
  compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
 
3610
  compute_world_frame_points_loss=compute_world_frame_points_loss,
3611
  world_frame_points_loss_weight=world_frame_points_loss_weight,
3612
  )
 
1766
  return losses, (details | {})
1767
 
1768
 
1769
+ class NormalGMLoss(MultiLoss):
1770
+ """
1771
+ Normal & Gradient Matching Loss for Monocular Depth Training.
1772
+ """
1773
+
1774
+ def __init__(
1775
+ self,
1776
+ norm_predictions=True,
1777
+ norm_mode="avg_dis",
1778
+ apply_normal_and_gm_loss_to_synthetic_data_only=True,
1779
+ ):
1780
+ """
1781
+ Initialize the loss criterion for Normal & Gradient Matching Loss (currently only valid for 1 view).
1782
+ Computes:
1783
+ (1) Normal Loss over the PointMap (naturally will be in local frame) in euclidean coordinates,
1784
+ (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
1785
+
1786
+ Args:
1787
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
1788
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
1789
+ apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
1790
+ If False, apply the normal and gm loss to all data. Default: True.
1791
+ """
1792
+ super().__init__()
1793
+ self.norm_predictions = norm_predictions
1794
+ self.norm_mode = norm_mode
1795
+ self.apply_normal_and_gm_loss_to_synthetic_data_only = (
1796
+ apply_normal_and_gm_loss_to_synthetic_data_only
1797
+ )
1798
+
1799
+ def get_all_info(self, batch, preds, dist_clip=None):
1800
+ """
1801
+ Function to get all the information needed to compute the loss.
1802
+ Returns all quantities normalized.
1803
+ """
1804
+ n_views = len(batch)
1805
+ assert n_views == 1, (
1806
+ "Normal & Gradient Matching Loss Class only supports 1 view"
1807
+ )
1808
+
1809
+ # Everything is normalized w.r.t. camera of view1
1810
+ in_camera1 = closed_form_pose_inverse(batch[0]["camera_pose"])
1811
+
1812
+ # Initialize lists to store data for all views
1813
+ no_norm_gt_pts = []
1814
+ valid_masks = []
1815
+ no_norm_pr_pts = []
1816
+
1817
+ # Get ground truth & prediction info for all views
1818
+ for i in range(n_views):
1819
+ # Get ground truth
1820
+ no_norm_gt_pts.append(geotrf(in_camera1, batch[i]["pts3d"]))
1821
+ valid_masks.append(batch[i]["valid_mask"].clone())
1822
+
1823
+ # Get predictions for normalized loss
1824
+ if "metric_scaling_factor" in preds[i].keys():
1825
+ # Divide by the predicted metric scaling factor to get the raw predicted points
1826
+ # This detaches the predicted metric scaling factor from the geometry based loss
1827
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
1828
+ "metric_scaling_factor"
1829
+ ].unsqueeze(-1).unsqueeze(-1)
1830
+ else:
1831
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"]
1832
+ no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
1833
+
1834
+ if dist_clip is not None:
1835
+ # Points that are too far-away == invalid
1836
+ for i in range(n_views):
1837
+ dis = no_norm_gt_pts[i].norm(dim=-1)
1838
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
1839
+
1840
+ # Initialize normalized tensors
1841
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
1842
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
1843
+
1844
+ # Normalize the predicted points if specified
1845
+ if self.norm_predictions:
1846
+ pr_normalization_output = normalize_multiple_pointclouds(
1847
+ no_norm_pr_pts,
1848
+ valid_masks,
1849
+ self.norm_mode,
1850
+ ret_factor=True,
1851
+ )
1852
+ pr_pts_norm = pr_normalization_output[:-1]
1853
+
1854
+ # Normalize the ground truth points
1855
+ gt_normalization_output = normalize_multiple_pointclouds(
1856
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
1857
+ )
1858
+ gt_pts_norm = gt_normalization_output[:-1]
1859
+
1860
+ for i in range(n_views):
1861
+ if self.norm_predictions:
1862
+ # Assign the normalized predictions
1863
+ pr_pts[i] = pr_pts_norm[i]
1864
+ else:
1865
+ # Assign the raw predicted points
1866
+ pr_pts[i] = no_norm_pr_pts[i]
1867
+ # Assign the normalized ground truth
1868
+ gt_pts[i] = gt_pts_norm[i]
1869
+
1870
+ return gt_pts, pr_pts, valid_masks
1871
+
1872
+ def compute_loss(self, batch, preds, **kw):
1873
+ gt_pts, pred_pts, valid_masks = self.get_all_info(batch, preds, **kw)
1874
+ n_views = len(batch)
1875
+ assert n_views == 1, (
1876
+ "Normal & Gradient Matching Loss Class only supports 1 view"
1877
+ )
1878
+
1879
+ normal_losses = []
1880
+ gradient_matching_losses = []
1881
+ details = {}
1882
+ running_avg_dict = {}
1883
+ self_name = type(self).__name__
1884
+
1885
+ for i in range(n_views):
1886
+ # Get the local frame points, log space depth_z & valid masks
1887
+ pred_local_pts3d = pred_pts[i]
1888
+ pred_depth_z = pred_local_pts3d[..., 2:]
1889
+ pred_depth_z = apply_log_to_norm(pred_depth_z)
1890
+ gt_local_pts3d = gt_pts[i]
1891
+ gt_depth_z = gt_local_pts3d[..., 2:]
1892
+ gt_depth_z = apply_log_to_norm(gt_depth_z)
1893
+ valid_mask_for_normal_gm_loss = valid_masks[i].clone()
1894
+
1895
+ # Update the validity mask for normal & gm loss based on the synthetic data mask if required
1896
+ if self.apply_normal_and_gm_loss_to_synthetic_data_only:
1897
+ synthetic_mask = batch[i]["is_synthetic"] # (B, )
1898
+ synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
1899
+ synthetic_mask = synthetic_mask.expand(
1900
+ -1, pred_depth_z.shape[1], pred_depth_z.shape[2]
1901
+ ) # (B, H, W)
1902
+ valid_mask_for_normal_gm_loss = (
1903
+ valid_mask_for_normal_gm_loss & synthetic_mask
1904
+ )
1905
+
1906
+ # Compute the normal loss
1907
+ normal_loss = compute_normal_loss(
1908
+ pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
1909
+ )
1910
+ normal_losses.append(normal_loss)
1911
+
1912
+ # Compute the gradient matching loss
1913
+ gradient_matching_loss = compute_gradient_matching_loss(
1914
+ pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
1915
+ )
1916
+ gradient_matching_losses.append(gradient_matching_loss)
1917
+
1918
+ # Add loss details if only valid values are present
1919
+ # Initialize or update running average directly
1920
+ # Normal loss details
1921
+ if float(normal_loss) > 0:
1922
+ details[f"{self_name}_normal_view{i + 1}"] = float(normal_loss)
1923
+ normal_avg_key = f"{self_name}_normal_avg"
1924
+ if normal_avg_key not in details:
1925
+ details[normal_avg_key] = float(normal_losses[i])
1926
+ running_avg_dict[f"{self_name}_normal_valid_views"] = 1
1927
+ else:
1928
+ normal_valid_views = (
1929
+ running_avg_dict[f"{self_name}_normal_valid_views"] + 1
1930
+ )
1931
+ running_avg_dict[f"{self_name}_normal_valid_views"] = (
1932
+ normal_valid_views
1933
+ )
1934
+ details[normal_avg_key] += (
1935
+ float(normal_losses[i]) - details[normal_avg_key]
1936
+ ) / normal_valid_views
1937
+
1938
+ # Gradient Matching loss details
1939
+ if float(gradient_matching_loss) > 0:
1940
+ details[f"{self_name}_gradient_matching_view{i + 1}"] = float(
1941
+ gradient_matching_loss
1942
+ )
1943
+ # For gradient matching loss
1944
+ gm_avg_key = f"{self_name}_gradient_matching_avg"
1945
+ if gm_avg_key not in details:
1946
+ details[gm_avg_key] = float(gradient_matching_losses[i])
1947
+ running_avg_dict[f"{self_name}_gm_valid_views"] = 1
1948
+ else:
1949
+ gm_valid_views = running_avg_dict[f"{self_name}_gm_valid_views"] + 1
1950
+ running_avg_dict[f"{self_name}_gm_valid_views"] = gm_valid_views
1951
+ details[gm_avg_key] += (
1952
+ float(gradient_matching_losses[i]) - details[gm_avg_key]
1953
+ ) / gm_valid_views
1954
+
1955
+ # Put the losses together
1956
+ loss_terms = []
1957
+ for i in range(n_views):
1958
+ loss_terms.append((normal_losses[i], None, "normal"))
1959
+ loss_terms.append((gradient_matching_losses[i], None, "gradient_matching"))
1960
+ losses = Sum(*loss_terms)
1961
+
1962
+ return losses, details
1963
+
1964
+
1965
  class FactoredGeometryRegr3D(Criterion, MultiLoss):
1966
  """
1967
  Regression Loss for Factored Geometry.
 
1983
  pose_quats_loss_weight=1,
1984
  pose_trans_loss_weight=1,
1985
  compute_pairwise_relative_pose_loss=False,
1986
+ convert_predictions_to_view0_frame=False,
1987
  compute_world_frame_points_loss=True,
1988
  world_frame_points_loss_weight=1,
1989
  ):
 
2018
  pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
2019
  compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
2020
  exhaustive pairwise relative poses. Default: False.
2021
+ convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
2022
+ Use this if the predictions are not already in the view0 frame. Default: False.
2023
  compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
2024
  world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
2025
  """
 
2046
  self.pose_quats_loss_weight = pose_quats_loss_weight
2047
  self.pose_trans_loss_weight = pose_trans_loss_weight
2048
  self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
2049
+ self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame
2050
  self.compute_world_frame_points_loss = compute_world_frame_points_loss
2051
  self.world_frame_points_loss_weight = world_frame_points_loss_weight
2052
 
 
2069
  gt_ray_directions = []
2070
  gt_pose_quats = []
2071
  # Predicted quantities
2072
+ if self.convert_predictions_to_view0_frame:
2073
+ # Get the camera transform to convert quantities to view0 frame
2074
+ pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze(
2075
+ 0
2076
+ )
2077
+ batch_size = preds[0]["cam_quats"].shape[0]
2078
+ pred_camera0 = pred_camera0.repeat(batch_size, 1, 1)
2079
+ pred_camera0_rot = quaternion_to_rotation_matrix(
2080
+ preds[0]["cam_quats"].clone()
2081
+ )
2082
+ pred_camera0[..., :3, :3] = pred_camera0_rot
2083
+ pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone()
2084
+ pred_in_camera0 = closed_form_pose_inverse(pred_camera0)
2085
  no_norm_pr_pts = []
2086
  no_norm_pr_pts_cam = []
2087
  no_norm_pr_depth = []
 
2135
  gt_pose_quats.append(gt_pose_quats_in_view0)
2136
  no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
2137
 
2138
+ # Get the local predictions
 
2139
  no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"])
2140
  pr_ray_directions.append(preds[i]["ray_directions"])
2141
  if self.depth_type_for_loss == "depth_along_ray":
2142
  no_norm_pr_depth.append(preds[i]["depth_along_ray"])
2143
  elif self.depth_type_for_loss == "depth_z":
2144
  no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:])
2145
+
2146
+ # Get the predicted global predictions in view0's frame
2147
+ if self.convert_predictions_to_view0_frame:
2148
+ # Convert predictions to view0 frame
2149
+ pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"])
2150
+ pr_pose_quats_in_view0, pr_pose_trans_in_view0 = (
2151
+ transform_pose_using_quats_and_trans_2_to_1(
2152
+ preds[0]["cam_quats"],
2153
+ preds[0]["cam_trans"],
2154
+ preds[i]["cam_quats"],
2155
+ preds[i]["cam_trans"],
2156
+ )
2157
+ )
2158
+ no_norm_pr_pts.append(pr_pts3d_in_view0)
2159
+ no_norm_pr_pose_trans.append(pr_pose_trans_in_view0)
2160
+ pr_pose_quats.append(pr_pose_quats_in_view0)
2161
+ else:
2162
+ # Predictions are already in view0 frame
2163
+ no_norm_pr_pts.append(preds[i]["pts3d"])
2164
+ no_norm_pr_pose_trans.append(preds[i]["cam_trans"])
2165
+ pr_pose_quats.append(preds[i]["cam_quats"])
2166
 
2167
  if dist_clip is not None:
2168
  # Points that are too far-away == invalid
 
2674
  pose_quats_loss_weight=1,
2675
  pose_trans_loss_weight=1,
2676
  compute_pairwise_relative_pose_loss=False,
2677
+ convert_predictions_to_view0_frame=False,
2678
  compute_world_frame_points_loss=True,
2679
  world_frame_points_loss_weight=1,
2680
  apply_normal_and_gm_loss_to_synthetic_data_only=True,
 
2710
  pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
2711
  compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
2712
  exhaustive pairwise relative poses. Default: False.
2713
+ convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
2714
+ Use this if the predictions are not already in the view0 frame. Default: False.
2715
  compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
2716
  world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
2717
  apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
 
2734
  pose_quats_loss_weight=pose_quats_loss_weight,
2735
  pose_trans_loss_weight=pose_trans_loss_weight,
2736
  compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
2737
+ convert_predictions_to_view0_frame=convert_predictions_to_view0_frame,
2738
  compute_world_frame_points_loss=compute_world_frame_points_loss,
2739
  world_frame_points_loss_weight=world_frame_points_loss_weight,
2740
  )
 
3130
  pose_trans_loss_weight=1,
3131
  scale_loss_weight=1,
3132
  compute_pairwise_relative_pose_loss=False,
3133
+ convert_predictions_to_view0_frame=False,
3134
  compute_world_frame_points_loss=True,
3135
  world_frame_points_loss_weight=1,
3136
  ):
 
3164
  scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
3165
  compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
3166
  exhaustive pairwise relative poses. Default: False.
3167
+ convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
3168
+ Use this if the predictions are not already in the view0 frame. Default: False.
3169
  compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
3170
  world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
3171
  """
 
3186
  self.pose_trans_loss_weight = pose_trans_loss_weight
3187
  self.scale_loss_weight = scale_loss_weight
3188
  self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
3189
+ self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame
3190
  self.compute_world_frame_points_loss = compute_world_frame_points_loss
3191
  self.world_frame_points_loss_weight = world_frame_points_loss_weight
3192
 
 
3209
  gt_ray_directions = []
3210
  gt_pose_quats = []
3211
  # Predicted quantities
3212
+ if self.convert_predictions_to_view0_frame:
3213
+ # Get the camera transform to convert quantities to view0 frame
3214
+ pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze(
3215
+ 0
3216
+ )
3217
+ batch_size = preds[0]["cam_quats"].shape[0]
3218
+ pred_camera0 = pred_camera0.repeat(batch_size, 1, 1)
3219
+ pred_camera0_rot = quaternion_to_rotation_matrix(
3220
+ preds[0]["cam_quats"].clone()
3221
+ )
3222
+ pred_camera0[..., :3, :3] = pred_camera0_rot
3223
+ pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone()
3224
+ pred_in_camera0 = closed_form_pose_inverse(pred_camera0)
3225
  no_norm_pr_pts = []
3226
  no_norm_pr_pts_cam = []
3227
  no_norm_pr_depth = []
 
3276
  gt_pose_quats.append(gt_pose_quats_in_view0)
3277
  no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
3278
 
3279
+ # Get the global predictions in view0's frame
3280
+ if self.convert_predictions_to_view0_frame:
3281
+ # Convert predictions to view0 frame
3282
+ pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"])
3283
+ pr_pose_quats_in_view0, pr_pose_trans_in_view0 = (
3284
+ transform_pose_using_quats_and_trans_2_to_1(
3285
+ preds[0]["cam_quats"],
3286
+ preds[0]["cam_trans"],
3287
+ preds[i]["cam_quats"],
3288
+ preds[i]["cam_trans"],
3289
+ )
3290
+ )
3291
+ else:
3292
+ # Predictions are already in view0 frame
3293
+ pr_pts3d_in_view0 = preds[i]["pts3d"]
3294
+ pr_pose_trans_in_view0 = preds[i]["cam_trans"]
3295
+ pr_pose_quats_in_view0 = preds[i]["cam_quats"]
3296
+
3297
  # Get predictions for normalized loss
3298
  if self.depth_type_for_loss == "depth_along_ray":
3299
  curr_view_no_norm_depth = preds[i]["depth_along_ray"]
 
3302
  if "metric_scaling_factor" in preds[i].keys():
3303
  # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
3304
  # This detaches the predicted metric scaling factor from the geometry based loss
3305
+ curr_view_no_norm_pr_pts = pr_pts3d_in_view0 / preds[i][
3306
  "metric_scaling_factor"
3307
  ].unsqueeze(-1).unsqueeze(-1)
3308
  curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
 
3312
  "metric_scaling_factor"
3313
  ].unsqueeze(-1).unsqueeze(-1)
3314
  curr_view_no_norm_pr_pose_trans = (
3315
+ pr_pose_trans_in_view0 / preds[i]["metric_scaling_factor"]
3316
  )
3317
  else:
3318
+ curr_view_no_norm_pr_pts = pr_pts3d_in_view0
3319
  curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
3320
  curr_view_no_norm_depth = curr_view_no_norm_depth
3321
+ curr_view_no_norm_pr_pose_trans = pr_pose_trans_in_view0
3322
  no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
3323
  no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
3324
  no_norm_pr_depth.append(curr_view_no_norm_depth)
3325
  no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
3326
  pr_ray_directions.append(preds[i]["ray_directions"])
3327
+ pr_pose_quats.append(pr_pose_quats_in_view0)
3328
 
3329
  # Get the predicted metric scale points
3330
  if "metric_scaling_factor" in preds[i].keys():
 
3823
  pose_trans_loss_weight=1,
3824
  scale_loss_weight=1,
3825
  compute_pairwise_relative_pose_loss=False,
3826
+ convert_predictions_to_view0_frame=False,
3827
  compute_world_frame_points_loss=True,
3828
  world_frame_points_loss_weight=1,
3829
  apply_normal_and_gm_loss_to_synthetic_data_only=True,
 
3856
  scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
3857
  compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
3858
  exhaustive pairwise relative poses. Default: False.
3859
+ convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
3860
+ Use this if the predictions are not already in the view0 frame. Default: False.
3861
  compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
3862
  world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
3863
  apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
 
3880
  pose_trans_loss_weight=pose_trans_loss_weight,
3881
  scale_loss_weight=scale_loss_weight,
3882
  compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
3883
+ convert_predictions_to_view0_frame=convert_predictions_to_view0_frame,
3884
  compute_world_frame_points_loss=compute_world_frame_points_loss,
3885
  world_frame_points_loss_weight=world_frame_points_loss_weight,
3886
  )
mapanything/utils/geometry.py CHANGED
@@ -10,6 +10,7 @@ from typing import Tuple, Union
10
  import einops as ein
11
  import numpy as np
12
  import torch
 
13
 
14
  from mapanything.utils.misc import invalid_to_zeros
15
  from mapanything.utils.warnings import no_warnings
@@ -646,6 +647,96 @@ def quaternion_to_rotation_matrix(quat):
646
  return rot_matrix
647
 
648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  def quaternion_inverse(quat):
650
  """
651
  Compute the inverse of a quaternion.
 
10
  import einops as ein
11
  import numpy as np
12
  import torch
13
+ import torch.nn.functional as F
14
 
15
  from mapanything.utils.misc import invalid_to_zeros
16
  from mapanything.utils.warnings import no_warnings
 
647
  return rot_matrix
648
 
649
 
650
+ def rotation_matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
651
+ """
652
+ Convert rotations given as rotation matrices to quaternions.
653
+
654
+ Args:
655
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
656
+
657
+ Returns:
658
+ quaternions with real part last, as tensor of shape (..., 4).
659
+ Quaternion Order: XYZW or say ijkr, scalar-last
660
+ """
661
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
662
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
663
+
664
+ batch_dim = matrix.shape[:-2]
665
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
666
+ matrix.reshape(batch_dim + (9,)), dim=-1
667
+ )
668
+
669
+ q_abs = _sqrt_positive_part(
670
+ torch.stack(
671
+ [
672
+ 1.0 + m00 + m11 + m22,
673
+ 1.0 + m00 - m11 - m22,
674
+ 1.0 - m00 + m11 - m22,
675
+ 1.0 - m00 - m11 + m22,
676
+ ],
677
+ dim=-1,
678
+ )
679
+ )
680
+
681
+ # we produce the desired quaternion multiplied by each of r, i, j, k
682
+ quat_by_rijk = torch.stack(
683
+ [
684
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
685
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
686
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
687
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
688
+ ],
689
+ dim=-2,
690
+ )
691
+
692
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
693
+ # the candidate won't be picked.
694
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
695
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
696
+
697
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
698
+ # forall i; we pick the best-conditioned one (with the largest denominator)
699
+ out = quat_candidates[
700
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
701
+ ].reshape(batch_dim + (4,))
702
+
703
+ # Convert from rijk to ijkr
704
+ out = out[..., [1, 2, 3, 0]]
705
+
706
+ out = standardize_quaternion(out)
707
+
708
+ return out
709
+
710
+
711
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
712
+ """
713
+ Returns torch.sqrt(torch.max(0, x))
714
+ but with a zero subgradient where x is 0.
715
+ """
716
+ ret = torch.zeros_like(x)
717
+ positive_mask = x > 0
718
+ if torch.is_grad_enabled():
719
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
720
+ else:
721
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
722
+ return ret
723
+
724
+
725
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
726
+ """
727
+ Convert a unit quaternion to a standard form: one in which the real
728
+ part is non negative.
729
+
730
+ Args:
731
+ quaternions: Quaternions with real part last,
732
+ as tensor of shape (..., 4).
733
+
734
+ Returns:
735
+ Standardized quaternions as tensor of shape (..., 4).
736
+ """
737
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
738
+
739
+
740
  def quaternion_inverse(quat):
741
  """
742
  Compute the inverse of a quaternion.
mapanything/utils/image.py CHANGED
@@ -287,6 +287,17 @@ def load_images(
287
  f"Using target resolution {target_size[0]}x{target_size[1]} (W x H) for all images"
288
  )
289
 
 
 
 
 
 
 
 
 
 
 
 
290
  # Second pass: Resize all images to the same target size
291
  imgs = []
292
  for path, img, W1, H1 in loaded_images:
@@ -298,16 +309,6 @@ def load_images(
298
  if verbose:
299
  print(f" - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
300
 
301
- if norm_type in IMAGE_NORMALIZATION_DICT.keys():
302
- img_norm = IMAGE_NORMALIZATION_DICT[norm_type]
303
- ImgNorm = tvf.Compose(
304
- [tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)]
305
- )
306
- else:
307
- raise ValueError(
308
- f"Unknown image normalization type: {norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}"
309
- )
310
-
311
  imgs.append(
312
  dict(
313
  img=ImgNorm(img)[None],
 
287
  f"Using target resolution {target_size[0]}x{target_size[1]} (W x H) for all images"
288
  )
289
 
290
+ # Get the image normalization function based on the norm_type
291
+ if norm_type in IMAGE_NORMALIZATION_DICT.keys():
292
+ img_norm = IMAGE_NORMALIZATION_DICT[norm_type]
293
+ ImgNorm = tvf.Compose(
294
+ [tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)]
295
+ )
296
+ else:
297
+ raise ValueError(
298
+ f"Unknown image normalization type: {norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}"
299
+ )
300
+
301
  # Second pass: Resize all images to the same target size
302
  imgs = []
303
  for path, img, W1, H1 in loaded_images:
 
309
  if verbose:
310
  print(f" - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
311
 
 
 
 
 
 
 
 
 
 
 
312
  imgs.append(
313
  dict(
314
  img=ImgNorm(img)[None],
mapanything/utils/inference.py CHANGED
@@ -3,9 +3,43 @@ Inference utilities.
3
  """
4
 
5
  import warnings
 
6
 
 
7
  import torch
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def loss_of_one_batch_multi_view(
11
  batch,
@@ -84,3 +118,358 @@ def loss_of_one_batch_multi_view(
84
  result["loss"] = loss
85
 
86
  return result[ret] if ret else result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  import warnings
6
+ from typing import Any, Dict, List
7
 
8
+ import numpy as np
9
  import torch
10
 
11
+ from mapanything.utils.geometry import (
12
+ depth_edge,
13
+ get_rays_in_camera_frame,
14
+ normals_edge,
15
+ points_to_normals,
16
+ quaternion_to_rotation_matrix,
17
+ recover_pinhole_intrinsics_from_ray_directions,
18
+ rotation_matrix_to_quaternion,
19
+ )
20
+ from mapanything.utils.image import rgb
21
+
22
+ # Hard constraints - exactly what users can provide
23
+ ALLOWED_VIEW_KEYS = {
24
+ "img", # Required - input images
25
+ "data_norm_type", # Required - normalization type of the input images
26
+ "depth_z", # Optional - Z depth maps
27
+ "ray_directions", # Optional - ray directions in camera frame
28
+ "intrinsics", # Optional - pinhole camera intrinsics (conflicts with ray_directions)
29
+ "camera_poses", # Optional - camera poses
30
+ "is_metric_scale", # Optional - whether inputs are metric scale
31
+ "true_shape", # Optional - original image shape
32
+ "idx", # Optional - index of the view
33
+ "instance", # Optional - instance info of the view
34
+ }
35
+
36
+ REQUIRED_KEYS = {"img", "data_norm_type"}
37
+
38
+ # Define conflicting keys that cannot be used together
39
+ CONFLICTING_KEYS = [
40
+ ("intrinsics", "ray_directions") # Both represent camera projection
41
+ ]
42
+
43
 
44
  def loss_of_one_batch_multi_view(
45
  batch,
 
118
  result["loss"] = loss
119
 
120
  return result[ret] if ret else result
121
+
122
+
123
+ def validate_input_views_for_inference(
124
+ views: List[Dict[str, Any]],
125
+ ) -> List[Dict[str, Any]]:
126
+ """
127
+ Strict validation and preprocessing of input views.
128
+
129
+ Args:
130
+ views: List of view dictionaries
131
+
132
+ Returns:
133
+ Validated and preprocessed views
134
+
135
+ Raises:
136
+ ValueError: For invalid keys, missing required keys, conflicting inputs, or invalid camera pose constraints
137
+ """
138
+ # Ensure input is not empty
139
+ if not views:
140
+ raise ValueError("At least one view must be provided")
141
+
142
+ # Track which views have camera poses
143
+ views_with_poses = []
144
+
145
+ # Validate each view
146
+ for view_idx, view in enumerate(views):
147
+ # Check for invalid keys
148
+ provided_keys = set(view.keys())
149
+ invalid_keys = provided_keys - ALLOWED_VIEW_KEYS
150
+ if invalid_keys:
151
+ raise ValueError(
152
+ f"View {view_idx} contains invalid keys: {invalid_keys}. "
153
+ f"Allowed keys are: {sorted(ALLOWED_VIEW_KEYS)}"
154
+ )
155
+
156
+ # Check for missing required keys
157
+ missing_keys = REQUIRED_KEYS - provided_keys
158
+ if missing_keys:
159
+ raise ValueError(f"View {view_idx} missing required keys: {missing_keys}")
160
+
161
+ # Check for conflicting keys
162
+ for conflict_set in CONFLICTING_KEYS:
163
+ present_conflicts = [key for key in conflict_set if key in provided_keys]
164
+ if len(present_conflicts) > 1:
165
+ raise ValueError(
166
+ f"View {view_idx} contains conflicting keys: {present_conflicts}. "
167
+ f"Only one of {conflict_set} can be provided at a time."
168
+ )
169
+
170
+ # Check depth constraint: If depth is provided, intrinsics or ray_directions must also be provided
171
+ if "depth_z" in provided_keys:
172
+ if (
173
+ "intrinsics" not in provided_keys
174
+ and "ray_directions" not in provided_keys
175
+ ):
176
+ raise ValueError(
177
+ f"View {view_idx} depth constraint violation: If 'depth_z' is provided, "
178
+ f"then 'intrinsics' or 'ray_directions' must also be provided. "
179
+ f"Z Depth values require camera calibration information to be meaningful for an image."
180
+ )
181
+
182
+ # Track views with camera poses
183
+ if "camera_poses" in provided_keys:
184
+ views_with_poses.append(view_idx)
185
+
186
+ # Cross-view constraint: If any view has camera_poses, view 0 must have them too
187
+ if views_with_poses and 0 not in views_with_poses:
188
+ raise ValueError(
189
+ f"Camera pose constraint violation: Views {views_with_poses} have camera_poses, "
190
+ f"but view 0 (reference view) does not. When using camera_poses, the first view "
191
+ f"must also provide camera_poses to serve as the reference frame."
192
+ )
193
+
194
+ return views
195
+
196
+
197
+ def preprocess_input_views_for_inference(
198
+ views: List[Dict[str, Any]],
199
+ ) -> List[Dict[str, Any]]:
200
+ """
201
+ Pre-process input views to match the expected internal input format.
202
+
203
+ The following steps are performed:
204
+ 1. Convert intrinsics to ray directions when required. If ray directions are already provided, unit normalize them.
205
+ 2. Convert depth_z to depth_along_ray
206
+ 3. Convert camera_poses to the expected input keys (camera_pose_quats and camera_pose_trans)
207
+ 4. Default is_metric_scale to True when not provided
208
+
209
+ Args:
210
+ views: List of view dictionaries
211
+
212
+ Returns:
213
+ Preprocessed views with consistent internal format
214
+ """
215
+ processed_views = []
216
+
217
+ for view_idx, view in enumerate(views):
218
+ # Copy the view dictionary to avoid modifying the original input
219
+ processed_view = dict(view)
220
+
221
+ # Step 1: Convert intrinsics to ray_directions when required. If ray_directions are provided, unit normalize them.
222
+ if "intrinsics" in view:
223
+ images = view["img"]
224
+ height, width = images.shape[-2:]
225
+ intrinsics = view["intrinsics"]
226
+ _, ray_directions = get_rays_in_camera_frame(
227
+ intrinsics=intrinsics,
228
+ height=height,
229
+ width=width,
230
+ normalize_to_unit_sphere=True,
231
+ )
232
+ processed_view["ray_directions"] = ray_directions
233
+ del processed_view["intrinsics"]
234
+ elif "ray_directions" in view:
235
+ ray_directions = view["ray_directions"]
236
+ ray_norm = torch.norm(ray_directions, dim=-1, keepdim=True)
237
+ processed_view["ray_directions"] = ray_directions / (ray_norm + 1e-8)
238
+
239
+ # Step 2: Convert depth_z to depth_along_ray
240
+ if "depth_z" in view:
241
+ depth_z = view["depth_z"]
242
+ ray_directions = processed_view["ray_directions"]
243
+ ray_directions_unit_plane = ray_directions / ray_directions[..., 2:3]
244
+ pts3d_cam = depth_z * ray_directions_unit_plane
245
+ depth_along_ray = torch.norm(pts3d_cam, dim=-1, keepdim=True)
246
+ processed_view["depth_along_ray"] = depth_along_ray
247
+ del processed_view["depth_z"]
248
+
249
+ # Step 3: Convert camera_poses to expected input keys
250
+ if "camera_poses" in view:
251
+ camera_poses = view["camera_poses"]
252
+ if isinstance(camera_poses, tuple) and len(camera_poses) == 2:
253
+ quats, trans = camera_poses
254
+ processed_view["camera_pose_quats"] = quats
255
+ processed_view["camera_pose_trans"] = trans
256
+ elif torch.is_tensor(camera_poses) and camera_poses.shape[-2:] == (4, 4):
257
+ rotation_matrices = camera_poses[:, :3, :3]
258
+ translation_vectors = camera_poses[:, :3, 3]
259
+ quats = rotation_matrix_to_quaternion(rotation_matrices)
260
+ processed_view["camera_pose_quats"] = quats
261
+ processed_view["camera_pose_trans"] = translation_vectors
262
+ else:
263
+ raise ValueError(
264
+ f"View {view_idx}: camera_poses must be either a tuple of (quats, trans) "
265
+ f"or a tensor of (B, 4, 4) transformation matrices."
266
+ )
267
+ del processed_view["camera_poses"]
268
+
269
+ # Step 4: Default is_metric_scale to True when not provided
270
+ if "is_metric_scale" not in processed_view:
271
+ # Get batch size from the image tensor
272
+ batch_size = view["img"].shape[0]
273
+ # Default to True for all samples in the batch
274
+ processed_view["is_metric_scale"] = torch.ones(
275
+ batch_size, dtype=torch.bool, device=view["img"].device
276
+ )
277
+
278
+ # Rename keys to match expected model input format
279
+ if "ray_directions" in processed_view:
280
+ processed_view["ray_directions_cam"] = processed_view["ray_directions"]
281
+ del processed_view["ray_directions"]
282
+
283
+ # Append the processed view to the list
284
+ processed_views.append(processed_view)
285
+
286
+ return processed_views
287
+
288
+
289
+ def postprocess_model_outputs_for_inference(
290
+ raw_outputs: List[Dict[str, torch.Tensor]],
291
+ input_views: List[Dict[str, Any]],
292
+ apply_mask: bool = True,
293
+ mask_edges: bool = True,
294
+ edge_normal_threshold: float = 5.0,
295
+ edge_depth_threshold: float = 0.03,
296
+ apply_confidence_mask: bool = False,
297
+ confidence_percentile: float = 10,
298
+ ) -> List[Dict[str, torch.Tensor]]:
299
+ """
300
+ Post-process raw model outputs by copying raw outputs and adding essential derived fields.
301
+
302
+ This function simplifies the raw model outputs by:
303
+ 1. Copying all raw outputs as-is
304
+ 2. Adding denormalized images (img_no_norm)
305
+ 3. Adding Z depth (depth_z) from camera frame points
306
+ 4. Recovering pinhole camera intrinsics from ray directions
307
+ 5. Adding camera pose matrices (camera_poses) if pose data is available
308
+ 6. Applying mask to dense geometry outputs if requested (supports edge masking and confidence masking)
309
+
310
+ Args:
311
+ raw_outputs: List of raw model output dictionaries, one per view
312
+ input_views: List of original input view dictionaries, one per view
313
+ apply_mask: Whether to apply non-ambiguous mask to dense outputs. Defaults to True.
314
+ mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
315
+ apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False.
316
+ confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10.
317
+
318
+ Returns:
319
+ List of processed output dictionaries containing:
320
+ - All original raw outputs (after masking dense geometry outputs if requested)
321
+ - 'img_no_norm': Denormalized RGB images (B, H, W, 3)
322
+ - 'depth_z': Z depth from camera frame (B, H, W, 1) if points in camera frame available
323
+ - 'intrinsics': Recovered pinhole camera intrinsics (B, 3, 3) if ray directions available
324
+ - 'camera_poses': 4x4 pose matrices (B, 4, 4) if pose data available
325
+ - 'mask': comprehensive mask for dense geometry outputs (B, H, W, 1) if requested
326
+
327
+ """
328
+ processed_outputs = []
329
+
330
+ for view_idx, (raw_output, original_view) in enumerate(
331
+ zip(raw_outputs, input_views)
332
+ ):
333
+ # Start by copying all raw outputs
334
+ processed_output = dict(raw_output)
335
+
336
+ # 1. Add denormalized images
337
+ img = original_view["img"] # Shape: (B, 3, H, W)
338
+ data_norm_type = original_view["data_norm_type"][0]
339
+ img_hwc = rgb(img, data_norm_type)
340
+
341
+ # Convert numpy back to torch if needed (rgb returns numpy)
342
+ if isinstance(img_hwc, np.ndarray):
343
+ img_hwc = torch.from_numpy(img_hwc).to(img.device)
344
+
345
+ processed_output["img_no_norm"] = img_hwc
346
+
347
+ # 2. Add Z depth if we have camera frame points
348
+ if "pts3d_cam" in processed_output:
349
+ processed_output["depth_z"] = processed_output["pts3d_cam"][..., 2:3]
350
+
351
+ # 3. Recover pinhole camera intrinsics from ray directions if available
352
+ if "ray_directions" in processed_output:
353
+ intrinsics = recover_pinhole_intrinsics_from_ray_directions(
354
+ processed_output["ray_directions"]
355
+ )
356
+ processed_output["intrinsics"] = intrinsics
357
+
358
+ # 4. Add camera pose matrices if both translation and quaternions are available
359
+ if "cam_trans" in processed_output and "cam_quats" in processed_output:
360
+ cam_trans = processed_output["cam_trans"] # (B, 3)
361
+ cam_quats = processed_output["cam_quats"] # (B, 4)
362
+ batch_size = cam_trans.shape[0]
363
+
364
+ # Convert quaternions to rotation matrices
365
+ rotation_matrices = quaternion_to_rotation_matrix(cam_quats) # (B, 3, 3)
366
+
367
+ # Create 4x4 pose matrices
368
+ pose_matrices = (
369
+ torch.eye(4, device=img.device).unsqueeze(0).repeat(batch_size, 1, 1)
370
+ )
371
+ pose_matrices[:, :3, :3] = rotation_matrices
372
+ pose_matrices[:, :3, 3] = cam_trans
373
+
374
+ processed_output["camera_poses"] = pose_matrices # (B, 4, 4)
375
+
376
+ # 5. Apply comprehensive mask to dense geometry outputs if requested
377
+ if apply_mask:
378
+ final_mask = None
379
+
380
+ # Start with non-ambiguous mask if available
381
+ if "non_ambiguous_mask" in processed_output:
382
+ non_ambiguous_mask = (
383
+ processed_output["non_ambiguous_mask"].cpu().numpy()
384
+ ) # (B, H, W)
385
+ final_mask = non_ambiguous_mask
386
+
387
+ # Apply confidence mask if requested and available
388
+ if apply_confidence_mask and "conf" in processed_output:
389
+ confidences = processed_output["conf"].cpu() # (B, H, W)
390
+ # Compute percentile threshold for each batch element
391
+ batch_size = confidences.shape[0]
392
+ conf_mask = torch.zeros_like(confidences, dtype=torch.bool)
393
+ percentile_threshold = (
394
+ torch.quantile(
395
+ confidences.reshape(batch_size, -1),
396
+ confidence_percentile / 100.0,
397
+ dim=1,
398
+ )
399
+ .unsqueeze(-1)
400
+ .unsqueeze(-1)
401
+ ) # Shape: (B, 1, 1)
402
+
403
+ # Compute mask for each batch element
404
+ conf_mask = confidences > percentile_threshold
405
+ conf_mask = conf_mask.numpy()
406
+
407
+ if final_mask is not None:
408
+ final_mask = final_mask & conf_mask
409
+ else:
410
+ final_mask = conf_mask
411
+
412
+ # Apply edge mask if requested and we have the required data
413
+ if mask_edges and final_mask is not None and "pts3d" in processed_output:
414
+ # Get 3D points for edge computation
415
+ pred_pts3d = processed_output["pts3d"].cpu().numpy() # (B, H, W, 3)
416
+ batch_size, height, width = final_mask.shape
417
+
418
+ edge_masks = []
419
+ for b in range(batch_size):
420
+ batch_final_mask = final_mask[b] # (H, W)
421
+ batch_pts3d = pred_pts3d[b] # (H, W, 3)
422
+
423
+ if batch_final_mask.any(): # Only compute if we have valid points
424
+ # Compute normals and normal-based edge mask
425
+ normals, normals_mask = points_to_normals(
426
+ batch_pts3d, mask=batch_final_mask
427
+ )
428
+ normal_edges = normals_edge(
429
+ normals, tol=edge_normal_threshold, mask=normals_mask
430
+ )
431
+
432
+ # Compute depth-based edge mask
433
+ depth_z = (
434
+ processed_output["depth_z"][b].squeeze(-1).cpu().numpy()
435
+ )
436
+ depth_edges = depth_edge(
437
+ depth_z, rtol=edge_depth_threshold, mask=batch_final_mask
438
+ )
439
+
440
+ # Combine both edge types
441
+ edge_mask = ~(depth_edges & normal_edges)
442
+ edge_masks.append(edge_mask)
443
+ else:
444
+ # No valid points, keep all as invalid
445
+ edge_masks.append(np.zeros_like(batch_final_mask, dtype=bool))
446
+
447
+ # Stack batch edge masks and combine with final mask
448
+ edge_mask = np.stack(edge_masks, axis=0) # (B, H, W)
449
+ final_mask = final_mask & edge_mask
450
+
451
+ # Apply final mask to dense geometry outputs if we have a mask
452
+ if final_mask is not None:
453
+ # Convert mask to torch tensor
454
+ final_mask_torch = torch.from_numpy(final_mask).to(
455
+ processed_output["pts3d"].device
456
+ )
457
+ final_mask_torch = final_mask_torch.unsqueeze(-1) # (B, H, W, 1)
458
+
459
+ # Apply mask to dense geometry outputs (zero out invalid regions)
460
+ dense_geometry_keys = [
461
+ "pts3d",
462
+ "pts3d_cam",
463
+ "depth_along_ray",
464
+ "depth_z",
465
+ ]
466
+ for key in dense_geometry_keys:
467
+ if key in processed_output:
468
+ processed_output[key] = processed_output[key] * final_mask_torch
469
+
470
+ # Add mask to processed output
471
+ processed_output["mask"] = final_mask_torch
472
+
473
+ processed_outputs.append(processed_output)
474
+
475
+ return processed_outputs
mapanything/utils/viz.py CHANGED
@@ -110,7 +110,7 @@ def script_add_rerun_args(parser: ArgumentParser) -> None:
110
  parser.add_argument(
111
  "--url",
112
  type=str,
113
- default="rerun+http://127.0.0.1:9081/proxy",
114
  help="Connect to this HTTP(S) URL",
115
  )
116
  parser.add_argument(
@@ -129,7 +129,7 @@ def init_rerun_args(
129
  headless=True,
130
  connect=True,
131
  serve=False,
132
- url="rerun+http://127.0.0.1:9081/proxy",
133
  save=None,
134
  stdout=False,
135
  ) -> Namespace:
 
110
  parser.add_argument(
111
  "--url",
112
  type=str,
113
+ default="rerun+http://127.0.0.1:2004/proxy",
114
  help="Connect to this HTTP(S) URL",
115
  )
116
  parser.add_argument(
 
129
  headless=True,
130
  connect=True,
131
  serve=False,
132
+ url="rerun+http://127.0.0.1:2004/proxy",
133
  save=None,
134
  stdout=False,
135
  ) -> Namespace:
mapanything/utils/wai/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ This utils module contains PORTAGE of wai-core scripts/methods for MapAnything.
3
+ """
mapanything/utils/wai/basic_dataset.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any
3
+
4
+ import torch
5
+ from box import Box
6
+
7
+ from mapanything.utils.wai.core import get_frame_index, load_data, load_frame
8
+ from mapanything.utils.wai.ops import stack
9
+ from mapanything.utils.wai.scene_frame import get_scene_frame_names
10
+
11
+
12
+ class BasicSceneframeDataset(torch.utils.data.Dataset):
13
+ """Basic wai dataset to iterative over frames of scenes"""
14
+
15
+ @staticmethod
16
+ def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
17
+ return stack(batch)
18
+
19
+ def __init__(
20
+ self,
21
+ cfg: Box,
22
+ ):
23
+ """
24
+ Initialize the BasicSceneframeDataset.
25
+
26
+ Args:
27
+ cfg (Box): Configuration object containing dataset parameters including:
28
+ - root: Root directory containing scene data
29
+ - frame_modalities: List of modalities to load for each frame
30
+ - key_remap: Optional dictionary mapping original keys to new keys
31
+ """
32
+ super().__init__()
33
+ self.cfg = cfg
34
+ self.root = cfg.root
35
+ keyframes = cfg.get("use_keyframes", True)
36
+ self.scene_frame_names = get_scene_frame_names(cfg, keyframes=keyframes)
37
+ self.scene_frame_list = [
38
+ (scene_name, frame_name)
39
+ for scene_name, frame_names in self.scene_frame_names.items()
40
+ for frame_name in frame_names
41
+ ]
42
+ self._scene_cache = {}
43
+
44
+ def __len__(self):
45
+ """
46
+ Get the total number of scene-frame pairs in the dataset.
47
+
48
+ Returns:
49
+ int: The number of scene-frame pairs.
50
+ """
51
+ return len(self.scene_frame_list)
52
+
53
+ def _load_scene(self, scene_name: str) -> dict[str, Any]:
54
+ """
55
+ Load scene data for a given scene name.
56
+
57
+ Args:
58
+ scene_name (str): The name of the scene to load.
59
+
60
+ Returns:
61
+ dict: A dictionary containing scene data, including scene metadata.
62
+ """
63
+ # load scene data
64
+ scene_data = {}
65
+ scene_data["meta"] = load_data(
66
+ Path(
67
+ self.root,
68
+ scene_name,
69
+ self.cfg.get("scene_meta_path", "scene_meta.json"),
70
+ ),
71
+ "scene_meta",
72
+ )
73
+
74
+ return scene_data
75
+
76
+ def _load_scene_frame(
77
+ self, scene_name: str, frame_name: str | float
78
+ ) -> dict[str, Any]:
79
+ """
80
+ Load data for a specific frame from a specific scene.
81
+
82
+ This method loads scene data if not already cached, then loads the specified frame
83
+ from that scene with the modalities specified in the configuration.
84
+
85
+ Args:
86
+ scene_name (str): The name of the scene containing the frame.
87
+ frame_name (str or float): The name/timestamp of the frame to load.
88
+
89
+ Returns:
90
+ dict: A dictionary containing the loaded frame data with requested modalities.
91
+ """
92
+ scene_frame_data = {}
93
+ if not (scene_data := self._scene_cache.get(scene_name)):
94
+ scene_data = self._load_scene(scene_name)
95
+ # for now only cache the last scene
96
+ self._scene_cache = {}
97
+ self._scene_cache[scene_name] = scene_data
98
+
99
+ frame_idx = get_frame_index(scene_data["meta"], frame_name)
100
+
101
+ scene_frame_data["scene_name"] = scene_name
102
+ scene_frame_data["frame_name"] = frame_name
103
+ scene_frame_data["scene_path"] = str(Path(self.root, scene_name))
104
+ scene_frame_data["frame_idx"] = frame_idx
105
+ scene_frame_data.update(
106
+ load_frame(
107
+ Path(self.root, scene_name),
108
+ frame_name,
109
+ modalities=self.cfg.frame_modalities,
110
+ scene_meta=scene_data["meta"],
111
+ )
112
+ )
113
+ # Remap key names
114
+ for key, new_key in self.cfg.get("key_remap", {}).items():
115
+ if key in scene_frame_data:
116
+ scene_frame_data[new_key] = scene_frame_data.pop(key)
117
+
118
+ return scene_frame_data
119
+
120
+ def __getitem__(self, index: int) -> dict[str, Any]:
121
+ """
122
+ Get a specific scene-frame pair by index.
123
+
124
+ Args:
125
+ index (int): The index of the scene-frame pair to retrieve.
126
+
127
+ Returns:
128
+ dict: A dictionary containing the loaded frame data with requested modalities.
129
+ """
130
+ scene_frame = self._load_scene_frame(*self.scene_frame_list[index])
131
+ return scene_frame
mapanything/utils/wai/camera.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This utils script contains PORTAGE of wai-core camera methods for MapAnything.
3
+ """
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ from scipy.spatial.transform import Rotation, Slerp
10
+
11
+ from mapanything.utils.wai.ops import get_dtype_device
12
+
13
+ # constants regarding camera models
14
+ PINHOLE_CAM_KEYS = ["fl_x", "fl_y", "cx", "cy", "h", "w"]
15
+ DISTORTION_PARAM_KEYS = [
16
+ "k1",
17
+ "k2",
18
+ "k3",
19
+ "k4",
20
+ "p1",
21
+ "p2",
22
+ ] # order corresponds to the OpenCV convention
23
+ CAMERA_KEYS = PINHOLE_CAM_KEYS + DISTORTION_PARAM_KEYS
24
+
25
+
26
+ def interpolate_intrinsics(
27
+ frame1: dict[str, Any],
28
+ frame2: dict[str, Any],
29
+ alpha: float,
30
+ ) -> dict[str, Any]:
31
+ """
32
+ Interpolate camera intrinsics linearly.
33
+ Args:
34
+ frame1: The first frame dictionary.
35
+ frame2: The second frame dictionary.
36
+ alpha: Interpolation parameter. alpha = 0 for frame1, alpha = 1 for frame2.
37
+ Returns:
38
+ frame_inter: dictionary with new intrinsics.
39
+ """
40
+ frame_inter = {}
41
+ for key in CAMERA_KEYS:
42
+ if key in frame1 and key in frame2:
43
+ p1 = frame1[key]
44
+ p2 = frame2[key]
45
+ frame_inter[key] = (1 - alpha) * p1 + alpha * p2
46
+ return frame_inter
47
+
48
+
49
+ def interpolate_extrinsics(
50
+ matrix1: list | np.ndarray | torch.Tensor,
51
+ matrix2: list | np.ndarray | torch.Tensor,
52
+ alpha: float,
53
+ ) -> list | np.ndarray | torch.Tensor:
54
+ """
55
+ Interpolate camera extrinsics 4x4 matrices using SLERP.
56
+ Args:
57
+ matrix1: The first matrix.
58
+ matrix2: The second matrix.
59
+ alpha: Interpolation parameter. alpha = 0 for matrix1, alpha = 1 for matrix2.
60
+ Returns:
61
+ matrix: 4x4 interpolated matrix, same type.
62
+ Raises:
63
+ ValueError: If different type.
64
+ """
65
+ if not isinstance(matrix1, type(matrix2)):
66
+ raise ValueError("Both matrices should have the same type.")
67
+
68
+ dtype, device = get_dtype_device(matrix1)
69
+ if isinstance(matrix1, list):
70
+ mtype = "list"
71
+ matrix1 = np.array(matrix1)
72
+ matrix2 = np.array(matrix2)
73
+ elif isinstance(matrix1, np.ndarray):
74
+ mtype = "numpy"
75
+ elif isinstance(matrix1, torch.Tensor):
76
+ mtype = "torch"
77
+ matrix1 = matrix1.numpy()
78
+ matrix2 = matrix2.numpy()
79
+ else:
80
+ raise ValueError(
81
+ "Only list, numpy array and torch tensors are supported as inputs."
82
+ )
83
+
84
+ R1 = matrix1[:3, :3]
85
+ t1 = matrix1[:3, 3]
86
+ R2 = matrix2[:3, :3]
87
+ t2 = matrix2[:3, 3]
88
+
89
+ # interpolate translation
90
+ t = (1 - alpha) * t1 + alpha * t2
91
+
92
+ # interpolate rotations with SLERP
93
+ R1_quat = Rotation.from_matrix(R1).as_quat()
94
+ R2_quat = Rotation.from_matrix(R2).as_quat()
95
+ rotation_slerp = Slerp([0, 1], Rotation(np.stack([R1_quat, R2_quat])))
96
+ R = rotation_slerp(alpha).as_matrix()
97
+ matrix_inter = np.eye(4)
98
+
99
+ # combine together
100
+ matrix_inter[:3, :3] = R
101
+ matrix_inter[:3, 3] = t
102
+
103
+ if mtype == "list":
104
+ matrix_inter = matrix_inter.tolist()
105
+ elif mtype == "torch":
106
+ matrix_inter = torch.from_numpy(matrix_inter).to(dtype).to(device)
107
+ elif mtype == "numpy":
108
+ matrix_inter = matrix_inter.astype(dtype)
109
+
110
+ return matrix_inter
111
+
112
+
113
+ def convert_camera_coeffs_to_pinhole_matrix(
114
+ scene_meta, frame, fmt="torch"
115
+ ) -> torch.Tensor | np.ndarray | list:
116
+ """
117
+ Convert camera intrinsics from NeRFStudio format to a 3x3 intrinsics matrix.
118
+
119
+ Args:
120
+ scene_meta: Scene metadata containing camera parameters
121
+ frame: Frame-specific camera parameters that override scene_meta
122
+
123
+ Returns:
124
+ torch.Tensor: 3x3 camera intrinsics matrix
125
+
126
+ Raises:
127
+ ValueError: If camera model is not PINHOLE or if distortion coefficients are present
128
+ """
129
+ # Check if camera model is supported
130
+ camera_model = frame.get("camera_model", scene_meta.get("camera_model"))
131
+ if camera_model != "PINHOLE":
132
+ raise ValueError("Only PINHOLE camera model supported")
133
+
134
+ # Check for unsupported distortion coefficients
135
+ if any(
136
+ (frame.get(coeff, 0) != 0) or (scene_meta.get(coeff, 0) != 0)
137
+ for coeff in DISTORTION_PARAM_KEYS
138
+ ):
139
+ raise ValueError(
140
+ "Pinhole camera does not support radial/tangential distortion -> Undistort first"
141
+ )
142
+
143
+ # Extract camera intrinsic parameters
144
+ camera_coeffs = {}
145
+ for coeff in ["fl_x", "fl_y", "cx", "cy"]:
146
+ camera_coeffs[coeff] = frame.get(coeff, scene_meta.get(coeff))
147
+ if camera_coeffs[coeff] is None:
148
+ raise ValueError(f"Missing required camera parameter: {coeff}")
149
+
150
+ # Create intrinsics matrix
151
+ intrinsics = [
152
+ [camera_coeffs["fl_x"], 0.0, camera_coeffs["cx"]],
153
+ [0.0, camera_coeffs["fl_y"], camera_coeffs["cy"]],
154
+ [0.0, 0.0, 1.0],
155
+ ]
156
+ if fmt == "torch":
157
+ intrinsics = torch.tensor(intrinsics)
158
+ elif fmt == "np":
159
+ intrinsics = np.array(intrinsics)
160
+
161
+ return intrinsics
162
+
163
+
164
+ def rotate_pinhole_90degcw(
165
+ W: int, H: int, fx: float, fy: float, cx: float, cy: float
166
+ ) -> tuple[int, int, float, float, float, float]:
167
+ """Rotates the intrinsics of a pinhole camera model by 90 degrees clockwise."""
168
+ W_new = H
169
+ H_new = W
170
+ fx_new = fy
171
+ fy_new = fx
172
+ cy_new = cx
173
+ cx_new = H - 1 - cy
174
+ return W_new, H_new, fx_new, fy_new, cx_new, cy_new
175
+
176
+
177
+ def _gl_cv_cmat() -> np.ndarray:
178
+ cmat = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
179
+ return cmat
180
+
181
+
182
+ def _apply_transformation(
183
+ c2ws: torch.Tensor | np.ndarray, cmat: np.ndarray
184
+ ) -> torch.Tensor | np.ndarray:
185
+ """
186
+ Convert camera poses using a provided conversion matrix.
187
+
188
+ Args:
189
+ c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4)
190
+ cmat (torch.Tensor or np.ndarray): Conversion matrix (4, 4)
191
+
192
+ Returns:
193
+ torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4)
194
+ """
195
+ if isinstance(c2ws, torch.Tensor):
196
+ # Clone the input tensor to avoid modifying it in-place
197
+ c2ws_transformed = c2ws.clone()
198
+ # Apply the conversion matrix to the rotation part of the camera poses
199
+ if len(c2ws.shape) == 3:
200
+ c2ws_transformed[:, :3, :3] = c2ws_transformed[
201
+ :, :3, :3
202
+ ] @ torch.from_numpy(cmat[:3, :3]).to(c2ws).unsqueeze(0)
203
+ else:
204
+ c2ws_transformed[:3, :3] = c2ws_transformed[:3, :3] @ torch.from_numpy(
205
+ cmat[:3, :3]
206
+ ).to(c2ws)
207
+
208
+ elif isinstance(c2ws, np.ndarray):
209
+ # Clone the input array to avoid modifying it in-place
210
+ c2ws_transformed = c2ws.copy()
211
+ if len(c2ws.shape) == 3: # batched
212
+ # Apply the conversion matrix to the rotation part of the camera poses
213
+ c2ws_transformed[:, :3, :3] = np.einsum(
214
+ "ijk,lk->ijl", c2ws_transformed[:, :3, :3], cmat[:3, :3]
215
+ )
216
+ else: # single 4x4 matrix
217
+ # Apply the conversion matrix to the rotation part of the camera pose
218
+ c2ws_transformed[:3, :3] = np.dot(c2ws_transformed[:3, :3], cmat[:3, :3])
219
+
220
+ else:
221
+ raise ValueError("Input data type not supported.")
222
+
223
+ return c2ws_transformed
224
+
225
+
226
+ def gl2cv(
227
+ c2ws: torch.Tensor | np.ndarray,
228
+ return_cmat: bool = False,
229
+ ) -> torch.Tensor | np.ndarray | tuple[torch.Tensor | np.ndarray, np.ndarray]:
230
+ """
231
+ Convert camera poses from OpenGL to OpenCV coordinate system.
232
+
233
+ Args:
234
+ c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4)
235
+ return_cmat (bool): If True, return the conversion matrix along with the transformed poses
236
+
237
+ Returns:
238
+ torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4)
239
+ np.ndarray (optional): Conversion matrix if return_cmat is True
240
+ """
241
+ cmat = _gl_cv_cmat()
242
+ if return_cmat:
243
+ return _apply_transformation(c2ws, cmat), cmat
244
+ return _apply_transformation(c2ws, cmat)
245
+
246
+
247
+ def intrinsics_to_fov(
248
+ fx: torch.Tensor, fy: torch.Tensor, h: torch.Tensor, w: torch.Tensor
249
+ ) -> tuple[torch.Tensor, torch.Tensor]:
250
+ """
251
+ Compute the horizontal and vertical fields of view in radians from camera intrinsics.
252
+
253
+ Args:
254
+ fx (torch.Tensor): focal x
255
+ fy (torch.Tensor): focal y
256
+ h (torch.Tensor): Image height(s) with shape (B,).
257
+ w (torch.Tensor): Image width(s) with shape (B,).
258
+
259
+ Returns:
260
+ tuple[torch.Tensor, torch.Tensor]: A tuple containing the horizontal and vertical fields
261
+ of view in radians, both with shape (N,).
262
+ """
263
+ return 2 * torch.atan((w / 2) / fx), 2 * torch.atan((h / 2) / fy)
mapanything/utils/wai/colormaps/colors_fps_5k.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fae94fe5fb565ff40d1c556ae2640d00fc068e732cb4af5bb64eef034790e07c
3
+ size 9478
mapanything/utils/wai/core.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This utils script contains PORTAGE of wai-core core methods for MapAnything.
3
+ """
4
+
5
+ import logging
6
+ import re
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from mapanything.utils.wai.camera import (
14
+ CAMERA_KEYS,
15
+ convert_camera_coeffs_to_pinhole_matrix,
16
+ interpolate_extrinsics,
17
+ interpolate_intrinsics,
18
+ )
19
+ from mapanything.utils.wai.io import _get_method, _load_scene_meta
20
+ from mapanything.utils.wai.ops import crop
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ WAI_COLORMAP_PATH = Path(__file__).parent / "colormaps"
25
+
26
+
27
+ def load_data(fname: str | Path, format_type: str | None = None, **kwargs) -> Any:
28
+ """
29
+ Loads data from a file using the appropriate method based on the file format.
30
+
31
+ Args:
32
+ fname (str or Path): The filename or path to load data from.
33
+ format_type (str, optional): The format type of the data. If None, it will be inferred from the file extension if possible.
34
+ Supported formats include: 'readable', 'scalar', 'image', 'binary', 'depth', 'normals',
35
+ 'numpy', 'ptz', 'mmap', 'scene_meta', 'labeled_image', 'mesh', 'labeled_mesh', 'caption', "latents".
36
+ **kwargs: Additional keyword arguments to pass to the loading method.
37
+
38
+ Returns:
39
+ The loaded data in the format returned by the specific loading method.
40
+
41
+ Raises:
42
+ ValueError: If the format cannot be inferred from the file extension.
43
+ NotImplementedError: If the specified format is not supported.
44
+ FileExistsError: If the file does not exist.
45
+ """
46
+ load_method = _get_method(fname, format_type, load=True)
47
+ return load_method(fname, **kwargs)
48
+
49
+
50
+ def store_data(
51
+ fname: str | Path,
52
+ data: Any,
53
+ format_type: str | None = None,
54
+ **kwargs,
55
+ ) -> Any:
56
+ """
57
+ Stores data to a file using the appropriate method based on the file format.
58
+
59
+ Args:
60
+ fname (str or Path): The filename or path to store data to.
61
+ data: The data to be stored.
62
+ format_type (str, optional): The format type of the data. If None, it will be inferred from the file extension.
63
+ **kwargs: Additional keyword arguments to pass to the storing method.
64
+
65
+ Returns:
66
+ The result of the storing method, which may vary depending on the method used.
67
+ """
68
+ store_method = _get_method(fname, format_type, load=False)
69
+ Path(fname).parent.mkdir(parents=True, exist_ok=True)
70
+ return store_method(fname, data, **kwargs)
71
+
72
+
73
+ def get_frame(
74
+ scene_meta: dict[str, Any],
75
+ frame_key: int | str | float,
76
+ ) -> dict[str, Any]:
77
+ """
78
+ Get a frame from scene_meta based on name or index.
79
+
80
+ Args:
81
+ scene_meta: Dictionary containing scene metadata
82
+ frame_key: Either a string (frame name) or integer (frame index) or float (video timestamp)
83
+
84
+ Returns:
85
+ The frame data (dict)
86
+ """
87
+ frame_idx = get_frame_index(scene_meta, frame_key)
88
+ if isinstance(frame_idx, int):
89
+ frame = scene_meta["frames"][frame_idx]
90
+ frame["_is_interpolated"] = False
91
+ else:
92
+ frame = {}
93
+ frame["frame_name"] = frame_key
94
+ left = int(frame_idx) # it's floor operation
95
+ assert left >= 0 and left < (len(scene_meta["frames"]) - 1), "Wrong index"
96
+ frame_left = scene_meta["frames"][left]
97
+ frame_right = scene_meta["frames"][left + 1]
98
+ # Interpolate intrinsics and extrinsics
99
+ frame["transform_matrix"] = interpolate_extrinsics(
100
+ frame_left["transform_matrix"],
101
+ frame_right["transform_matrix"],
102
+ frame_idx - left,
103
+ )
104
+ frame.update(
105
+ interpolate_intrinsics(
106
+ frame_left,
107
+ frame_right,
108
+ frame_idx - left,
109
+ )
110
+ )
111
+ frame["_is_interpolated"] = True
112
+ return frame
113
+
114
+
115
+ def get_intrinsics(
116
+ scene_meta,
117
+ frame_key,
118
+ fmt: str = "torch",
119
+ ) -> torch.Tensor | np.ndarray | list:
120
+ frame = get_frame(scene_meta, frame_key)
121
+ return convert_camera_coeffs_to_pinhole_matrix(scene_meta, frame, fmt=fmt)
122
+
123
+
124
+ def get_extrinsics(
125
+ scene_meta,
126
+ frame_key,
127
+ fmt: str = "torch",
128
+ ) -> torch.Tensor | np.ndarray | list | None:
129
+ frame = get_frame(scene_meta, frame_key)
130
+ if "transform_matrix" in frame:
131
+ if fmt == "torch":
132
+ return torch.tensor(frame["transform_matrix"]).reshape(4, 4).float()
133
+ elif fmt == "np":
134
+ return np.array(frame["transform_matrix"]).reshape(4, 4)
135
+ return frame["transform_matrix"]
136
+ else:
137
+ # TODO: should not happen if we enable interpolation
138
+ return None
139
+
140
+
141
+ def get_frame_index(
142
+ scene_meta: dict[str, Any],
143
+ frame_key: int | str | float,
144
+ frame_index_threshold_sec: float = 1e-4,
145
+ distance_threshold_sec: float = 2.0,
146
+ ) -> int | float:
147
+ """
148
+ Returns the frame index from scene_meta based on name (str) or index (int) or sub-frame index (float).
149
+
150
+ Args:
151
+ scene_meta: Dictionary containing scene metadata
152
+ frame_key: Either a string (frame name) or integer (frame index) or float (sub-frame index)
153
+ frame_index_threshold_sec: A threshold for nearest neighbor clipping for indexes (in seconds).
154
+ Default is 1e-4, which is 10000 fps.
155
+ distance_th: A threshold for maximum distance between interpolated frames (in seconds).
156
+
157
+ Returns:
158
+ Frame index (int)
159
+
160
+ Raises:
161
+ ValueError: If frame_key is not a string or integer or float
162
+ """
163
+ if isinstance(frame_key, str):
164
+ try:
165
+ return scene_meta["frame_names"][frame_key]
166
+ except KeyError as err:
167
+ error_message = (
168
+ f"Frame name not found: {frame_key} - "
169
+ f"Please verify scene_meta.json of scene: {scene_meta['dataset_name']}/{scene_meta['scene_name']}"
170
+ )
171
+ logger.error(error_message)
172
+ raise KeyError(error_message) from err
173
+
174
+ if isinstance(frame_key, int):
175
+ return frame_key
176
+
177
+ if isinstance(frame_key, float):
178
+ # If exact hit
179
+ if frame_key in scene_meta["frame_names"]:
180
+ return scene_meta["frame_names"][frame_key]
181
+
182
+ frame_names = sorted(list(scene_meta["frame_names"].keys()))
183
+ distances = np.array([frm - frame_key for frm in frame_names])
184
+ left = int(np.nonzero(distances <= 0)[0][-1])
185
+ right = left + 1
186
+
187
+ # The last frame or rounding errors
188
+ if (
189
+ left == distances.shape[0] - 1
190
+ or abs(distances[left]) < frame_index_threshold_sec
191
+ ):
192
+ return scene_meta["frame_names"][frame_names[int(left)]]
193
+ if abs(distances[right]) < frame_index_threshold_sec:
194
+ return scene_meta["frame_names"][frame_names[int(right)]]
195
+
196
+ interpolation_distance = distances[right] - distances[left]
197
+ if interpolation_distance > distance_threshold_sec:
198
+ raise ValueError(
199
+ f"Frame interpolation is forbidden for distances larger than {distance_threshold_sec}."
200
+ )
201
+ alpha = -distances[left] / interpolation_distance
202
+
203
+ return scene_meta["frame_names"][frame_names[int(left)]] + alpha
204
+
205
+ raise ValueError(f"Frame key type not supported: {frame_key} ({type(frame_key)}).")
206
+
207
+
208
+ def load_modality_data(
209
+ scene_root: Path | str,
210
+ results: dict[str, Any],
211
+ modality_dict: dict[str, Any],
212
+ modality: str,
213
+ frame: dict[str, Any] | None = None,
214
+ fmt: str = "torch",
215
+ ) -> dict[str, Any]:
216
+ """
217
+ Processes a modality by loading data from a specified path and updating the results dictionary.
218
+ This function extracts the format and path from the given modality dictionary, loads the data
219
+ from the specified path, and updates the results dictionary with the loaded data.
220
+
221
+ Args:
222
+ scene_root (str or Path): The root directory of the scene where the data is located.
223
+ results (dict): A dictionary to store the loaded modality data and optional frame path.
224
+ modality_dict (dict): A dictionary containing the modality information, including 'format'
225
+ and the path to the data.
226
+ modality (str): The key under which the loaded modality data will be stored in the results.
227
+ frame (dict, optional): A dictionary containing frame information. If provided, that means we are loading
228
+ frame modalities, otherwise it is scene modalities.
229
+
230
+ Returns:
231
+ dict: The updated results dictionary containing the loaded modality data.
232
+ """
233
+ modality_format = modality_dict["format"]
234
+
235
+ # The modality is stored as a video
236
+ if "video" in modality_format:
237
+ assert isinstance(frame["frame_name"], float), "frame_name should be float"
238
+ video_file = None
239
+ if "chunks" in modality_dict:
240
+ video_list = modality_dict["chunks"]
241
+ # Get the correct chunk of the video
242
+ for video_chunk in video_list:
243
+ if video_chunk["start"] <= frame["frame_name"] <= video_chunk["end"]:
244
+ video_file = video_chunk
245
+ break
246
+ else:
247
+ # There is only one video (no chunks)
248
+ video_file = modality_dict
249
+ if "start" not in video_file:
250
+ video_file["start"] = 0
251
+ if "end" not in video_file:
252
+ video_file["end"] = float("inf")
253
+ if not (video_file["start"] <= frame["frame_name"] <= video_file["end"]):
254
+ video_file = None
255
+
256
+ # This timestamp is not available in any of the chunks
257
+ if video_file is None:
258
+ frame_name = frame["frame_name"]
259
+ logger.warning(
260
+ f"Modality {modality} ({modality_format}) is not available at time {frame_name}"
261
+ )
262
+ return results
263
+
264
+ # Load the modality from the video
265
+ loaded_modality = load_data(
266
+ Path(scene_root, video_file["file"]),
267
+ modality_format,
268
+ frame_key=frame["frame_name"] - video_file["start"],
269
+ )
270
+
271
+ if "bbox" in video_file:
272
+ loaded_modality = crop(loaded_modality, video_file["bbox"])
273
+
274
+ if loaded_modality is not None:
275
+ results[modality] = loaded_modality
276
+
277
+ if frame:
278
+ results[f"{modality}_fname"] = video_file["file"]
279
+ else:
280
+ modality_path = [v for k, v in modality_dict.items() if k != "format"][0]
281
+ if frame:
282
+ if modality_path in frame:
283
+ fname = frame[modality_path]
284
+ else:
285
+ fname = None
286
+ else:
287
+ fname = modality_path
288
+ if fname is not None:
289
+ loaded_modality = load_data(
290
+ Path(scene_root, fname),
291
+ modality_format,
292
+ frame_key=frame["frame_name"] if frame else None,
293
+ fmt=fmt,
294
+ )
295
+ results[modality] = loaded_modality
296
+ if frame:
297
+ results[f"{modality}_fname"] = frame[modality_path]
298
+ return results
299
+
300
+
301
+ def load_modality(
302
+ scene_root: Path | str,
303
+ modality_meta: dict[str, Any],
304
+ modality: str,
305
+ frame: dict[str, Any] | None = None,
306
+ fmt: str = "torch",
307
+ ) -> dict[str, Any]:
308
+ """
309
+ Loads modality data based on the provided metadata and updates the results dictionary.
310
+ This function navigates through the modality metadata to find the specified modality,
311
+ then loads the data for each modality found.
312
+
313
+ Args:
314
+ scene_root (str or Path): The root directory of the scene where the data is located.
315
+ modality_meta (dict): A nested dictionary containing metadata for various modalities.
316
+ modality (str): A string representing the path to the desired modality within the metadata,
317
+ using '/' as a separator for nested keys.
318
+ frame (dict, optional): A dictionary containing frame information. If provided, we are operating
319
+ on frame modalities, otherwise it is scene modalities.
320
+
321
+ Returns:
322
+ dict: A dictionary containing the loaded modality data.
323
+ """
324
+ results = {}
325
+ # support for nested modalities like "pred_depth/metric3dv2"
326
+ modality_keys = modality.split("/")
327
+ current_modality = modality_meta
328
+ for key in modality_keys:
329
+ try:
330
+ current_modality = current_modality[key]
331
+ except KeyError as err:
332
+ error_message = (
333
+ f"Modality '{err.args[0]}' not found in modalities metadata. "
334
+ f"Please verify the scene_meta.json and the provided modalities in {scene_root}."
335
+ )
336
+ logger.error(error_message)
337
+ raise KeyError(error_message) from err
338
+ if "format" in current_modality:
339
+ results = load_modality_data(
340
+ scene_root, results, current_modality, modality, frame, fmt=fmt
341
+ )
342
+ else:
343
+ # nested modality, return last by default
344
+ logger.warning("Nested modality, returning last by default")
345
+ key = next(reversed(current_modality.keys()))
346
+ results = load_modality_data(
347
+ scene_root, results, current_modality[key], modality, frame, fmt=fmt
348
+ )
349
+ return results
350
+
351
+
352
+ def load_frame(
353
+ scene_root: Path | str,
354
+ frame_key: int | str | float,
355
+ modalities: str | list[str] | None = None,
356
+ scene_meta: dict[str, Any] | None = None,
357
+ load_intrinsics: bool = True,
358
+ load_extrinsics: bool = True,
359
+ fmt: str = "torch",
360
+ interpolate: bool = False,
361
+ ) -> dict[str, Any]:
362
+ """
363
+ Load a single frame from a scene with specified modalities.
364
+
365
+ Args:
366
+ scene_root (str or Path): The root directory of the scene where the data is located.
367
+ frame_key (int or str or float): Either a string (frame name) or integer (frame index) or float (video timestamp).
368
+ modalities (str or list[str], optional): The modality or list of modalities to load.
369
+ If None, only basic frame information is loaded.
370
+ scene_meta (dict, optional): Dictionary containing scene metadata. If None, it will be loaded
371
+ from scene_meta.json in the scene_root.
372
+ interpolate (bool, optional): Allow interpolating frames?
373
+
374
+ Returns:
375
+ dict: A dictionary containing the loaded frame data with the requested modalities.
376
+ """
377
+ scene_root = Path(scene_root)
378
+ if scene_meta is None:
379
+ scene_meta = _load_scene_meta(scene_root / "scene_meta.json")
380
+ frame = get_frame(scene_meta, frame_key)
381
+ # compact, standarized frame representation
382
+ wai_frame = {}
383
+ if load_extrinsics:
384
+ extrinsics = get_extrinsics(
385
+ scene_meta,
386
+ frame_key,
387
+ fmt=fmt,
388
+ )
389
+ if extrinsics is not None:
390
+ wai_frame["extrinsics"] = extrinsics
391
+ if load_intrinsics:
392
+ camera_model = frame.get("camera_model", scene_meta.get("camera_model"))
393
+ wai_frame["camera_model"] = camera_model
394
+ if camera_model == "PINHOLE":
395
+ wai_frame["intrinsics"] = get_intrinsics(scene_meta, frame_key, fmt=fmt)
396
+ elif camera_model in ["OPENCV", "OPENCV_FISHEYE"]:
397
+ # optional per-frame intrinsics
398
+ for camera_key in CAMERA_KEYS:
399
+ if camera_key in frame:
400
+ wai_frame[camera_key] = float(frame[camera_key])
401
+ elif camera_key in scene_meta:
402
+ wai_frame[camera_key] = float(scene_meta[camera_key])
403
+ else:
404
+ error_message = (
405
+ f"Camera model not supported: {camera_model} - "
406
+ f"Please verify scene_meta.json of scene: {scene_meta['dataset_name']}/{scene_meta['scene_name']}"
407
+ )
408
+ logger.error(error_message)
409
+ raise NotImplementedError(error_message)
410
+ wai_frame["w"] = frame.get("w", scene_meta["w"] if "w" in scene_meta else None)
411
+ wai_frame["h"] = frame.get("h", scene_meta["h"] if "h" in scene_meta else None)
412
+ wai_frame["frame_name"] = frame["frame_name"]
413
+ wai_frame["frame_idx"] = get_frame_index(scene_meta, frame_key)
414
+ wai_frame["_is_interpolated"] = frame["_is_interpolated"]
415
+
416
+ if modalities is not None:
417
+ if isinstance(modalities, str):
418
+ modalities = [modalities]
419
+ for modality in modalities:
420
+ # Handle regex patterns in modality
421
+ if any(char in modality for char in ".|*+?()[]{}^$\\"):
422
+ # This is a regex pattern
423
+ pattern = re.compile(modality)
424
+ matching_modalities = [
425
+ m for m in scene_meta["frame_modalities"] if pattern.match(m)
426
+ ]
427
+ if not matching_modalities:
428
+ raise ValueError(
429
+ f"No modalities match the pattern: {modality} in scene: {scene_root}"
430
+ )
431
+ # Use the first matching modality
432
+ modality = matching_modalities[0]
433
+ current_modalities = load_modality(
434
+ scene_root, scene_meta["frame_modalities"], modality, frame, fmt=fmt
435
+ )
436
+ wai_frame.update(current_modalities)
437
+
438
+ return wai_frame
439
+
440
+
441
+ def set_frame(
442
+ scene_meta: dict[str, Any],
443
+ frame_key: int | str,
444
+ new_frame: dict[str, Any],
445
+ sort: bool = False,
446
+ ) -> dict[str, Any]:
447
+ """
448
+ Replace a frame in scene_meta with a new frame.
449
+
450
+ Args:
451
+ scene_meta: Dictionary containing scene metadata.
452
+ frame_key: Either a string (frame name) or integer (frame index).
453
+ new_frame: New frame data to replace the existing frame.
454
+ sort: If True, sort the keys in the new_frame dictionary.
455
+
456
+ Returns:
457
+ Updated scene_meta dictionary.
458
+ """
459
+ frame_idx = get_frame_index(scene_meta, frame_key)
460
+ if isinstance(frame_idx, float):
461
+ raise ValueError(
462
+ f"Setting frame for sub-frame frame_key is not supported: {frame_key} ({type(frame_key)})."
463
+ )
464
+ if sort:
465
+ new_frame = {k: new_frame[k] for k in sorted(new_frame)}
466
+ scene_meta["frames"][frame_idx] = new_frame
467
+ return scene_meta
468
+
469
+
470
+ def nest_modality(
471
+ frame_modalities: dict[str, Any],
472
+ modality_name: str,
473
+ ) -> dict[str, Any]:
474
+ """
475
+ Converts a flat modality structure into a nested one based on the modality name.
476
+
477
+ Args:
478
+ frame_modalities (dict): Dictionary containing frame modalities.
479
+ modality_name (str): The name of the modality to nest.
480
+
481
+ Returns:
482
+ dict: A dictionary with the nested modality structure.
483
+ """
484
+ frame_modality = {}
485
+ if modality_name in frame_modalities:
486
+ frame_modality = frame_modalities[modality_name]
487
+ if "frame_key" in frame_modality:
488
+ # required for backwards compatibility
489
+ # converting non-nested format into nested one based on name
490
+ modality_name = frame_modality["frame_key"].split("_")[0]
491
+ frame_modality = {modality_name: frame_modality}
492
+ return frame_modality
mapanything/utils/wai/intersection_check.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+ from tqdm import tqdm
4
+
5
+
6
+ def create_frustum_from_intrinsics(
7
+ intrinsics: torch.Tensor,
8
+ near: torch.Tensor | float,
9
+ far: torch.Tensor | float,
10
+ ) -> torch.Tensor:
11
+ r"""
12
+ Create a frustum from camera intrinsics.
13
+
14
+ Args:
15
+ intrinsics (torch.Tensor): Bx3x3 Intrinsics of cameras.
16
+ near (torch.Tensor or float): [B] Near plane distance.
17
+ far (torch.Tensor or float): [B] Far plane distance.
18
+
19
+ Returns:
20
+ frustum (torch.Tensor): Bx8x3 batch of frustum points following the order:
21
+ 5 ---------- 4
22
+ |\ /|
23
+ 6 \ / 7
24
+ \ 1 ---- 0 /
25
+ \| |/
26
+ 2 ---- 3
27
+ """
28
+
29
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
30
+ cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
31
+
32
+ # Calculate the offsets at the near plane
33
+ near_x = near * (cx / fx)
34
+ near_y = near * (cy / fy)
35
+ far_x = far * (cx / fx)
36
+ far_y = far * (cy / fy)
37
+
38
+ # Define frustum vertices in camera space
39
+ near_plane = torch.stack(
40
+ [
41
+ torch.stack([near_x, near_y, near * torch.ones_like(near_x)], dim=-1),
42
+ torch.stack([-near_x, near_y, near * torch.ones_like(near_x)], dim=-1),
43
+ torch.stack([-near_x, -near_y, near * torch.ones_like(near_x)], dim=-1),
44
+ torch.stack([near_x, -near_y, near * torch.ones_like(near_x)], dim=-1),
45
+ ],
46
+ dim=1,
47
+ )
48
+
49
+ far_plane = torch.stack(
50
+ [
51
+ torch.stack([far_x, far_y, far * torch.ones_like(far_x)], dim=-1),
52
+ torch.stack([-far_x, far_y, far * torch.ones_like(far_x)], dim=-1),
53
+ torch.stack([-far_x, -far_y, far * torch.ones_like(far_x)], dim=-1),
54
+ torch.stack([far_x, -far_y, far * torch.ones_like(far_x)], dim=-1),
55
+ ],
56
+ dim=1,
57
+ )
58
+
59
+ return torch.cat([near_plane, far_plane], dim=1)
60
+
61
+
62
+ def _frustum_to_triangles(frustum: torch.Tensor) -> torch.Tensor:
63
+ """
64
+ Convert frustum to triangles.
65
+
66
+ Args:
67
+ frustums (torch.Tensor): Bx8 batch of frustum points.
68
+
69
+ Returns:
70
+ frustum_triangles (torch.Tensor): Bx3x3 batch of frustum triangles.
71
+ """
72
+
73
+ triangle_inds = torch.tensor(
74
+ [
75
+ [0, 1, 2],
76
+ [0, 2, 3],
77
+ [0, 3, 7],
78
+ [0, 7, 4],
79
+ [1, 2, 6],
80
+ [1, 6, 5],
81
+ [1, 4, 5],
82
+ [1, 0, 4],
83
+ [2, 6, 7],
84
+ [2, 3, 7],
85
+ [6, 7, 4],
86
+ [6, 5, 4],
87
+ ]
88
+ )
89
+ frustum_triangles = frustum[:, triangle_inds]
90
+ return frustum_triangles
91
+
92
+
93
+ def segment_triangle_intersection_check(
94
+ start_points: torch.Tensor,
95
+ end_points: torch.Tensor,
96
+ triangles: torch.Tensor,
97
+ ) -> torch.Tensor:
98
+ """
99
+ Check if segments (lines with starting and end point) intersect triangles in 3D using the
100
+ Moller-Trumbore algorithm.
101
+
102
+ Args:
103
+ start_points (torch.Tensor): Bx3 Starting points of the segment.
104
+ end_points (torch.Tensor): Bx3 End points of the segment.
105
+ triangles (torch.Tensor): Bx3x3 Vertices of the triangles.
106
+
107
+ Returns:
108
+ intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its
109
+ corresponding triangle.
110
+ """
111
+ vertex0 = triangles[:, 0, :]
112
+ vertex1 = triangles[:, 1, :]
113
+ vertex2 = triangles[:, 2, :]
114
+ edge1 = vertex1 - vertex0
115
+ edge2 = vertex2 - vertex0
116
+ ray_vectors = end_points - start_points
117
+ max_lengths = torch.norm(ray_vectors, dim=1)
118
+ ray_vectors = ray_vectors / max_lengths[:, None]
119
+ h = torch.cross(ray_vectors, edge2, dim=1)
120
+ a = (edge1 * h).sum(dim=1)
121
+
122
+ epsilon = 1e-6
123
+ mask = torch.abs(a) > epsilon
124
+ f = torch.zeros_like(a)
125
+ f[mask] = 1.0 / a[mask]
126
+
127
+ s = start_points - vertex0
128
+ u = f * (s * h).sum(dim=1)
129
+ q = torch.cross(s, edge1, dim=1)
130
+ v = f * (ray_vectors * q).sum(dim=1)
131
+
132
+ t = f * (edge2 * q).sum(dim=1)
133
+
134
+ # Check conditions
135
+ intersects = (
136
+ (u >= 0)
137
+ & (u <= 1)
138
+ & (v >= 0)
139
+ & (u + v <= 1)
140
+ & (t >= epsilon)
141
+ & (t <= max_lengths)
142
+ )
143
+
144
+ return intersects
145
+
146
+
147
+ def triangle_intersection_check(
148
+ triangles1: torch.Tensor,
149
+ triangles2: torch.Tensor,
150
+ ) -> torch.Tensor:
151
+ """
152
+ Check if two triangles intersect.
153
+
154
+ Args:
155
+ triangles1 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles.
156
+ triangles2 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles.
157
+
158
+ Returns:
159
+ triangle_intersection (torch.Tensor): B Boolean tensor indicating if triangles intersect.
160
+ """
161
+ n = triangles1.shape[1]
162
+ start_points1 = rearrange(triangles1, "B N C -> (B N) C")
163
+ end_points1 = rearrange(
164
+ triangles1[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C"
165
+ )
166
+
167
+ start_points2 = rearrange(triangles2, "B N C -> (B N) C")
168
+ end_points2 = rearrange(
169
+ triangles2[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C"
170
+ )
171
+ intersection_1_2 = segment_triangle_intersection_check(
172
+ start_points1, end_points1, repeat(triangles2, "B N C -> (B N2) N C", N2=3)
173
+ )
174
+ intersection_2_1 = segment_triangle_intersection_check(
175
+ start_points2, end_points2, repeat(triangles1, "B N C -> (B N2) N C", N2=3)
176
+ )
177
+ triangle_intersection = torch.any(
178
+ rearrange(intersection_1_2, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n),
179
+ dim=1,
180
+ ) | torch.any(
181
+ rearrange(intersection_2_1, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n),
182
+ dim=1,
183
+ )
184
+ return triangle_intersection
185
+
186
+
187
+ def frustum_intersection_check(
188
+ frustums: torch.Tensor,
189
+ check_inside: bool = True,
190
+ chunk_size: int = 500,
191
+ device: str | None = None,
192
+ ) -> torch.Tensor:
193
+ """
194
+ Check if any pair of the frustums intersect with each other.
195
+
196
+ Args:
197
+ frustums (torch.Tensor): Bx8 batch of frustum points.
198
+ check_inside (bool): If True, also checks if one frustum is inside another.
199
+ Defaults to True.
200
+ chunk_size (Optional[int]): Number of chunks to split the computation into.
201
+ Defaults to 500.
202
+ device (Optional[str]): Device to store exhuastive frustum intersection matrix on.
203
+ Defaults to None.
204
+
205
+ Returns:
206
+ frustum_intersection (torch.Tensor): BxB tensor of Booleans indicating if any pair
207
+ of frustums intersect with each other.
208
+ """
209
+ B = frustums.shape[0]
210
+ if device is None:
211
+ device = frustums.device
212
+ frustum_triangles = _frustum_to_triangles(frustums)
213
+ T = frustum_triangles.shape[1]
214
+
215
+ # Perform frustum in frustum check if required
216
+ if check_inside:
217
+ frustum_intersection = frustums_in_frustum_check(
218
+ frustums=frustums, chunk_size=chunk_size, device=device
219
+ )
220
+ else:
221
+ frustum_intersection = torch.zeros((B, B), dtype=torch.bool, device=device)
222
+
223
+ # Check triangle intersections in chunks
224
+ for i in tqdm(range(0, B, chunk_size), desc="Checking triangle intersections"):
225
+ i_end = min(i + chunk_size, B)
226
+ chunk_i_size = i_end - i
227
+
228
+ for j in range(0, B, chunk_size):
229
+ j_end = min(j + chunk_size, B)
230
+ chunk_j_size = j_end - j
231
+
232
+ # Process all triangle pairs between the two chunks in a vectorized way
233
+ triangles_i = frustum_triangles[i:i_end] # [chunk_i, T, 3, 3]
234
+ triangles_j = frustum_triangles[j:j_end] # [chunk_j, T, 3, 3]
235
+
236
+ # Reshape to process all triangle pairs at once
237
+ tri_i = triangles_i.reshape(chunk_i_size * T, 3, 3)
238
+ tri_j = triangles_j.reshape(chunk_j_size * T, 3, 3)
239
+
240
+ # Expand for all pairs - explicitly specify dimensions instead of using ...
241
+ tri_i_exp = repeat(tri_i, "bt i j -> (bt bj_t) i j", bj_t=chunk_j_size * T)
242
+ tri_j_exp = repeat(tri_j, "bt i j -> (bi_t bt) i j", bi_t=chunk_i_size * T)
243
+
244
+ # Check intersection
245
+ batch_intersect = triangle_intersection_check(tri_i_exp, tri_j_exp)
246
+
247
+ # Reshape and check if any triangle pair intersects
248
+ batch_intersect = batch_intersect.reshape(chunk_i_size, T, chunk_j_size, T)
249
+ batch_intersect = batch_intersect.any(dim=(1, 3))
250
+
251
+ # Update result
252
+ frustum_intersection[i:i_end, j:j_end] |= batch_intersect.to(device)
253
+
254
+ return frustum_intersection
255
+
256
+
257
+ def ray_triangle_intersection_check(
258
+ ray_origins: torch.Tensor,
259
+ ray_vectors: torch.Tensor,
260
+ triangles: torch.Tensor,
261
+ max_lengths: torch.Tensor | None = None,
262
+ ) -> torch.Tensor:
263
+ """
264
+ Check if rays intersect triangles in 3D using the Moller-Trumbore algorithm, considering the
265
+ finite length of rays.
266
+
267
+ Args:
268
+ ray_origins (torch.Tensor): Bx3 Origins of the rays.
269
+ ray_vectors (torch.Tensor): Bx3 Direction vectors of the rays.
270
+ triangles (torch.Tensor): Bx3x3 Vertices of the triangles.
271
+ max_lengths Optional[torch.Tensor]: B Maximum lengths of the rays.
272
+
273
+ Returns:
274
+ intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its
275
+ corresponding triangle.
276
+ """
277
+ vertex0 = triangles[:, 0, :]
278
+ vertex1 = triangles[:, 1, :]
279
+ vertex2 = triangles[:, 2, :]
280
+ edge1 = vertex1 - vertex0
281
+ edge2 = vertex2 - vertex0
282
+ h = torch.cross(ray_vectors, edge2, dim=1)
283
+ a = (edge1 * h).sum(dim=1)
284
+
285
+ epsilon = 1e-6
286
+ mask = torch.abs(a) > epsilon
287
+ f = torch.zeros_like(a)
288
+ f[mask] = 1.0 / a[mask]
289
+
290
+ s = ray_origins - vertex0
291
+ u = f * (s * h).sum(dim=1)
292
+ q = torch.cross(s, edge1, dim=1)
293
+ v = f * (ray_vectors * q).sum(dim=1)
294
+
295
+ t = f * (edge2 * q).sum(dim=1)
296
+
297
+ # Check conditions
298
+ intersects = (u >= 0) & (u <= 1) & (v >= 0) & (u + v <= 1) & (t >= epsilon)
299
+ if max_lengths is not None:
300
+ intersects &= t <= max_lengths
301
+
302
+ return intersects
303
+
304
+
305
+ #### Checks for frustums
306
+ def _frustum_to_planes(frustums: torch.Tensor) -> torch.Tensor:
307
+ r"""
308
+ Converts frustum parameters to plane representation.
309
+
310
+ Args:
311
+ frustums (torch.Tensor): Bx8 batch of frustum points following the order:
312
+ 5 ---------- 4
313
+ |\ /|
314
+ 6 \ / 7
315
+ \ 1 ---- 0 /
316
+ \| |/
317
+ 2 ---- 3
318
+
319
+ Returns:
320
+ planes (torch.Tensor): Bx6x4 where 6 represents the six frustum planes and
321
+ 4 represents plane parameters [a, b, c, d].
322
+ """
323
+ planes = []
324
+ for inds in [[0, 1, 3], [1, 6, 2], [0, 3, 7], [2, 6, 3], [0, 5, 1], [6, 5, 4]]:
325
+ normal = torch.cross(
326
+ frustums[:, inds[1]] - frustums[:, inds[0]],
327
+ frustums[:, inds[2]] - frustums[:, inds[0]],
328
+ dim=1,
329
+ )
330
+ normal = normal / torch.norm(normal, dim=1, keepdim=True)
331
+ d = -torch.sum(normal * frustums[:, inds[0]], dim=1, keepdim=True)
332
+ planes.append(torch.cat([normal, d], -1))
333
+ return torch.stack(planes, 1)
334
+
335
+
336
+ def points_in_frustum_check(
337
+ frustums: torch.Tensor,
338
+ points: torch.Tensor,
339
+ chunk_size: int | None = None,
340
+ device: str | None = None,
341
+ ):
342
+ """
343
+ Check if points are inside frustums.
344
+
345
+ Args:
346
+ frustums (torch.Tensor): Bx8 batch of frustum points.
347
+ points (torch.Tensor): BxNx3 batch of points.
348
+ chunk_size (Optional[int]): Number of chunks to split the computation into. Defaults to None.
349
+ device (Optional[str]): Device to perfrom computation on. Defaults to None.
350
+
351
+ Returns:
352
+ inside (torch.Tensor): BxN batch of Booleans indicating if points are inside frustums.
353
+ """
354
+ if device is None:
355
+ device = frustums.device
356
+
357
+ if chunk_size is not None:
358
+ # Split computation into chunks to avoid OOM errors for large batch sizes
359
+ point_plane_direction = []
360
+ for chunk_idx in range(0, frustums.shape[0], chunk_size):
361
+ chunk_frustum_planes = _frustum_to_planes(
362
+ frustums[chunk_idx : chunk_idx + chunk_size]
363
+ )
364
+ # Bx8x4 tensor of plane parameters [a, b, c, d]
365
+ chunk_points = points[chunk_idx : chunk_idx + chunk_size]
366
+ chunk_point_plane_direction = torch.einsum(
367
+ "bij,bnj->bni", (chunk_frustum_planes[:, :, :-1], chunk_points)
368
+ ) + repeat(
369
+ chunk_frustum_planes[:, :, -1], "B P -> B N P", N=chunk_points.shape[1]
370
+ ) # BxMxN tensor
371
+ point_plane_direction.append(chunk_point_plane_direction.to(device))
372
+ point_plane_direction = torch.cat(point_plane_direction)
373
+ else:
374
+ # Convert frustums to planes
375
+ frustum_planes = _frustum_to_planes(
376
+ frustums
377
+ ) # Bx8x4 tensor of plane parameters [a, b, c, d]
378
+ # Compute dot product between each point and each plane
379
+ point_plane_direction = torch.einsum(
380
+ "bij,bnj->bni", (frustum_planes[:, :, :-1], points)
381
+ ) + repeat(frustum_planes[:, :, -1], "B P -> B N P", N=points.shape[1]).to(
382
+ device
383
+ ) # BxMxN tensor
384
+
385
+ inside = (point_plane_direction >= 0).all(-1)
386
+ return inside
387
+
388
+
389
+ def frustums_in_frustum_check(
390
+ frustums: torch.Tensor,
391
+ chunk_size: int,
392
+ device: str | None = None,
393
+ use_double_chunking: bool = True,
394
+ ):
395
+ """
396
+ Check if frustums are contained within other frustums.
397
+
398
+ Args:
399
+ frustums (torch.Tensor): Bx8 batch of frustum points.
400
+ chunk_size (Optional[int]): Number of chunks to split the computation into.
401
+ Defaults to None.
402
+ device (Optional[str]): Device to store exhuastive frustum containment matrix on.
403
+ Defaults to None.
404
+ use_double_chunking (bool): If True, use double chunking to avoid OOM errors.
405
+ Defaults to True.
406
+
407
+ Returns:
408
+ frustum_contained (torch.Tensor): BxB batch of Booleans indiciating if frustums are inside
409
+ other frustums.
410
+ """
411
+ B = frustums.shape[0]
412
+ if device is None:
413
+ device = frustums.device
414
+
415
+ if use_double_chunking:
416
+ frustum_contained = torch.zeros((B, B), dtype=torch.bool, device=device)
417
+ # Check if frustums are containing each other by processing in chunks
418
+ for i in tqdm(range(0, B, chunk_size), desc="Checking frustum containment"):
419
+ i_end = min(i + chunk_size, B)
420
+ chunk_i_size = i_end - i
421
+
422
+ for j in range(0, B, chunk_size):
423
+ j_end = min(j + chunk_size, B)
424
+ chunk_j_size = j_end - j
425
+
426
+ # Process a chunk of frustums against another chunk
427
+ frustums_i = frustums[i:i_end]
428
+ frustums_j_vertices = frustums[
429
+ j:j_end, :1
430
+ ] # Just need one vertex to check containment
431
+
432
+ # Perform points in frustum check
433
+ contained = rearrange(
434
+ points_in_frustum_check(
435
+ repeat(frustums_i, "B ... -> (B B2) ...", B2=chunk_j_size),
436
+ repeat(
437
+ frustums_j_vertices, "B ... -> (B2 B) ...", B2=chunk_i_size
438
+ ),
439
+ )[:, 0],
440
+ "(B B2) -> B B2",
441
+ B=chunk_i_size,
442
+ ).to(device)
443
+
444
+ # Map results back to the full matrix
445
+ frustum_contained[i:i_end, j:j_end] |= contained
446
+ frustum_contained[j:j_end, i:i_end] |= contained.transpose(
447
+ 0, 1
448
+ ) # Symmetric relation
449
+ else:
450
+ # Perform points in frustum check with a single chunked loop
451
+ frustum_contained = rearrange(
452
+ points_in_frustum_check(
453
+ repeat(frustums, "B ... -> (B B2) ...", B2=B),
454
+ repeat(frustums[:, :1], "B ... -> (B2 B) ...", B2=B),
455
+ chunk_size=chunk_size,
456
+ )[:, 0],
457
+ "(B B2) -> B B2",
458
+ B=B,
459
+ ).to(device)
460
+ frustum_contained = frustum_contained | frustum_contained.T
461
+
462
+ return frustum_contained
mapanything/utils/wai/io.py ADDED
@@ -0,0 +1,1373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This utils script contains PORTAGE of wai-core io methods for MapAnything.
3
+ """
4
+
5
+ import gzip
6
+ import io
7
+ import json
8
+ import logging
9
+ import os
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any, Callable, cast, IO, Literal, overload
13
+
14
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+ import trimesh
19
+ import yaml
20
+ from PIL import Image, PngImagePlugin
21
+ from plyfile import PlyData, PlyElement
22
+ from safetensors.torch import load_file as load_sft, save_file as save_sft
23
+ from torchvision.io import decode_image
24
+ from yaml import CLoader
25
+
26
+ from mapanything.utils.wai.ops import (
27
+ to_numpy,
28
+ )
29
+ from mapanything.utils.wai.semantics import (
30
+ apply_id_to_color_mapping,
31
+ INVALID_ID,
32
+ load_semantic_color_mapping,
33
+ )
34
+
35
+ # Try to use orjson for faster JSON processing
36
+ try:
37
+ import orjson
38
+ except ImportError:
39
+ orjson = None
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ @overload
45
+ def _load_readable(
46
+ fname: Path | str, load_as_string: Literal[True], **kwargs
47
+ ) -> str: ...
48
+ @overload
49
+ def _load_readable(
50
+ fname: Path | str, load_as_string: Literal[False] = False, **kwargs
51
+ ) -> dict: ...
52
+
53
+
54
+ def _load_readable(
55
+ fname: Path | str,
56
+ load_as_string: bool = False,
57
+ **kwargs,
58
+ ) -> Any | str:
59
+ """
60
+ Loads data from a human-readable file and will try to parse JSON or YAML files as a dict, list,
61
+ int, float, str, bool, or None object. Can optionally return the file contents as a string.
62
+
63
+ Args:
64
+ fname (str or Path): The filename to load data from.
65
+ load_as_string (bool, optional): Whether to return the loaded data as a string.
66
+ Defaults to False.
67
+
68
+ Returns:
69
+ The loaded data, which can be any type of object that can be represented in JSON or YAML.
70
+
71
+ Raises:
72
+ NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
73
+ """
74
+ if load_as_string:
75
+ return _load_readable_string(fname, **kwargs)
76
+ else:
77
+ return _load_readable_structured(fname, **kwargs)
78
+
79
+
80
+ def _load_readable_structured(
81
+ fname: Path | str,
82
+ **kwargs,
83
+ ) -> Any:
84
+ """
85
+ Loads data from a human-readable file and will try to parse JSON or YAML files as a dict, list,
86
+ int, float, str, bool, or None object.
87
+
88
+ Args:
89
+ fname (str or Path): The filename to load data from.
90
+
91
+ Returns:
92
+ The loaded data, which can be any type of object that can be represented in JSON or YAML.
93
+
94
+ Raises:
95
+ NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
96
+ """
97
+ fname = Path(fname)
98
+ if not fname.exists():
99
+ raise FileNotFoundError(f"File does not exist: {fname}")
100
+
101
+ if fname.suffix == ".json":
102
+ # Use binary mode for JSON files
103
+ with open(fname, mode="rb") as f:
104
+ # Use orjson if available, otherwise use standard JSON
105
+ if orjson:
106
+ return orjson.loads(f.read())
107
+ return json.load(f)
108
+
109
+ if fname.suffix in [".yaml", ".yml"]:
110
+ # Use text mode with UTF-8 encoding for YAML files
111
+ with open(fname, mode="r", encoding="utf-8") as f:
112
+ return yaml.load(f, Loader=CLoader)
113
+
114
+ raise NotImplementedError(f"Readable format not supported: {fname.suffix}")
115
+
116
+
117
+ def _load_readable_string(
118
+ fname: Path | str,
119
+ **kwargs,
120
+ ) -> str:
121
+ """
122
+ Loads data from a human-readable file as a string.
123
+
124
+ Args:
125
+ fname (str or Path): The filename to load data from.
126
+
127
+ Returns:
128
+ The file's contents, as a string.
129
+ """
130
+ fname = Path(fname)
131
+ if not fname.exists():
132
+ raise FileNotFoundError(f"File does not exist: {fname}")
133
+
134
+ with open(fname, mode="r", encoding="utf-8") as f:
135
+ contents = f.read()
136
+
137
+ return contents
138
+
139
+
140
+ def _store_readable(
141
+ fname: Path | str,
142
+ data: Any,
143
+ **kwargs,
144
+ ) -> int:
145
+ """
146
+ Stores data in a human-readable file (JSON or YAML).
147
+
148
+ Args:
149
+ fname (str or Path): The filename to store data in.
150
+ data: The data to store, which can be any type of object that can be represented in JSON or YAML.
151
+
152
+ Returns:
153
+ The number of bytes written to the file.
154
+
155
+ Raises:
156
+ NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
157
+ """
158
+ fname = Path(fname)
159
+
160
+ # Create parent directory if it doesn't exist
161
+ os.makedirs(fname.parent, exist_ok=True)
162
+
163
+ if fname.suffix == ".json":
164
+ if orjson:
165
+ # Define the operation for orjson
166
+ with open(fname, mode="wb") as f:
167
+ return f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
168
+ else:
169
+ # Define the operation for standard json
170
+ with open(fname, mode="w", encoding="utf-8") as f:
171
+ json.dump(data, f, indent=2)
172
+ return f.tell()
173
+
174
+ elif fname.suffix in [".yaml", ".yml"]:
175
+ # Define the operation for YAML files
176
+ with open(fname, mode="w", encoding="utf-8") as f:
177
+ yaml.dump(data, f)
178
+ return f.tell()
179
+ else:
180
+ raise NotImplementedError(f"Writable format not supported: {fname.suffix}")
181
+
182
+
183
+ def get_processing_state(scene_root: Path | str) -> dict:
184
+ """
185
+ Retrieves the processing state of a scene.
186
+
187
+ Args:
188
+ scene_root (Path or str): The root directory of the scene.
189
+
190
+ Returns:
191
+ dict: A dictionary containing the processing state of the scene.
192
+ If no processing log exists, or reading it fails, an empty
193
+ dictionary is returned.
194
+ """
195
+ process_log_path = Path(scene_root) / "_process_log.json"
196
+
197
+ try:
198
+ return _load_readable_structured(process_log_path)
199
+ except FileNotFoundError:
200
+ logger.debug(f"Log file not found, returning empty dict: {process_log_path}")
201
+ return {}
202
+ except Exception:
203
+ logger.error(
204
+ f"Could not parse, returning empty dict: {process_log_path}", exc_info=True
205
+ )
206
+ return {}
207
+
208
+
209
+ def _write_exr(
210
+ fname: str | Path,
211
+ data: np.ndarray | torch.Tensor,
212
+ params: list | None = None,
213
+ **kwargs,
214
+ ) -> bool:
215
+ """
216
+ Writes an image as an EXR file using OpenCV.
217
+
218
+ Args:
219
+ fname (str or Path): The filename to save the image to.
220
+ data (numpy.ndarray, torch.Tensor): The image data to save. Must be a 2D or 3D array.
221
+ params (list, optional): A list of parameters to pass to OpenCV's imwrite function.
222
+ Defaults to None, which uses 32-bit with zip compression.
223
+
224
+ Returns:
225
+ bool: True if the image was saved successfully, False otherwise.
226
+
227
+ Raises:
228
+ ValueError: If the input data has less than two or more than three dimensions.
229
+
230
+ Notes:
231
+ Only 32-bit float (CV_32F) images can be saved.
232
+ For comparison of different compression methods, see P1732924327.
233
+ """
234
+ if Path(fname).suffix != ".exr":
235
+ raise ValueError(
236
+ f"Only filenames with suffix .exr allowed but received: {fname}"
237
+ )
238
+
239
+ ## Note: only 32-bit float (CV_32F) images can be saved
240
+ data_np = to_numpy(data, dtype=np.float32)
241
+ if (data_np.ndim > 3) or (data_np.ndim < 2):
242
+ raise ValueError(
243
+ f"Image needs to contain two or three dims but received: {data_np.shape}"
244
+ )
245
+
246
+ return cv2.imwrite(str(fname), data_np, params if params else [])
247
+
248
+
249
+ @overload
250
+ def _read_exr(fname: str | Path, fmt: Literal["np"], **kwargs) -> np.ndarray: ...
251
+ @overload
252
+ def _read_exr(fname: str | Path, fmt: Literal["PIL"], **kwargs) -> Image.Image: ...
253
+ @overload
254
+ def _read_exr(
255
+ fname: str | Path, fmt: Literal["torch"] = "torch", **kwargs
256
+ ) -> torch.Tensor: ...
257
+
258
+
259
+ def _read_exr(
260
+ fname: str | Path, fmt: Literal["np", "PIL", "torch"] = "torch", **kwargs
261
+ ) -> np.ndarray | torch.Tensor | Image.Image:
262
+ """
263
+ Reads an EXR image file using OpenCV.
264
+
265
+ Args:
266
+ fname (str or Path): The filename of the EXR image to read.
267
+ fmt (str): The format of the output data. Can be one of:
268
+ - "torch": Returns a PyTorch tensor.
269
+ - "np": Returns a NumPy array.
270
+ - "PIL": Returns a PIL Image object.
271
+ Defaults to "torch".
272
+
273
+ Returns:
274
+ The EXR image data in the specified output format.
275
+
276
+ Raises:
277
+ NotImplementedError: If the specified output format is not supported.
278
+ ValueError: If data shape is not supported, e.g. multi-channel PIL float images.
279
+
280
+ Notes:
281
+ The EXR image is read in its original format, without any conversion or rescaling.
282
+ """
283
+ data = cv2.imread(str(fname), cv2.IMREAD_UNCHANGED)
284
+ if data is None:
285
+ raise FileNotFoundError(f"Failed to read EXR file: {fname}")
286
+ if fmt == "torch":
287
+ # Convert to PyTorch tensor with float32 dtype
288
+ data = torch.from_numpy(data).float()
289
+ elif fmt == "np":
290
+ # Convert to NumPy array with float32 dtype
291
+ data = np.array(data, dtype=np.float32)
292
+ elif fmt == "PIL":
293
+ if data.ndim != 2:
294
+ raise ValueError("PIL does not support multi-channel EXR images")
295
+
296
+ # Convert to PIL Image object
297
+ data = Image.fromarray(data)
298
+ else:
299
+ raise NotImplementedError(f"fmt not supported: {fmt}")
300
+ return data
301
+
302
+
303
+ @overload
304
+ def _load_image(
305
+ fname: str | Path,
306
+ fmt: Literal["np"],
307
+ resize: tuple[int, int] | None = None,
308
+ **kwargs,
309
+ ) -> np.ndarray: ...
310
+ @overload
311
+ def _load_image(
312
+ fname: str | Path,
313
+ fmt: Literal["pil"],
314
+ resize: tuple[int, int] | None = None,
315
+ **kwargs,
316
+ ) -> Image.Image: ...
317
+ @overload
318
+ def _load_image(
319
+ fname: str | Path,
320
+ fmt: Literal["torch"] = "torch",
321
+ resize: tuple[int, int] | None = None,
322
+ **kwargs,
323
+ ) -> torch.Tensor: ...
324
+
325
+
326
+ def _load_image(
327
+ fname: str | Path,
328
+ fmt: Literal["np", "pil", "torch"] = "torch",
329
+ resize: tuple[int, int] | None = None,
330
+ **kwargs,
331
+ ) -> np.ndarray | torch.Tensor | Image.Image:
332
+ """
333
+ Loads an image from a file.
334
+
335
+ Args:
336
+ fname (str or Path): The filename to load the image from.
337
+ fmt (str): The format of the output data. Can be one of:
338
+ - "torch": Returns a PyTorch tensor with shape (C, H, W).
339
+ - "np": Returns a NumPy array with shape (H, W, C).
340
+ - "pil": Returns a PIL Image object.
341
+ Defaults to "torch".
342
+ resize (tuple, optional): A tuple of two integers representing the desired width and height of the image.
343
+ If None, the image is not resized. Defaults to None.
344
+
345
+ Returns:
346
+ The loaded image in the specified output format.
347
+
348
+ Raises:
349
+ NotImplementedError: If the specified output format is not supported.
350
+
351
+ Notes:
352
+ This function loads non-binary images in RGB mode and normalizes pixel values to the range [0, 1].
353
+ """
354
+
355
+ # Fastest way to load into torch tensor
356
+ if resize is None and fmt == "torch":
357
+ return decode_image(str(fname)).float() / 255.0
358
+
359
+ # Load using PIL
360
+ with open(fname, "rb") as f:
361
+ pil_image = Image.open(f)
362
+ pil_image.load()
363
+
364
+ if pil_image.mode not in ["RGB", "RGBA"]:
365
+ raise OSError(
366
+ f"Expected a RGB or RGBA image in {fname}, but instead found an image with mode {pil_image.mode}"
367
+ )
368
+
369
+ if resize is not None:
370
+ pil_image = pil_image.resize(resize)
371
+
372
+ if fmt == "torch":
373
+ return (
374
+ torch.from_numpy(np.array(pil_image)).permute(2, 0, 1).float() / 255.0
375
+ )
376
+ elif fmt == "np":
377
+ return np.array(pil_image, dtype=np.float32) / 255.0
378
+ elif fmt == "pil":
379
+ return pil_image
380
+ else:
381
+ raise NotImplementedError(f"Image format not supported: {fmt}")
382
+
383
+
384
+ def _store_image(
385
+ fname: str | Path, img_data: np.ndarray | torch.Tensor | Image.Image, **kwargs
386
+ ) -> None:
387
+ """
388
+ Stores an image in a file.
389
+
390
+ Args:
391
+ fname (str or Path): The filename to store the image in.
392
+ img_data (numpy.ndarray, torch.tensor or PIL.Image.Image): The image data to store.
393
+
394
+ Notes (for numpy.ndarray or torch.tensor inputs):
395
+ This function assumes that the input image data is in the range [0, 1], and has shape
396
+ (H, W, C), or (C, H, W) for PyTorch tensors, with C being 3 or 4.
397
+ It converts the image data to uint8 format and saves it as a compressed image file.
398
+ """
399
+ if isinstance(img_data, torch.Tensor):
400
+ if img_data.ndim != 3:
401
+ raise ValueError(f"Tensor needs to be 3D but received: {img_data.shape=}")
402
+
403
+ if img_data.shape[0] in [3, 4]:
404
+ # Convert to HWC format expected by pillow `Image.save` below
405
+ img_data = img_data.permute(1, 2, 0)
406
+
407
+ img_data = img_data.contiguous()
408
+
409
+ if isinstance(img_data, (np.ndarray, torch.Tensor)):
410
+ if img_data.shape[-1] not in [3, 4]:
411
+ raise ValueError(
412
+ f"Image must have 3 or 4 channels, but received: {img_data.shape=}"
413
+ )
414
+
415
+ img_data_np = to_numpy(img_data, dtype=np.float32)
416
+ img_data = Image.fromarray((255 * img_data_np).round().astype(np.uint8))
417
+
418
+ with open(fname, "wb") as f:
419
+ pil_kwargs = {
420
+ # Make PNGs faster to save using minimal compression
421
+ "optimize": False,
422
+ "compress_level": 1,
423
+ # Higher JPEG image quality
424
+ "quality": "high",
425
+ }
426
+ pil_kwargs.update(kwargs)
427
+ img_data.save(cast(IO[bytes], f), **pil_kwargs)
428
+
429
+
430
+ def _load_binary_mask(
431
+ fname: str | Path,
432
+ fmt: str = "torch",
433
+ resize: tuple[int, int] | None = None,
434
+ **kwargs,
435
+ ) -> np.ndarray | torch.Tensor | Image.Image:
436
+ """
437
+ Loads a binary image from a file.
438
+
439
+ Args:
440
+ fname (str or Path): The filename to load the binary image from.
441
+ fmt (str): The format of the output data. Can be one of:
442
+ - "torch": Returns a PyTorch Boolean tensor with shape H x W.
443
+ - "np": Returns a NumPy Boolean array with shape H x W.
444
+ - "pil": Returns a PIL Image object.
445
+ Defaults to "torch".
446
+ resize (tuple, optional): A tuple of two integers representing the desired width and height of the binary image.
447
+ If None, the image is not resized. Defaults to None.
448
+
449
+ Returns:
450
+ The loaded binary image in the specified output format.
451
+
452
+ Raises:
453
+ NotImplementedError: If the specified output format is not supported.
454
+ """
455
+ if fmt not in ["pil", "np", "torch"]:
456
+ raise NotImplementedError(f"Image format not supported: {fmt}")
457
+
458
+ with open(fname, "rb") as f:
459
+ pil_image = Image.open(f)
460
+ pil_image.load()
461
+
462
+ if pil_image.mode == "L":
463
+ pil_image = pil_image.convert("1")
464
+
465
+ elif pil_image.mode != "1":
466
+ raise OSError(
467
+ f"Expected a binary or grayscale image in {fname}, but instead found an image with mode {pil_image.mode}"
468
+ )
469
+
470
+ if resize is not None:
471
+ pil_image = pil_image.resize(resize)
472
+
473
+ if fmt == "pil":
474
+ return pil_image
475
+
476
+ mask = np.array(pil_image, copy=True)
477
+ return mask if fmt == "np" else torch.from_numpy(mask)
478
+
479
+
480
+ def _store_binary_mask(
481
+ fname: str | Path, img_data: np.ndarray | torch.Tensor | Image.Image, **kwargs
482
+ ) -> None:
483
+ """
484
+ Stores a binary image in a compressed image file.
485
+
486
+ Args:
487
+ fname (str or Path): The filename to store the binary image in.
488
+ img_data (numpy.ndarray, torch.tensor or PIL.Image.Image): The binary image data to store.
489
+ """
490
+ if isinstance(img_data, Image.Image):
491
+ if img_data.mode not in ["1", "L"]:
492
+ raise RuntimeError(
493
+ f'Expected a PIL image with mode "1" or "L", but instead got a PIL image with mode {img_data.mode}'
494
+ )
495
+ elif isinstance(img_data, np.ndarray) or isinstance(img_data, torch.Tensor):
496
+ if len(img_data.squeeze().shape) != 2:
497
+ raise RuntimeError(
498
+ f"Expected a PyTorch tensor or NumPy array with shape (H, W, 1), (1, H, W) or (H, W), but the shape is {img_data.shape}"
499
+ )
500
+ img_data = img_data.squeeze()
501
+ else:
502
+ raise NotImplementedError(f"Input format not supported: {type(img_data)}")
503
+
504
+ if not isinstance(img_data, Image.Image):
505
+ img_data = to_numpy(img_data, dtype=bool)
506
+ img_data = Image.fromarray(img_data)
507
+
508
+ img_data = img_data.convert("1")
509
+ with open(fname, "wb") as f:
510
+ img_data.save(f, compress_level=1, optimize=False)
511
+
512
+
513
+ def _load_sft(
514
+ fname: str | Path,
515
+ fmt: str = "torch",
516
+ **kwargs,
517
+ ) -> torch.Tensor:
518
+ """
519
+ Loads a tensor from a safetensor file.
520
+
521
+ Args:
522
+ fname (str | Path): The filename of the safetensor file to load.
523
+ fmt (str, optional): The format of the output data. Currently only "torch" is supported.
524
+ **kwargs: Additional keyword arguments (unused).
525
+
526
+ Returns:
527
+ torch.Tensor: The loaded tensor.
528
+
529
+ Raises:
530
+ AssertionError: If the file extension is not .sft or if fmt is not "torch".
531
+ """
532
+ assert Path(fname).suffix == ".sft", "Only .sft (safetensor) is supported"
533
+ assert fmt == "torch", "Only torch format is supported for latent"
534
+ out = load_sft(str(fname))
535
+ return out["latent"]
536
+
537
+
538
+ def _store_sft(fname: str | Path, data: torch.Tensor, **kwargs) -> None:
539
+ """
540
+ Stores a tensor to a safetensor file.
541
+
542
+ Args:
543
+ fname (str | Path): The filename to store the latent in.
544
+ data (torch.Tensor): The latent tensor to store.
545
+ **kwargs: Additional keyword arguments (unused).
546
+
547
+ Raises:
548
+ AssertionError: If the file extension is not .sft or if data is not a torch.Tensor.
549
+ """
550
+ assert Path(fname).suffix == ".sft", "Only .sft (safetensor) is supported"
551
+ assert isinstance(data, torch.Tensor)
552
+ save_sft(tensors={"latent": data}, filename=str(fname))
553
+
554
+
555
+ def _store_depth(fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs) -> bool:
556
+ """
557
+ Stores a depth map in an EXR file.
558
+
559
+ Args:
560
+ fname (str or Path): The filename to save the depth map to.
561
+ data (numpy.ndarray, torch.tensor): The depth map to save.
562
+
563
+ Returns:
564
+ bool: True if the depth map was saved successfully, False otherwise.
565
+
566
+ Raises:
567
+ ValueError: If the input data does not have two dimensions after removing singleton dimensions.
568
+ """
569
+ data_np = to_numpy(data, dtype=np.float32)
570
+ data_np = data_np.squeeze() # remove all 1-dim entries
571
+ if data_np.ndim != 2:
572
+ raise ValueError(f"Depth image needs to be 2d, but received: {data_np.shape}")
573
+
574
+ if "params" in kwargs:
575
+ params = kwargs["params"]
576
+ else:
577
+ # use 16-bit with zip compression for depth maps
578
+ params = [
579
+ cv2.IMWRITE_EXR_TYPE,
580
+ cv2.IMWRITE_EXR_TYPE_HALF,
581
+ cv2.IMWRITE_EXR_COMPRESSION,
582
+ cv2.IMWRITE_EXR_COMPRESSION_ZIP,
583
+ ]
584
+
585
+ return _write_exr(fname, data_np, params=params)
586
+
587
+
588
+ def _load_depth(
589
+ fname: str | Path, fmt: str = "torch", **kwargs
590
+ ) -> np.ndarray | torch.Tensor | Image.Image:
591
+ """
592
+ Loads a depth image from an EXR file.
593
+
594
+ Args:
595
+ fname (str or Path): The filename of the EXR file to load.
596
+ fmt (str): The format of the output data. Can be one of:
597
+ - "torch": Returns a PyTorch tensor.
598
+ - "np": Returns a NumPy array.
599
+ - "PIL": Returns a PIL Image object.
600
+ Defaults to "torch".
601
+
602
+ Returns:
603
+ The loaded depth image in the specified output format.
604
+
605
+ Raises:
606
+ ValueError: If the loaded depth image does not have two dimensions.
607
+
608
+ Notes:
609
+ This function assumes that the EXR file contains a single-channel depth image.
610
+ """
611
+ data = _read_exr(fname, fmt)
612
+ if (fmt != "PIL") and (data.ndim != 2):
613
+ raise ValueError(f"Depth image needs to be 2D, but loaded: {data.shape}")
614
+ return data
615
+
616
+
617
+ def _store_normals(
618
+ fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs
619
+ ) -> bool:
620
+ """
621
+ Stores a normals image in an EXR file.
622
+
623
+ Args:
624
+ fname (str or Path): The filename to save the normals image to.
625
+ data (numpy.ndarray): The normals image data to save. Will be converted to a 32-bit float array.
626
+
627
+ Returns:
628
+ bool: True if the normals image was saved successfully, False otherwise.
629
+
630
+ Raises:
631
+ ValueError: If the input data has more than three dimensions after removing singleton dimensions.
632
+ ValueError: If the input data does not have exactly three channels.
633
+ ValueError: If the input data is not normalized (i.e., maximum absolute value exceeds 1).
634
+
635
+ Notes:
636
+ This function assumes that the input data is in HWC (height, width, channels) format.
637
+ If the input data is in CHW (channels, height, width) format, it will be automatically transposed to HWC.
638
+ """
639
+ data_np = to_numpy(data, dtype=np.float32)
640
+ data_np = data_np.squeeze() # remove all singleton dimensions
641
+
642
+ if data_np.ndim != 3:
643
+ raise ValueError(
644
+ f"Normals image needs to be 3-dim but received: {data_np.shape}"
645
+ )
646
+
647
+ if (data_np.shape[0] == 3) and (data_np.shape[2] != 3):
648
+ # ensure HWC format
649
+ data_np = data_np.transpose(1, 2, 0)
650
+
651
+ if data_np.shape[2] != 3:
652
+ raise ValueError(
653
+ f"Normals image needs have 3 channels but received: {data_np.shape}"
654
+ )
655
+
656
+ # We want to check that the norm values are either 1 (valid) or 0 (invalid values are 0s)
657
+ norm = np.linalg.norm(data_np, axis=-1)
658
+ is_one = np.isclose(norm, 1.0, atol=1e-3)
659
+ is_zero = np.isclose(norm, 0.0)
660
+ if not np.all([is_one | is_zero]):
661
+ raise ValueError("Normals image must be normalized")
662
+
663
+ return _write_exr(fname, data_np)
664
+
665
+
666
+ def _load_normals(
667
+ fname: str | Path, fmt: str = "torch", **kwargs
668
+ ) -> np.ndarray | torch.Tensor | Image.Image:
669
+ """
670
+ Loads a normals image from an EXR file.
671
+
672
+ Args:
673
+ fname (str or Path): The filename of the EXR file to load.
674
+ fmt (str): The format of the output data. Can be one of:
675
+ - "torch": Returns a PyTorch tensor.
676
+ - "np": Returns a NumPy array.
677
+ - "PIL": Returns a PIL Image object.
678
+ Defaults to "torch".
679
+
680
+ Returns:
681
+ The loaded normals image in the specified output format.
682
+
683
+ Raises:
684
+ Warning: If the loaded normals image has more than two dimensions.
685
+
686
+ Notes:
687
+ This function assumes that the EXR file contains a 3-channel normals image.
688
+ """
689
+ data = _read_exr(fname, fmt)
690
+
691
+ if data.ndim != 3:
692
+ raise ValueError(f"Normals image needs to be 3-dim but received: {data.shape}")
693
+
694
+ if data.shape[2] != 3:
695
+ raise ValueError(
696
+ f"Normals image needs have 3 channels but received: {data.shape}"
697
+ )
698
+
699
+ return data
700
+
701
+
702
+ def _load_numpy(fname: str | Path, allow_pickle: bool = False, **kwargs) -> np.ndarray:
703
+ """
704
+ Loads a NumPy array from a file.
705
+
706
+ Args:
707
+ fname (str or Path): The filename to load the NumPy array from.
708
+ allow_pickle (bool, optional): Whether to allow pickled objects in the NumPy file.
709
+ Defaults to False.
710
+
711
+ Returns:
712
+ numpy.ndarray: The loaded NumPy array.
713
+
714
+ Raises:
715
+ NotImplementedError: If the file suffix is not supported (i.e., not .npy or .npz).
716
+
717
+ Notes:
718
+ This function supports loading NumPy arrays from .npy and .npz files.
719
+ For .npz files, it assumes that the array is stored under the key "arr_0".
720
+ """
721
+ fname = Path(fname)
722
+ with open(fname, "rb") as fid:
723
+ if fname.suffix == ".npy":
724
+ return np.load(fid, allow_pickle=allow_pickle)
725
+ elif fname.suffix == ".npz":
726
+ return np.load(fid, allow_pickle=allow_pickle).get("arr_0")
727
+ else:
728
+ raise NotImplementedError(f"Numpy format not supported: {fname.suffix}")
729
+
730
+
731
+ def _store_numpy(fname: str | Path, data: np.ndarray, **kwargs) -> None:
732
+ """
733
+ Stores a NumPy array in a file.
734
+
735
+ Args:
736
+ fname (str or Path): The filename to store the NumPy array in.
737
+ data (numpy.ndarray): The NumPy array to store.
738
+
739
+ Raises:
740
+ NotImplementedError: If the file suffix is not supported (i.e., not .npy or .npz).
741
+
742
+ Notes:
743
+ This function supports storing NumPy arrays in .npy and .npz files.
744
+ For .npz files, it uses compression to reduce the file size.
745
+ """
746
+ fname = Path(fname)
747
+ with open(fname, "wb") as fid:
748
+ if fname.suffix == ".npy":
749
+ np.save(fid, data)
750
+ elif fname.suffix == ".npz":
751
+ np.savez_compressed(fid, arr_0=data)
752
+ else:
753
+ raise NotImplementedError(f"Numpy format not supported: {fname.suffix}")
754
+
755
+
756
+ def _load_ptz(fname: str | Path, **kwargs) -> torch.Tensor:
757
+ """
758
+ Loads a PyTorch tensor from a PTZ file.
759
+
760
+ Args:
761
+ fname (str or Path): The filename to load the tensor from.
762
+
763
+ Returns:
764
+ torch.Tensor: The loaded PyTorch tensor.
765
+
766
+ Notes:
767
+ This function assumes that the PTZ file contains a PyTorch tensor saved using `torch.save`.
768
+ If the tensor was saved in a different format, this function may fail.
769
+ """
770
+ with open(fname, "rb") as fid:
771
+ data = gzip.decompress(fid.read())
772
+ ## Note: if the following line fails, save PyTorch tensors in PTZ instead of NumPy
773
+ return torch.load(io.BytesIO(data), map_location="cpu", weights_only=True)
774
+
775
+
776
+ def _store_ptz(fname: str | Path, data: torch.Tensor, **kwargs) -> None:
777
+ """
778
+ Stores a PyTorch tensor in a PTZ file.
779
+
780
+ Args:
781
+ fname (str or Path): The filename to store the tensor in.
782
+ data (torch.Tensor): The PyTorch tensor to store.
783
+
784
+ Notes:
785
+ This function saves the tensor using `torch.save` and compresses it using gzip.
786
+ """
787
+ with open(fname, "wb") as fid:
788
+ with gzip.open(fid, "wb") as gfid:
789
+ torch.save(data, gfid)
790
+
791
+
792
+ def _store_mmap(fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs) -> str:
793
+ """
794
+ Stores matrix-shaped data in a memory-mapped file.
795
+
796
+ Args:
797
+ fname (str or Path): The filename to store the data in.
798
+ data (numpy.ndarray): The matrix-shaped data to store.
799
+
800
+ Returns:
801
+ str: The name of the stored memory-mapped file.
802
+
803
+ Notes:
804
+ This function stores the data in a .npy file with a modified filename that includes the shape of the data.
805
+ The data is converted to float32 format before storing.
806
+ """
807
+ fname = Path(fname)
808
+ # add dimensions to the file name for loading
809
+ data_np = to_numpy(data, dtype=np.float32)
810
+ shape_string = "x".join([str(dim) for dim in data_np.shape])
811
+ mmap_name = f"{fname.stem}--{shape_string}.npy"
812
+ with open(fname.parent / mmap_name, "wb") as fid:
813
+ np.save(fid, data_np)
814
+ return mmap_name
815
+
816
+
817
+ def _load_mmap(fname: str | Path, **kwargs) -> np.memmap:
818
+ """
819
+ Loads matrix-shaped data from a memory-mapped file.
820
+
821
+ Args:
822
+ fname (str or Path): The filename of the memory-mapped file to load.
823
+
824
+ Returns:
825
+ numpy.memmap: A memory-mapped array containing the loaded data.
826
+
827
+ Notes:
828
+ This function assumes that the filename contains the shape of the data, separated by 'x' or ','.
829
+ It uses this information to create a memory-mapped array with the correct shape.
830
+ """
831
+ shape_string = Path(Path(fname).name.split("--")[1]).stem
832
+ shape = [int(dim) for dim in shape_string.replace(",", "x").split("x")]
833
+ with open(fname, "rb") as fid:
834
+ return np.memmap(fid, dtype=np.float32, mode="r", shape=shape, offset=128)
835
+
836
+
837
+ def _store_scene_meta(fname: Path | str, scene_meta: dict[str, Any], **kwargs) -> None:
838
+ """
839
+ Stores scene metadata in a readable file.
840
+
841
+ Args:
842
+ fname (str or Path): The filename to store the scene metadata in.
843
+ scene_meta (dict): The scene metadata to store.
844
+
845
+ Notes:
846
+ This function updates the "last_modified" field of the scene metadata to the current date and time before storing it.
847
+ It also removes the "frame_names" field from the scene metadata, as it is not necessary to store this information.
848
+ Creates a backup of the existing file before overwriting it.
849
+ """
850
+ # update the modified date
851
+ scene_meta["last_modified"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
852
+ if "frame_names" in scene_meta:
853
+ del scene_meta["frame_names"]
854
+
855
+ # create/overwrite backup
856
+ fname_path = Path(fname)
857
+ if fname_path.exists():
858
+ backup_fname = fname_path.parent / f"_{fname_path.stem}_backup.json"
859
+ if backup_fname.exists():
860
+ backup_fname.unlink()
861
+ fname_path.rename(backup_fname)
862
+
863
+ _store_readable(fname, scene_meta)
864
+
865
+
866
+ def _load_scene_meta(fname: Path | str, **kwargs) -> dict[str, Any]:
867
+ """
868
+ Loads scene metadata from a readable file.
869
+
870
+ Args:
871
+ fname (str or Path): The filename to load the scene metadata from.
872
+
873
+ Returns:
874
+ dict: The loaded scene metadata, including an additional "frame_names" field that maps frame names to their indices.
875
+
876
+ Notes:
877
+ This function creates the "frame_names" field in the scene metadata for efficient lookup of frame indices by name.
878
+ """
879
+ scene_meta = _load_readable_structured(fname)
880
+ # create the frame_name -> frame_idx for efficiency
881
+ scene_meta["frame_names"] = {
882
+ frame["frame_name"]: frame_idx
883
+ for frame_idx, frame in enumerate(scene_meta["frames"])
884
+ }
885
+ return scene_meta
886
+
887
+
888
+ def _load_labeled_image(
889
+ fname: str | Path,
890
+ fmt: str = "torch",
891
+ resize: tuple[int, int] | None = None,
892
+ **kwargs,
893
+ ) -> np.ndarray | torch.Tensor | Image.Image:
894
+ """
895
+ Loads a labeled image from a PNG file.
896
+
897
+ Args:
898
+ fname (str or Path): The filename to load the image from.
899
+ fmt (str): The format of the output data. Can be one of:
900
+ - "torch": Returns a PyTorch int32 tensor with shape (H, W).
901
+ - "np": Returns a NumPy int32 array with shape (H, W).
902
+ - "pil": Returns a PIL Image object.
903
+ Defaults to "torch".
904
+ resize (tuple, optional): A tuple of two integers representing the desired width and height of the image.
905
+ If None, the image is not resized. Defaults to None.
906
+
907
+ Returns:
908
+ The loaded image in the specified output format.
909
+
910
+ Raises:
911
+ NotImplementedError: If the specified output format is not supported.
912
+ RuntimeError: If the 'id_to_color_mapping' is missing in the PNG metadata.
913
+
914
+ Notes:
915
+ The function expects the PNG file to contain metadata with a key 'id_to_color_mapping',
916
+ which maps from label ids to tuples of RGB values.
917
+ """
918
+ with open(fname, "rb") as f:
919
+ pil_image = Image.open(f)
920
+ pil_image.load()
921
+ if pil_image.mode != "RGB":
922
+ raise OSError(
923
+ f"Expected a RGB image in {fname}, but instead found an image with mode {pil_image.mode}"
924
+ )
925
+
926
+ # Load id to RGB mapping
927
+ color_palette_json = pil_image.info.get("id_to_color_mapping", None)
928
+ if color_palette_json is None:
929
+ raise RuntimeError("'id_to_color_mapping' is missing in the PNG metadata.")
930
+ color_palette = json.loads(color_palette_json)
931
+ color_to_id_mapping = {
932
+ tuple(color): int(id) for id, color in color_palette.items()
933
+ }
934
+
935
+ if resize is not None:
936
+ pil_image = pil_image.resize(resize, Image.NEAREST)
937
+
938
+ if fmt == "pil":
939
+ return pil_image
940
+
941
+ # Reverse the color mapping: map from RGB colors to ids
942
+ img_data = np.array(pil_image)
943
+
944
+ # Create a lookup table for fast mapping
945
+ max_color_value = 256 # Assuming 8-bit per channel
946
+ lookup_table = np.full(
947
+ (max_color_value, max_color_value, max_color_value),
948
+ INVALID_ID,
949
+ dtype=np.int32,
950
+ )
951
+ for color, index in color_to_id_mapping.items():
952
+ lookup_table[color] = index
953
+ # Map colors to ids using the lookup table
954
+ img_data = lookup_table[img_data[..., 0], img_data[..., 1], img_data[..., 2]]
955
+
956
+ if fmt == "np":
957
+ return img_data
958
+ elif fmt == "torch":
959
+ return torch.from_numpy(img_data)
960
+ else:
961
+ raise NotImplementedError(f"Image format not supported: {fmt}")
962
+
963
+
964
+ def _store_labeled_image(
965
+ fname: str | Path,
966
+ img_data: np.ndarray | torch.Tensor | Image.Image,
967
+ semantic_color_mapping: np.ndarray | None = None,
968
+ **kwargs,
969
+ ) -> None:
970
+ """
971
+ Stores a labeled image as a uint8 RGB PNG file.
972
+
973
+ Args:
974
+ fname (str or Path): The filename to store the image in.
975
+ img_data (numpy.ndarray, torch.Tensor or PIL.Image.Image): The per-pixel label ids to store.
976
+ semantic_color_mapping (np.ndarray): Optional, preloaded NumPy array of semantic colors.
977
+
978
+ Raises:
979
+ ValueError: If the file suffix is not supported (i.e., not .png).
980
+ RuntimeError: If the type of the image data is different from uint16, int16 or int32.
981
+
982
+ Notes:
983
+ The function takes an image with per-pixel label ids and converts it into an RGB image
984
+ using a specified mapping from label ids to RGB colors. The resulting image is saved as
985
+ a PNG file, with the mapping stored as metadata.
986
+ """
987
+ if Path(fname).suffix != ".png":
988
+ raise ValueError(
989
+ f"Only filenames with suffix .png allowed but received: {fname}"
990
+ )
991
+
992
+ if isinstance(img_data, Image.Image) and img_data.mode != "I;16":
993
+ raise RuntimeError(
994
+ f"The provided image does not seem to be a labeled image. The provided PIL image has mode {img_data.mode}."
995
+ )
996
+
997
+ if isinstance(img_data, np.ndarray) and img_data.dtype not in [
998
+ np.uint16,
999
+ np.int16,
1000
+ np.int32,
1001
+ ]:
1002
+ raise RuntimeError(
1003
+ f"The provided NumPy array has type {img_data.dtype} but the expected type is np.uint16, np.int16 or np.int32."
1004
+ )
1005
+
1006
+ if isinstance(img_data, torch.Tensor):
1007
+ if img_data.dtype not in [torch.uint16, torch.int16, torch.int32]:
1008
+ raise RuntimeError(
1009
+ f"The provided PyTorch tensor has type {img_data.dtype} but the expected type is torch.uint16, torch.int16 or torch.int32."
1010
+ )
1011
+ img_data = img_data.numpy()
1012
+
1013
+ if semantic_color_mapping is None:
1014
+ # Mapping from ids to colors not provided, load it now
1015
+ semantic_color_mapping = load_semantic_color_mapping()
1016
+
1017
+ img_data, color_palette = apply_id_to_color_mapping(
1018
+ img_data, semantic_color_mapping
1019
+ )
1020
+ pil_image = Image.fromarray(img_data, "RGB")
1021
+
1022
+ # Create a PngInfo object to store metadata
1023
+ meta = PngImagePlugin.PngInfo()
1024
+ meta.add_text("id_to_color_mapping", json.dumps(color_palette))
1025
+
1026
+ pil_image.save(fname, pnginfo=meta)
1027
+
1028
+
1029
+ def _load_generic_mesh(mesh_path: str | Path, **kwargs) -> trimesh.Trimesh:
1030
+ """Load mesh with the trimesh library.
1031
+
1032
+ Args:
1033
+ mesh_path (str): Path to the mesh file
1034
+
1035
+ Returns:
1036
+ The trimesh object from trimesh.load().
1037
+
1038
+ Raises:
1039
+ ValueError: If the file format is not supported.
1040
+ """
1041
+
1042
+ # needed to load big texture files
1043
+ Image.MAX_IMAGE_PIXELS = None
1044
+
1045
+ # load mesh with trimesh
1046
+ mesh_data = trimesh.load(mesh_path, process=False)
1047
+
1048
+ return mesh_data
1049
+
1050
+
1051
+ def _store_generic_mesh(
1052
+ file_path: str | Path, mesh_data: dict | trimesh.Trimesh, **kwargs
1053
+ ) -> None:
1054
+ """
1055
+ Dummy function for storing generic mesh data.
1056
+
1057
+ Args:
1058
+ file_path (str): The filename to store the mesh in.
1059
+ mesh_data (dict): Dictionary containing mesh data.
1060
+ **kwargs: Additional keyword arguments.
1061
+
1062
+ Raises:
1063
+ NotImplementedError: This function is not implemented yet.
1064
+ """
1065
+ raise NotImplementedError("Storing generic meshes is not implemented yet.")
1066
+
1067
+
1068
+ def _load_labeled_mesh(
1069
+ file_path: str | Path,
1070
+ fmt: str = "torch",
1071
+ palette: str = "rgb",
1072
+ **kwargs,
1073
+ ) -> dict | trimesh.Trimesh:
1074
+ """
1075
+ Loads a mesh from a labeled mesh file (PLY binary format).
1076
+
1077
+ Args:
1078
+ file_path (str): The path to the labeled mesh file (.ply).
1079
+ fmt (str): Output format of the mesh data. Can be one of:
1080
+ - "torch": Returns a dict of PyTorch tensors containing mesh data.
1081
+ - "np": Returns a dict of NumPy arrays containing mesh data.
1082
+ - "trimesh": Returns a trimesh mesh object.
1083
+ Defaults to "torch".
1084
+ palette (str): Output color of the trimesh mesh data. Can be one of:
1085
+ - "rgb": Colors the mesh with original rgb colors
1086
+ - "semantic_class": Colors the mesh with semantic class colors
1087
+ - "instance": Colors the mesh with semantic instance colors
1088
+ Applied only when fmt is "trimesh".
1089
+
1090
+ Returns:
1091
+ The loaded mesh in the specified output format.
1092
+
1093
+ Raises:
1094
+ NotImplementedError: If the specified output format is not supported.
1095
+
1096
+ Notes:
1097
+ This function reads a binary PLY file with vertex position, color, and optional
1098
+ semantic class and instance IDs. The faces are stored as lists of vertex indices.
1099
+ """
1100
+ # load data (NOTE: define known_list_len to enable faster read)
1101
+ ply_data = PlyData.read(file_path, known_list_len={"face": {"vertex_indices": 3}})
1102
+
1103
+ # get vertices
1104
+ vertex_data = ply_data["vertex"].data
1105
+ vertices = np.column_stack(
1106
+ (vertex_data["x"], vertex_data["y"], vertex_data["z"])
1107
+ ).astype(np.float32)
1108
+
1109
+ # initialize output data
1110
+ mesh_data = {}
1111
+ mesh_data["is_labeled_mesh"] = True
1112
+ mesh_data["vertices"] = vertices
1113
+
1114
+ # get faces if available
1115
+ if "face" in ply_data:
1116
+ faces = np.asarray(ply_data["face"].data["vertex_indices"]).astype(np.int32)
1117
+ mesh_data["faces"] = faces
1118
+
1119
+ # get rgb colors if available
1120
+ if all(color in vertex_data.dtype.names for color in ["red", "green", "blue"]):
1121
+ vertices_color = np.column_stack(
1122
+ (vertex_data["red"], vertex_data["green"], vertex_data["blue"])
1123
+ ).astype(np.uint8)
1124
+ mesh_data["vertices_color"] = vertices_color
1125
+
1126
+ # get vertices class and instance if available
1127
+ if "semantic_class_id" in vertex_data.dtype.names:
1128
+ vertices_class = vertex_data["semantic_class_id"].astype(np.int32)
1129
+ mesh_data["vertices_semantic_class_id"] = vertices_class
1130
+
1131
+ if "instance_id" in vertex_data.dtype.names:
1132
+ vertices_instance = vertex_data["instance_id"].astype(np.int32)
1133
+ mesh_data["vertices_instance_id"] = vertices_instance
1134
+
1135
+ # get class colors if available
1136
+ if all(
1137
+ color in vertex_data.dtype.names
1138
+ for color in [
1139
+ "semantic_class_red",
1140
+ "semantic_class_green",
1141
+ "semantic_class_blue",
1142
+ ]
1143
+ ):
1144
+ vertices_semantic_class_color = np.column_stack(
1145
+ (
1146
+ vertex_data["semantic_class_red"],
1147
+ vertex_data["semantic_class_green"],
1148
+ vertex_data["semantic_class_blue"],
1149
+ )
1150
+ ).astype(np.uint8)
1151
+ mesh_data["vertices_semantic_class_color"] = vertices_semantic_class_color
1152
+
1153
+ # get instance colors if available
1154
+ if all(
1155
+ color in vertex_data.dtype.names
1156
+ for color in ["instance_red", "instance_green", "instance_blue"]
1157
+ ):
1158
+ vertices_instance_color = np.column_stack(
1159
+ (
1160
+ vertex_data["instance_red"],
1161
+ vertex_data["instance_green"],
1162
+ vertex_data["instance_blue"],
1163
+ )
1164
+ ).astype(np.uint8)
1165
+ mesh_data["vertices_instance_color"] = vertices_instance_color
1166
+
1167
+ # convert data into output format (if needed)
1168
+ if fmt == "np":
1169
+ return mesh_data
1170
+ elif fmt == "torch":
1171
+ return {k: torch.tensor(v) for k, v in mesh_data.items()}
1172
+ elif fmt == "trimesh":
1173
+ trimesh_mesh = trimesh.Trimesh(
1174
+ vertices=mesh_data["vertices"], faces=mesh_data["faces"]
1175
+ )
1176
+ # color the mesh according to the palette
1177
+ if palette == "rgb":
1178
+ # original rgb colors
1179
+ if "vertices_color" in mesh_data:
1180
+ trimesh_mesh.visual.vertex_colors = mesh_data["vertices_color"]
1181
+ else:
1182
+ raise ValueError(
1183
+ f"Palette {palette} could not be applied. Missing vertices_color in mesh data."
1184
+ )
1185
+ elif palette == "semantic_class":
1186
+ # semantic class colors
1187
+ if "vertices_semantic_class_color" in mesh_data:
1188
+ trimesh_mesh.visual.vertex_colors = mesh_data[
1189
+ "vertices_semantic_class_color"
1190
+ ]
1191
+ else:
1192
+ raise ValueError(
1193
+ f"Palette {palette} could not be applied. Missing vertices_semantic_class_color in mesh data."
1194
+ )
1195
+ elif palette == "instance":
1196
+ # semantic instance colors
1197
+ if "vertices_instance_color" in mesh_data:
1198
+ trimesh_mesh.visual.vertex_colors = mesh_data["vertices_instance_color"]
1199
+ else:
1200
+ raise ValueError(
1201
+ f"Palette {palette} could not be applied. Missing vertices_instance_color in mesh data."
1202
+ )
1203
+ else:
1204
+ raise ValueError(f"Invalid palette: {palette}.")
1205
+ return trimesh_mesh
1206
+ else:
1207
+ raise NotImplementedError(f"Labeled mesh format not supported: {fmt}")
1208
+
1209
+
1210
+ def _store_labeled_mesh(file_path: str | Path, mesh_data: dict, **kwargs) -> None:
1211
+ """
1212
+ Stores a mesh in WAI format (PLY binary format).
1213
+
1214
+ Args:
1215
+ file_path (str): The filename to store the mesh in.
1216
+ mesh_data (dict): Dictionary containing mesh data with keys:
1217
+ - 'vertices' (numpy.ndarray): Array of vertex coordinates with shape (N, 3).
1218
+ - 'faces' (numpy.ndarray, optional): Array of face indices.
1219
+ - 'vertices_color' (numpy.ndarray, optional): Array of vertex colors with shape (N, 3).
1220
+ - 'vertices_semantic_class_id' (numpy.ndarray, optional): Array of semantic classes for each vertex with shape (N).
1221
+ - 'vertices_instance_id' (numpy.ndarray, optional): Array of instance IDs for each vertex with shape (N).
1222
+ - 'vertices_semantic_class_color' (numpy.ndarray, optional): Array of vertex semantic class colors with shape (N, 3).
1223
+ - 'vertices_instance_color' (numpy.ndarray, optional): Array of vertex instance colors with shape (N, 3).
1224
+
1225
+ Notes:
1226
+ This function writes a binary PLY file with vertex position, color, and optional
1227
+ semantic class and instance IDs. The faces are stored as lists of vertex indices.
1228
+ """
1229
+ # Validate input data
1230
+ if "vertices" not in mesh_data:
1231
+ raise ValueError("Mesh data must contain 'vertices'")
1232
+
1233
+ # create vertex data with properties
1234
+ vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4")]
1235
+ if "vertices_color" in mesh_data:
1236
+ vertex_dtype.extend([("red", "u1"), ("green", "u1"), ("blue", "u1")])
1237
+ if "vertices_semantic_class_id" in mesh_data:
1238
+ vertex_dtype.append(("semantic_class_id", "i4"))
1239
+ if "vertices_instance_id" in mesh_data:
1240
+ vertex_dtype.append(("instance_id", "i4"))
1241
+ if "vertices_semantic_class_color" in mesh_data:
1242
+ vertex_dtype.extend(
1243
+ [
1244
+ ("semantic_class_red", "u1"),
1245
+ ("semantic_class_green", "u1"),
1246
+ ("semantic_class_blue", "u1"),
1247
+ ]
1248
+ )
1249
+ if "vertices_instance_color" in mesh_data:
1250
+ vertex_dtype.extend(
1251
+ [("instance_red", "u1"), ("instance_green", "u1"), ("instance_blue", "u1")]
1252
+ )
1253
+ vertex_count = len(mesh_data["vertices"])
1254
+ vertex_data = np.zeros(vertex_count, dtype=vertex_dtype)
1255
+
1256
+ # vertex positions
1257
+ vertex_data["x"] = mesh_data["vertices"][:, 0]
1258
+ vertex_data["y"] = mesh_data["vertices"][:, 1]
1259
+ vertex_data["z"] = mesh_data["vertices"][:, 2]
1260
+
1261
+ # vertex colors
1262
+ if "vertices_color" in mesh_data:
1263
+ vertex_data["red"] = mesh_data["vertices_color"][:, 0]
1264
+ vertex_data["green"] = mesh_data["vertices_color"][:, 1]
1265
+ vertex_data["blue"] = mesh_data["vertices_color"][:, 2]
1266
+
1267
+ # vertex class
1268
+ if "vertices_semantic_class_id" in mesh_data:
1269
+ vertex_data["semantic_class_id"] = mesh_data["vertices_semantic_class_id"]
1270
+
1271
+ # vertex instance
1272
+ if "vertices_instance_id" in mesh_data:
1273
+ vertex_data["instance_id"] = mesh_data["vertices_instance_id"]
1274
+
1275
+ # vertex class colors
1276
+ if "vertices_semantic_class_color" in mesh_data:
1277
+ vertex_data["semantic_class_red"] = mesh_data["vertices_semantic_class_color"][
1278
+ :, 0
1279
+ ]
1280
+ vertex_data["semantic_class_green"] = mesh_data[
1281
+ "vertices_semantic_class_color"
1282
+ ][:, 1]
1283
+ vertex_data["semantic_class_blue"] = mesh_data["vertices_semantic_class_color"][
1284
+ :, 2
1285
+ ]
1286
+
1287
+ # vertex instance colors
1288
+ if "vertices_instance_color" in mesh_data:
1289
+ vertex_data["instance_red"] = mesh_data["vertices_instance_color"][:, 0]
1290
+ vertex_data["instance_green"] = mesh_data["vertices_instance_color"][:, 1]
1291
+ vertex_data["instance_blue"] = mesh_data["vertices_instance_color"][:, 2]
1292
+
1293
+ # initialize data to save
1294
+ vertex_element = PlyElement.describe(vertex_data, "vertex")
1295
+ data_to_save = [vertex_element]
1296
+
1297
+ # faces data
1298
+ if "faces" in mesh_data:
1299
+ face_dtype = [("vertex_indices", "i4", (3,))]
1300
+ face_data = np.zeros(len(mesh_data["faces"]), dtype=face_dtype)
1301
+ face_data["vertex_indices"] = mesh_data["faces"]
1302
+ face_element = PlyElement.describe(face_data, "face")
1303
+ data_to_save.append(face_element)
1304
+
1305
+ # Create and write a binary PLY file
1306
+ ply_data = PlyData(data_to_save, text=False)
1307
+ ply_data.write(file_path)
1308
+
1309
+
1310
+ def _get_method(
1311
+ fname: Path | str, format_type: str | None = None, load: bool = True
1312
+ ) -> Callable:
1313
+ """
1314
+ Returns a method for loading or storing data in a specific format.
1315
+
1316
+ Args:
1317
+ fname (str or Path): The filename to load or store data from/to.
1318
+ format_type (str, optional): The format of the data. If None, it will be inferred from the file extension.
1319
+ Defaults to None.
1320
+ load (bool, optional): Whether to return a method for loading or storing data.
1321
+ Defaults to True.
1322
+
1323
+ Returns:
1324
+ callable: A method for loading or storing data in the specified format.
1325
+
1326
+ Raises:
1327
+ ValueError: If the format cannot be inferred from the file extension.
1328
+ NotImplementedError: If the specified format is not supported.
1329
+
1330
+ Notes:
1331
+ This function supports various formats, including readable files (JSON, YAML), images, NumPy arrays,
1332
+ PyTorch tensors, memory-mapped files, and scene metadata.
1333
+ """
1334
+ fname = Path(fname)
1335
+ if format_type is None:
1336
+ # use default formats
1337
+ if fname.suffix in [".json", ".yaml", ".yml"]:
1338
+ format_type = "readable"
1339
+ elif fname.suffix in [".jpg", ".jpeg", ".png", ".webp"]:
1340
+ format_type = "image"
1341
+ elif fname.suffix in [".npy", ".npz"]:
1342
+ format_type = "numpy"
1343
+ elif fname.suffix == ".ptz":
1344
+ format_type = "ptz"
1345
+ elif fname.suffix == ".sft":
1346
+ format_type = "sft"
1347
+ elif fname.suffix == ".exr":
1348
+ format_type = "scalar"
1349
+ elif fname.suffix in [".glb", ".obj", ".ply"]:
1350
+ format_type = "mesh"
1351
+ else:
1352
+ raise ValueError(f"Cannot infer format for {fname}")
1353
+ methods = {
1354
+ "readable": (_load_readable, _store_readable),
1355
+ "scalar": (_read_exr, _write_exr),
1356
+ "image": (_load_image, _store_image),
1357
+ "binary": (_load_binary_mask, _store_binary_mask),
1358
+ "latent": (_load_sft, _store_sft),
1359
+ "depth": (_load_depth, _store_depth),
1360
+ "normals": (_load_normals, _store_normals),
1361
+ "numpy": (_load_numpy, _store_numpy),
1362
+ "ptz": (_load_ptz, _store_ptz),
1363
+ "sft": (_load_sft, _store_sft),
1364
+ "mmap": (_load_mmap, _store_mmap),
1365
+ "scene_meta": (_load_scene_meta, _store_scene_meta),
1366
+ "labeled_image": (_load_labeled_image, _store_labeled_image),
1367
+ "mesh": (_load_generic_mesh, _store_generic_mesh),
1368
+ "labeled_mesh": (_load_labeled_mesh, _store_labeled_mesh),
1369
+ }
1370
+ try:
1371
+ return methods[format_type][0 if load else 1]
1372
+ except KeyError as e:
1373
+ raise NotImplementedError(f"Format not supported: {format_type}") from e
mapanything/utils/wai/m_ops.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def m_dot(
6
+ transform: torch.Tensor,
7
+ points: torch.Tensor | list,
8
+ maintain_shape: bool = False,
9
+ ) -> torch.Tensor | list:
10
+ """
11
+ Apply batch matrix multiplication between transform matrices and points.
12
+
13
+ Args:
14
+ transform: Batch of transformation matrices [..., 3/4, 3/4]
15
+ points: Batch of points [..., N, 3] or a list of points
16
+ maintain_shape: If True, preserves the original shape of points
17
+
18
+ Returns:
19
+ Transformed points with shape [..., N, 3] or a list of transformed points
20
+ """
21
+ if isinstance(points, list):
22
+ return [m_dot(t, p, maintain_shape) for t, p in zip(transform, points)]
23
+
24
+ # Store original shape and flatten batch dimensions
25
+ orig_shape = points.shape
26
+ batch_dims = points.shape[:-3]
27
+
28
+ # Reshape to standard batch format
29
+ transform_flat = transform.reshape(-1, transform.shape[-2], transform.shape[-1])
30
+ points_flat = points.reshape(transform_flat.shape[0], -1, points.shape[-1])
31
+
32
+ # Apply transformation
33
+ pts = torch.bmm(
34
+ transform_flat[:, :3, :3],
35
+ points_flat[..., :3].permute(0, 2, 1).to(transform_flat.dtype),
36
+ ).permute(0, 2, 1)
37
+
38
+ if transform.shape[-1] == 4:
39
+ pts = pts + transform_flat[:, :3, 3].unsqueeze(1)
40
+
41
+ # Restore original shape
42
+ if maintain_shape:
43
+ return pts.reshape(orig_shape)
44
+ else:
45
+ return pts.reshape(*batch_dims, -1, 3)
46
+
47
+
48
+ def m_unproject(
49
+ depth: torch.Tensor,
50
+ intrinsic: torch.Tensor,
51
+ cam2world: torch.Tensor = None,
52
+ img_grid: torch.Tensor = None,
53
+ valid: torch.Tensor = None,
54
+ H: int | None = None,
55
+ W: int | None = None,
56
+ img_feats: torch.Tensor = None,
57
+ maintain_shape: bool = False,
58
+ ) -> torch.Tensor:
59
+ """
60
+ Unproject 2D image points with depth values to 3D points in camera or world space.
61
+
62
+ Args:
63
+ depth: Depth values, either a tensor of shape ...xHxW or a float value
64
+ intrinsic: Camera intrinsic matrix of shape ...x3x3
65
+ cam2world: Optional camera-to-world transformation matrix of shape ...x4x4
66
+ img_grid: Optional pre-computed image grid. If None, will be created
67
+ valid: Optional mask for valid depth values or minimum depth threshold
68
+ H: Image height (required if depth is a scalar)
69
+ W: Image width (required if depth is a scalar)
70
+ img_feats: Optional image features to append to 3D points
71
+ maintain_shape: If True, preserves the original shape of points
72
+
73
+ Returns:
74
+ 3D points in camera or world space, with optional features appended
75
+ """
76
+ # Get device and shape information from intrinsic matrix
77
+ device = intrinsic.device
78
+ pre_shape = intrinsic.shape[:-2] # Batch dimensions
79
+
80
+ # Validate inputs
81
+ if isinstance(depth, (int, float)) and H is None:
82
+ raise ValueError("H must be provided if depth is a scalar")
83
+
84
+ # Determine image dimensions from depth if not provided
85
+ if isinstance(depth, torch.Tensor) and H is None:
86
+ H, W = depth.shape[-2:]
87
+
88
+ # Create image grid if not provided
89
+ if img_grid is None:
90
+ # Create coordinate grid with shape HxWx3 (last dimension is homogeneous)
91
+ img_grid = _create_image_grid(H, W, device)
92
+ # Add homogeneous coordinate
93
+ img_grid = torch.cat([img_grid, torch.ones_like(img_grid[..., :1])], -1)
94
+
95
+ # Expand img_grid to match batch dimensions of intrinsic
96
+ if img_grid.dim() <= intrinsic.dim():
97
+ img_grid = img_grid.unsqueeze(0)
98
+ img_grid = img_grid.expand(*pre_shape, *img_grid.shape[-3:])
99
+
100
+ # Handle valid mask or minimum depth threshold
101
+ depth_mask = None
102
+ if valid is not None:
103
+ if isinstance(valid, float):
104
+ # Create mask for minimum depth value
105
+ depth_mask = depth > valid
106
+ elif isinstance(valid, torch.Tensor):
107
+ depth_mask = valid
108
+
109
+ # Apply mask to image grid and other inputs
110
+ img_grid = masking(img_grid, depth_mask, dim=intrinsic.dim())
111
+ if not isinstance(depth, (int, float)):
112
+ depth = masking(depth, depth_mask, dim=intrinsic.dim() - 1)
113
+ if img_feats is not None:
114
+ img_feats = masking(img_feats, depth_mask, dim=intrinsic.dim() - 1)
115
+
116
+ # Unproject 2D points to 3D camera space
117
+ cam_pts: torch.Tensor = m_dot(
118
+ m_inverse_intrinsics(intrinsic),
119
+ img_grid[..., [1, 0, 2]],
120
+ maintain_shape=True,
121
+ )
122
+ # Scale by depth values
123
+ cam_pts = mult(cam_pts, depth.unsqueeze(-1))
124
+
125
+ # Transform to world space if cam2world is provided
126
+ if cam2world is not None:
127
+ cam_pts = m_dot(cam2world, cam_pts, maintain_shape=True)
128
+
129
+ # Append image features if provided
130
+ if img_feats is not None:
131
+ if isinstance(cam_pts, list):
132
+ if isinstance(cam_pts[0], list):
133
+ # Handle nested list case
134
+ result = []
135
+ for batch_idx, batch in enumerate(cam_pts):
136
+ batch_result = []
137
+ for view_idx, view in enumerate(batch):
138
+ batch_result.append(
139
+ torch.cat([view, img_feats[batch_idx][view_idx]], -1)
140
+ )
141
+ result.append(batch_result)
142
+ cam_pts = result
143
+ else:
144
+ # Handle single list case
145
+ cam_pts = [
146
+ torch.cat([pts, feats], -1)
147
+ for pts, feats in zip(cam_pts, img_feats)
148
+ ]
149
+ else:
150
+ # Handle tensor case
151
+ cam_pts = torch.cat([cam_pts, img_feats], -1)
152
+
153
+ if maintain_shape:
154
+ return cam_pts
155
+
156
+ # Flatten last dimension
157
+ return cam_pts.reshape(*pre_shape, -1, 3)
158
+
159
+
160
+ def m_project(
161
+ world_pts: torch.Tensor,
162
+ intrinsic: torch.Tensor,
163
+ world2cam: torch.Tensor | None = None,
164
+ maintain_shape: bool = False,
165
+ ) -> torch.Tensor:
166
+ """
167
+ Project 3D world points to 2D image coordinates.
168
+
169
+ Args:
170
+ world_pts: 3D points in world coordinates
171
+ intrinsic: Camera intrinsic matrix
172
+ world2cam: Optional transformation from world to camera coordinates
173
+ maintain_shape: If True, preserves the original shape of points
174
+
175
+ Returns:
176
+ Image points with coordinates in img_y,img_x,z order
177
+ """
178
+ # Transform points from world to camera space if world2cam is provided
179
+ cam_pts: torch.Tensor = world_pts
180
+ if world2cam is not None:
181
+ cam_pts = m_dot(world2cam, world_pts, maintain_shape=maintain_shape)
182
+
183
+ # Get shapes to properly expand intrinsics
184
+ shared_dims = intrinsic.shape[:-2]
185
+ extra_dims = cam_pts.shape[len(shared_dims) : -1]
186
+
187
+ # Expand intrinsics to match cam_pts shape
188
+ expanded_intrinsic = intrinsic.view(*shared_dims, *([1] * len(extra_dims)), 3, 3)
189
+ expanded_intrinsic = expanded_intrinsic.expand(*shared_dims, *extra_dims, 3, 3)
190
+
191
+ # Project points from camera space to image space
192
+ depth_abs = cam_pts[..., 2].abs().clamp(min=1e-5)
193
+ return torch.stack(
194
+ [
195
+ expanded_intrinsic[..., 1, 1] * cam_pts[..., 1] / depth_abs
196
+ + expanded_intrinsic[..., 1, 2],
197
+ expanded_intrinsic[..., 0, 0] * cam_pts[..., 0] / depth_abs
198
+ + expanded_intrinsic[..., 0, 2],
199
+ cam_pts[..., 2],
200
+ ],
201
+ -1,
202
+ )
203
+
204
+
205
+ def in_image(
206
+ image_pts: torch.Tensor | list,
207
+ H: int,
208
+ W: int,
209
+ min_depth: float = 0.0,
210
+ ) -> torch.Tensor | list:
211
+ """
212
+ Check if image points are within the image boundaries.
213
+
214
+ Args:
215
+ image_pts: Image points in pixel coordinates
216
+ H: Image height
217
+ W: Image width
218
+ min_depth: Minimum valid depth
219
+
220
+ Returns:
221
+ Boolean mask indicating which points are within the image
222
+ """
223
+ is_list = isinstance(image_pts, list)
224
+ if is_list:
225
+ return [in_image(pts, H, W, min_depth=min_depth) for pts in image_pts]
226
+
227
+ in_image_mask = (
228
+ torch.all(image_pts >= 0, -1)
229
+ & (image_pts[..., 0] < H)
230
+ & (image_pts[..., 1] < W)
231
+ )
232
+ if (min_depth is not None) and image_pts.shape[-1] == 3:
233
+ in_image_mask &= image_pts[..., 2] > min_depth
234
+ return in_image_mask
235
+
236
+
237
+ def _create_image_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
238
+ """
239
+ Create a coordinate grid for image pixels.
240
+
241
+ Args:
242
+ H: Image height
243
+ W: Image width
244
+ device: Computation device
245
+
246
+ Returns:
247
+ Image grid with shape HxWx3 (last dimension is homogeneous)
248
+ """
249
+ y_coords = torch.arange(H, device=device)
250
+ x_coords = torch.arange(W, device=device)
251
+
252
+ # Use meshgrid with indexing="ij" for correct orientation
253
+ y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij")
254
+
255
+ # Stack coordinates and add homogeneous coordinate
256
+ img_grid = torch.stack([y_grid, x_grid, torch.ones_like(y_grid)], dim=-1)
257
+
258
+ return img_grid
259
+
260
+
261
+ def masking(
262
+ X: torch.Tensor | list,
263
+ mask: torch.Tensor | list,
264
+ dim: int = 3,
265
+ ) -> torch.Tensor | list:
266
+ """
267
+ Apply a Boolean mask to tensor or list elements.
268
+ Handles nested structures by recursively applying the mask.
269
+
270
+ Args:
271
+ X: Input tensor or list to be masked
272
+ mask: Boolean mask to apply
273
+ dim: Dimension threshold for recursive processing
274
+
275
+ Returns:
276
+ Masked tensor or list with the same structure as input
277
+ """
278
+ if isinstance(X, list) or (isinstance(X, torch.Tensor) and X.dim() >= dim):
279
+ return [masking(x, m, dim) for x, m in zip(X, mask)]
280
+ return X[mask]
281
+
282
+
283
+ def m_inverse_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor:
284
+ """
285
+ Compute the inverse of camera intrinsics matrices analytically.
286
+ This is much faster than using torch.inverse() for intrinsics matrices.
287
+
288
+ The intrinsics matrix has the form:
289
+ K = [fx s cx]
290
+ [0 fy cy]
291
+ [0 0 1]
292
+
293
+ And its inverse is:
294
+ K^-1 = [1/fx -s/(fx*fy) (s*cy-cx*fy)/(fx*fy)]
295
+ [0 1/fy -cy/fy ]
296
+ [0 0 1 ]
297
+
298
+ Args:
299
+ intrinsics: Camera intrinsics matrices of shape [..., 3, 3]
300
+
301
+ Returns:
302
+ Inverse intrinsics matrices of shape [..., 3, 3]
303
+ """
304
+ # Extract the components of the intrinsics matrix
305
+ fx = intrinsics[..., 0, 0]
306
+ s = intrinsics[..., 0, 1] # skew, usually 0
307
+ cx = intrinsics[..., 0, 2]
308
+ fy = intrinsics[..., 1, 1]
309
+ cy = intrinsics[..., 1, 2]
310
+
311
+ # Create output tensor with same shape and device
312
+ inv_intrinsics = torch.zeros_like(intrinsics)
313
+
314
+ # Compute the inverse analytically
315
+ inv_intrinsics[..., 0, 0] = 1.0 / fx
316
+ inv_intrinsics[..., 0, 1] = -s / (fx * fy)
317
+ inv_intrinsics[..., 0, 2] = (s * cy - cx * fy) / (fx * fy)
318
+ inv_intrinsics[..., 1, 1] = 1.0 / fy
319
+ inv_intrinsics[..., 1, 2] = -cy / fy
320
+ inv_intrinsics[..., 2, 2] = 1.0
321
+
322
+ return inv_intrinsics
323
+
324
+
325
+ def mult(
326
+ A: torch.Tensor | np.ndarray | list | float | int,
327
+ B: torch.Tensor | np.ndarray | list | float | int,
328
+ ) -> torch.Tensor | np.ndarray | list | float | int:
329
+ """
330
+ Multiply two objects with support for lists, tensors, arrays, and scalars.
331
+ Handles nested structures by recursively applying multiplication.
332
+
333
+ Args:
334
+ A: First operand (tensor, array, list, or scalar)
335
+ B: Second operand (tensor, array, list, or scalar)
336
+
337
+ Returns:
338
+ Result of multiplication with the same structure as inputs
339
+ """
340
+ if isinstance(A, list) and isinstance(B, (int, float)):
341
+ return [mult(a, B) for a in A]
342
+ if isinstance(B, list) and isinstance(A, (int, float)):
343
+ return [mult(A, b) for b in B]
344
+ if isinstance(A, list) and isinstance(B, list):
345
+ return [mult(a, b) for a, b in zip(A, B)]
346
+ return A * B
mapanything/utils/wai/ops.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This utils script contains PORTAGE of wai-core ops methods for MapAnything.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+
10
+
11
+ def to_numpy(
12
+ data: torch.Tensor | np.ndarray | int | float,
13
+ dtype: np.dtype | str | type = np.float32,
14
+ ) -> np.ndarray:
15
+ """
16
+ Convert data to a NumPy array with the specified dtype (default: float32).
17
+
18
+ This function handles conversion from NumPy arrays and PyTorch tensors to a NumPy array.
19
+
20
+ Args:
21
+ data: Input data (torch.Tensor, np.ndarray, or scalar)
22
+ dtype: Target data type (NumPy dtype, str, or type). Default: np.float32.
23
+
24
+ Returns:
25
+ Converted data as NumPy array with specified dtype.
26
+ """
27
+ # Set default dtype if not defined
28
+ assert dtype is not None, "dtype cannot be None"
29
+ dtype = np.dtype(dtype)
30
+
31
+ # Handle torch.Tensor
32
+ if isinstance(data, torch.Tensor):
33
+ return data.detach().cpu().numpy().astype(dtype)
34
+
35
+ # Handle numpy.ndarray
36
+ if isinstance(data, np.ndarray):
37
+ return data.astype(dtype)
38
+
39
+ # Handle scalar values
40
+ if isinstance(data, (int, float)):
41
+ return np.array(data, dtype=dtype)
42
+
43
+ raise NotImplementedError(f"Unsupported data type: {type(data)}")
44
+
45
+
46
+ def get_dtype_device(
47
+ data: torch.Tensor | np.ndarray | dict | list,
48
+ ) -> tuple[torch.dtype | np.dtype | None, torch.device | str | type | None]:
49
+ """
50
+ Determine the data type and device of the input data.
51
+
52
+ This function recursively inspects the input data and determines its data type
53
+ and device. It handles PyTorch tensors, NumPy arrays, dictionaries, and lists.
54
+
55
+ Args:
56
+ data: Input data (torch.Tensor, np.ndarray, dict, list, or other)
57
+
58
+ Returns:
59
+ tuple: (dtype, device) where:
60
+ - dtype: The data type (torch.dtype or np.dtype)
61
+ - device: The device (torch.device, 'cpu', 'cuda:X', or np.ndarray)
62
+
63
+ Raises:
64
+ ValueError: If tensors in a dictionary are on different CUDA devices
65
+ """
66
+ if isinstance(data, torch.Tensor):
67
+ return data.dtype, data.device
68
+
69
+ if isinstance(data, np.ndarray):
70
+ return data.dtype, np.ndarray
71
+
72
+ if isinstance(data, dict):
73
+ dtypes = {get_dtype_device(v)[0] for v in data.values()}
74
+ devices = {get_dtype_device(v)[1] for v in data.values()}
75
+ cuda_devices = {device for device in devices if str(device).startswith("cuda")}
76
+ cpu_devices = {device for device in devices if str(device).startswith("cpu")}
77
+ if (len(cuda_devices) > 0) or (len(cpu_devices) > 0):
78
+ # torch.tensor
79
+ dtype = torch.float
80
+ if all(dtype == torch.half for dtype in dtypes):
81
+ dtype = torch.half
82
+ device = None
83
+ if len(cuda_devices) > 1:
84
+ raise ValueError("All tensors must be on the same device")
85
+ if len(cuda_devices) == 1:
86
+ device = list(cuda_devices)[0]
87
+ if (device is None) and (len(cpu_devices) == 1):
88
+ device = list(cpu_devices)[0]
89
+ else:
90
+ dtype = np.float32
91
+ # Fix typo in numpy float16 check
92
+ if all(dtype == np.float16 for dtype in dtypes):
93
+ dtype = np.float16
94
+ device = np.ndarray
95
+
96
+ elif isinstance(data, list):
97
+ if not data: # Handle empty list case
98
+ return None, None
99
+ dtype, device = get_dtype_device(data[0])
100
+
101
+ else:
102
+ return np.float32, np.ndarray
103
+
104
+ return dtype, device
105
+
106
+
107
+ def crop(
108
+ data: np.ndarray | torch.Tensor | Image.Image,
109
+ bbox: tuple[int, int, int, int] | tuple[int, int],
110
+ ) -> np.ndarray | torch.Tensor | Image.Image:
111
+ """
112
+ Crop data of different formats (numpy arrays, PyTorch tensors, PIL Images) to a target size.
113
+
114
+ Args:
115
+ data: Input data to resize (numpy.ndarray, torch.Tensor, or PIL.Image.Image)
116
+ size: Target size as tuple (offset_height, offset_width, height, width) or tuple (height, width)
117
+
118
+ Returns:
119
+ Cropped data in the same format as the input
120
+ """
121
+ if len(bbox) == 4:
122
+ offset_height, offset_width, target_height, target_width = bbox
123
+ elif len(bbox) == 2:
124
+ target_height, target_width = bbox
125
+ offset_height, offset_width = 0, 0
126
+ else:
127
+ raise ValueError(f"Unsupported size length {len(bbox)}.")
128
+
129
+ end_height = offset_height + target_height
130
+ end_width = offset_width + target_width
131
+
132
+ if any([sz < 0 for sz in bbox]):
133
+ raise ValueError("Bounding box can't have negative values.")
134
+
135
+ if isinstance(data, np.ndarray):
136
+ if (
137
+ max(offset_height, end_height) > data.shape[0]
138
+ or max(offset_width, end_width) > data.shape[1]
139
+ ):
140
+ raise ValueError("Invalid bounding box.")
141
+ cropped_data = data[offset_height:end_height, offset_width:end_width, ...]
142
+ return cropped_data
143
+
144
+ # Handle PIL images
145
+ elif isinstance(data, Image.Image):
146
+ if (
147
+ max(offset_height, end_height) > data.size[1]
148
+ or max(offset_width, end_width) > data.size[0]
149
+ ):
150
+ raise ValueError("Invalid bounding box.")
151
+ return data.crop((offset_width, offset_height, end_width, end_height))
152
+
153
+ # Handle PyTorch tensors
154
+ elif isinstance(data, torch.Tensor):
155
+ if data.is_nested:
156
+ # special handling for nested tensors
157
+ return torch.stack([crop(nested_tensor, bbox) for nested_tensor in data])
158
+ if (
159
+ max(offset_height, end_height) > data.shape[-2]
160
+ or max(offset_width, end_width) > data.shape[-1]
161
+ ):
162
+ raise ValueError("Invalid bounding box.")
163
+ cropped_data = data[..., offset_height:end_height, offset_width:end_width]
164
+ return cropped_data
165
+ else:
166
+ raise TypeError(f"Unsupported data type '{type(data)}'.")
167
+
168
+
169
+ def stack(
170
+ data: list[
171
+ dict[str, torch.Tensor | np.ndarray]
172
+ | list[torch.Tensor | np.ndarray]
173
+ | tuple[torch.Tensor | np.ndarray]
174
+ ],
175
+ ) -> dict[str, torch.Tensor | np.ndarray] | list[torch.Tensor | np.ndarray]:
176
+ """
177
+ Stack a list of dictionaries into a single dictionary with stacked values.
178
+ Or when given a list of sublists, stack the sublists using torch or numpy stack
179
+ if the items are of equal size, or nested tensors if the items are PyTorch tensors
180
+ of different size.
181
+
182
+ This utility function is similar to PyTorch's collate function, but specifically
183
+ designed for stacking dictionaries of numpy arrays or PyTorch tensors.
184
+
185
+ Args:
186
+ data (list): A list of dictionaries with the same keys, where values are
187
+ either numpy arrays or PyTorch tensors.
188
+ OR
189
+ A list of sublist, where the values of sublists are PyTorch tensors
190
+ or np arrrays.
191
+
192
+ Returns:
193
+ dict: A dictionary with the same keys as input dictionaries, but with values
194
+ stacked along a new first dimension.
195
+ OR
196
+ list: If the input was a list with sublists, it returns a list with a stacked
197
+ output for each original input sublist.
198
+
199
+ Raises:
200
+ ValueError: If dictionaries in the list have inconsistent keys.
201
+ NotImplementedError: If input is not a list or contains non-dictionary elements.
202
+ """
203
+ if not isinstance(data, list):
204
+ raise NotImplementedError(f"Stack: Data type not supported: {data}")
205
+
206
+ if len(data) == 0:
207
+ return data
208
+
209
+ if all(isinstance(entry, dict) for entry in data):
210
+ stacked_data = {}
211
+ keys = list(data[0].keys())
212
+ if any(set(entry.keys()) != set(keys) for entry in data):
213
+ raise ValueError("Data not consistent for stacking")
214
+
215
+ for key in keys:
216
+ stacked_data[key] = []
217
+ for entry in data:
218
+ stacked_data[key].append(entry[key])
219
+
220
+ # stack it according to data format
221
+ if all(isinstance(v, np.ndarray) for v in stacked_data[key]):
222
+ stacked_data[key] = np.stack(stacked_data[key])
223
+ elif all(isinstance(v, torch.Tensor) for v in stacked_data[key]):
224
+ # Check if all tensors have the same shape
225
+ first_shape = stacked_data[key][0].shape
226
+ if all(tensor.shape == first_shape for tensor in stacked_data[key]):
227
+ stacked_data[key] = torch.stack(stacked_data[key])
228
+ else:
229
+ # Use nested tensors if shapes are not consistent
230
+ stacked_data[key] = torch.nested.nested_tensor(stacked_data[key])
231
+ return stacked_data
232
+
233
+ if all(isinstance(entry, list) for entry in data):
234
+ # new stacked data will be a list with all of the sublist
235
+ stacked_data = []
236
+ for sublist in data:
237
+ # stack it according to data format
238
+ if all(isinstance(v, np.ndarray) for v in sublist):
239
+ stacked_data.append(np.stack(sublist))
240
+ elif all(isinstance(v, torch.Tensor) for v in sublist):
241
+ # Check if all tensors have the same shape
242
+ first_shape = sublist[0].shape
243
+ if all(tensor.shape == first_shape for tensor in sublist):
244
+ stacked_data.append(torch.stack(sublist))
245
+ else:
246
+ # Use nested tensors if shapes are not consistent
247
+ stacked_data.append(torch.nested.nested_tensor(sublist))
248
+ return stacked_data
249
+
250
+ raise NotImplementedError(f"Stack: Data type not supported: {data}")
251
+
252
+
253
+ def resize(
254
+ data: np.ndarray | torch.Tensor | Image.Image,
255
+ size: tuple[int, int] | int | None = None,
256
+ scale: float | None = None,
257
+ modality_format: str | None = None,
258
+ ) -> np.ndarray | torch.Tensor | Image.Image:
259
+ """
260
+ Resize data of different formats (numpy arrays, PyTorch tensors, PIL Images) to a target size.
261
+
262
+ Args:
263
+ data: Input data to resize (numpy.ndarray, torch.Tensor, or PIL.Image.Image)
264
+ size: Target size as tuple (height, width) or single int for long-side scaling
265
+ scale: Scale factor to apply to the original dimensions
266
+ modality_format: Type of data being resized ('depth', 'normals', or None)
267
+ Affects interpolation method used
268
+
269
+ Returns:
270
+ Resized data in the same format as the input
271
+
272
+ Raises:
273
+ ValueError: If neither size nor scale is provided, or if both are provided
274
+ TypeError: If data is not a supported type
275
+ """
276
+ # Validate input parameters
277
+ if size is not None and scale is not None:
278
+ raise ValueError("Only one of size or scale should be provided.")
279
+
280
+ # Calculate size from scale if needed
281
+ if size is None:
282
+ if scale is None:
283
+ raise ValueError("Either size or scale must be provided.")
284
+
285
+ size = (1, 1)
286
+ if isinstance(data, (np.ndarray, torch.Tensor)):
287
+ size = (int(data.shape[-2] * scale), int(data.shape[-1] * scale))
288
+ elif isinstance(data, Image.Image):
289
+ size = (int(data.size[1] * scale), int(data.size[0] * scale))
290
+ else:
291
+ raise TypeError(f"Unsupported data type '{type(data)}'.")
292
+
293
+ # Handle long-side scaling when size is a single integer
294
+ elif isinstance(size, int):
295
+ long_side = size
296
+ if isinstance(data, (np.ndarray, torch.Tensor)):
297
+ if isinstance(data, torch.Tensor) and data.is_nested:
298
+ raise ValueError(
299
+ "Long-side scaling not support for nested tensors, use fixed size instead."
300
+ )
301
+ h, w = data.shape[-2], data.shape[-1]
302
+ elif isinstance(data, Image.Image):
303
+ w, h = data.size
304
+ else:
305
+ raise TypeError(f"Unsupported data type '{type(data)}'.")
306
+ if h > w:
307
+ size = (long_side, int(w * long_side / h))
308
+ else:
309
+ size = (int(h * long_side / w), long_side)
310
+
311
+ target_height, target_width = size
312
+
313
+ # Set interpolation method based on modality
314
+ if modality_format in ["depth", "normals"]:
315
+ interpolation = Image.Resampling.NEAREST
316
+ torch_interpolation = "nearest"
317
+ else:
318
+ interpolation = Image.Resampling.LANCZOS
319
+ torch_interpolation = "bilinear"
320
+
321
+ # Handle numpy arrays
322
+ if isinstance(data, np.ndarray):
323
+ pil_image = Image.fromarray(data)
324
+ resized_image = pil_image.resize((target_width, target_height), interpolation)
325
+ return np.array(resized_image)
326
+
327
+ # Handle PIL images
328
+ elif isinstance(data, Image.Image):
329
+ return data.resize((target_width, target_height), interpolation)
330
+
331
+ # Handle PyTorch tensors
332
+ elif isinstance(data, torch.Tensor):
333
+ if data.is_nested:
334
+ # special handling for nested tensors
335
+ return torch.stack(
336
+ [
337
+ resize(nested_tensor, size, scale, modality_format)
338
+ for nested_tensor in data
339
+ ]
340
+ )
341
+ original_dim = data.ndim
342
+ if original_dim == 2: # (H, W)
343
+ data = data.unsqueeze(0).unsqueeze(0) # Add channel and batch dimensions
344
+ elif original_dim == 3: # (C/B, H W)
345
+ if modality_format == "depth":
346
+ data = data.unsqueeze(1) # channel batch dimension
347
+ else:
348
+ data = data.unsqueeze(0) # Add batch dimension
349
+ resized_tensor = F.interpolate(
350
+ data,
351
+ size=(target_height, target_width),
352
+ mode=torch_interpolation,
353
+ align_corners=False if torch_interpolation != "nearest" else None,
354
+ )
355
+ if original_dim == 2:
356
+ return resized_tensor.squeeze(0).squeeze(
357
+ 0
358
+ ) # Remove batch and channel dimensions
359
+ elif original_dim == 3:
360
+ if modality_format == "depth":
361
+ return resized_tensor.squeeze(1) # Remove channel dimension
362
+
363
+ return resized_tensor.squeeze(0) # Remove batch dimension
364
+ else:
365
+ return resized_tensor
366
+
367
+ else:
368
+ raise TypeError(f"Unsupported data type '{type(data)}'.")
mapanything/utils/wai/scene_frame.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import re
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+
10
+ from mapanything.utils.wai.io import (
11
+ _load_readable,
12
+ _load_scene_meta,
13
+ get_processing_state,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_scene_frame_names(
20
+ cfg: dict | object,
21
+ root: Path | str | None = None,
22
+ scene_frames_fn: str | None = None,
23
+ keyframes: bool = True,
24
+ ) -> dict[str, list[str | float]] | None:
25
+ """
26
+ Retrieve scene frame names based on configuration and optional parameters.
27
+
28
+ This function determines the scene frame names by resolving the scene frame file
29
+ and applying any necessary filters based on the provided configuration.
30
+
31
+ Args:
32
+ cfg: Configuration object containing settings and parameters.
33
+ root: Optional root directory path. If not provided, it will be fetched from cfg.
34
+ scene_frames_fn: Optional scene frames file name. If not provided, it will be fetched from cfg.
35
+ keyframes: Optional, used only for a video. If True (default), return only keyframes (with camera poses).
36
+
37
+ Returns:
38
+ A dictionary mapping scene names to their respective frame names.
39
+ """
40
+ scene_frames_fn = (
41
+ cfg.get("scene_frames_fn") if scene_frames_fn is None else scene_frames_fn
42
+ )
43
+ scene_frame_names = None
44
+ if scene_frames_fn is not None:
45
+ # load scene_frames based on scene_frame file
46
+ scene_frame_names = _resolve_scene_frames_fn(scene_frames_fn)
47
+
48
+ scene_names = get_scene_names(
49
+ cfg,
50
+ root=root,
51
+ scene_names=(
52
+ list(scene_frame_names.keys()) if scene_frame_names is not None else None
53
+ ),
54
+ )
55
+ scene_frame_names = _resolve_scene_frame_names(
56
+ cfg,
57
+ scene_names,
58
+ root=root,
59
+ scene_frame_names=scene_frame_names,
60
+ keyframes=keyframes,
61
+ )
62
+ return scene_frame_names
63
+
64
+
65
+ def get_scene_names(
66
+ cfg: dict | object,
67
+ root: Path | str | None = None,
68
+ scene_names: list[str] | None = None,
69
+ shuffle: bool = False,
70
+ ) -> list[str]:
71
+ """
72
+ Retrieve scene names based on the provided configuration and optional parameters.
73
+
74
+ This function determines the scene names by checking the root directory for subdirectories
75
+ and applying any necessary filters based on the provided configuration.
76
+
77
+ Args:
78
+ cfg: Configuration object containing settings and parameters.
79
+ root: Optional root directory path. If not provided, it will be fetched from cfg.
80
+ scene_names: Optional list of scene names. If not provided, it will be determined from the root directory.
81
+ shuffle: Optional bool. Default to False. If True, it will return the list of scene names in random order.
82
+
83
+ Returns:
84
+ A list of scene names after applying any filters specified in the configuration.
85
+ """
86
+ root = cfg.get("root") if root is None else root
87
+ if root is not None:
88
+ # Check if the root exists
89
+ if not Path(root).exists():
90
+ raise IOError(f"Root directory does not exist: {root}")
91
+
92
+ # Check if the root is a directory
93
+ if not Path(root).is_dir():
94
+ raise IOError(f"Root directory is not a directory: {root}")
95
+
96
+ if scene_names is None:
97
+ scene_filters = cfg.get("scene_filters")
98
+ if (
99
+ scene_filters
100
+ and len(scene_filters) == 1
101
+ and isinstance(scene_filters[0], list)
102
+ and all(isinstance(entry, str) for entry in scene_filters[0])
103
+ ):
104
+ # Shortcut the scene_names if the scene_filters is only a list of scene names
105
+ scene_names = scene_filters[0]
106
+ else:
107
+ # List all subdirectories in the root as scenes
108
+ scene_names = sorted(
109
+ [entry.name for entry in os.scandir(root) if entry.is_dir()]
110
+ )
111
+ # Filter scenes based on scene_filters
112
+ scene_names = _filter_scenes(root, scene_names, cfg.get("scene_filters"))
113
+
114
+ # shuffle the list if needed (in place)
115
+ if shuffle:
116
+ random.shuffle(scene_names)
117
+
118
+ return scene_names
119
+
120
+
121
+ def _filter_scenes(
122
+ root: Path | str,
123
+ scene_names: list[str],
124
+ scene_filters: tuple | list | None,
125
+ ) -> list[str]:
126
+ if scene_filters is None:
127
+ return scene_names
128
+
129
+ if not isinstance(scene_filters, (tuple, list)):
130
+ raise ValueError("scene_filters must be a list or tuple")
131
+
132
+ for scene_filter in scene_filters:
133
+ if scene_filter in [None, "all"]:
134
+ pass
135
+
136
+ elif isinstance(scene_filter, (tuple, list)):
137
+ if len(scene_filter) == 0:
138
+ raise ValueError("scene_filter cannot be empty")
139
+
140
+ elif all(isinstance(x, int) for x in scene_filter):
141
+ if len(scene_filter) == 2:
142
+ # start/end index
143
+ scene_names = scene_names[scene_filter[0] : scene_filter[1]]
144
+ elif len(scene_filter) == 3:
145
+ # start/end/step
146
+ scene_names = scene_names[
147
+ scene_filter[0] : scene_filter[1] : scene_filter[2]
148
+ ]
149
+ else:
150
+ # omegaconf conversion issue (converts strings to integers whenever possible)
151
+ if str(scene_filter[0]) in scene_names:
152
+ scene_names = [str(s) for s in scene_filter]
153
+ else:
154
+ raise ValueError(
155
+ "scene_filter format [start_idx, end_idx] or [start_idx, end_idx, step_size] or [scene_name1, scene_name2, ...]"
156
+ )
157
+
158
+ elif all(isinstance(x, str) for x in scene_filter):
159
+ # explicit scene names
160
+ if set(scene_filter).issubset(set(scene_names)):
161
+ scene_names = list(scene_filter)
162
+ else:
163
+ logger.warning(
164
+ f"Scene(s) not available: {set(scene_filter) - set(scene_names)}"
165
+ )
166
+ scene_names = list(set(scene_names) & set(scene_filter))
167
+ else:
168
+ raise TypeError(
169
+ f"Scene filter type not supported: {type(scene_filter)}"
170
+ )
171
+
172
+ elif isinstance(scene_filter, dict):
173
+ # reserved key words
174
+ if modality := scene_filter.get("exists"):
175
+ scene_names = [
176
+ scene_name
177
+ for scene_name in scene_names
178
+ if Path(root, scene_name, modality).exists()
179
+ ]
180
+
181
+ elif modality := scene_filter.get("exists_not"):
182
+ scene_names = [
183
+ scene_name
184
+ for scene_name in scene_names
185
+ if not Path(root, scene_name, modality).exists()
186
+ ]
187
+
188
+ elif process_filter := scene_filter.get("process_state"):
189
+ # filter for where <process_key> has <process_state>
190
+ (process_key, process_state) = process_filter
191
+ filtered_scene_names = []
192
+ for scene_name in scene_names:
193
+ # load processing state and check for
194
+ processing_state = get_processing_state(Path(root, scene_name))
195
+ if "*" in process_key: # regex matching
196
+ for process_name in processing_state:
197
+ if re.match(process_key, process_name):
198
+ process_key = process_name
199
+ break
200
+ if process_key not in processing_state:
201
+ continue
202
+ if processing_state[process_key]["state"] == process_state:
203
+ filtered_scene_names.append(scene_name)
204
+ scene_names = filtered_scene_names
205
+
206
+ elif process_filter := scene_filter.get("process_state_not"):
207
+ # filter for where <process_key> does not have <process_state>
208
+ (process_key, process_state) = process_filter
209
+ filtered_scene_names = []
210
+ for scene_name in scene_names:
211
+ # load processing state and check for
212
+ try:
213
+ processing_state = get_processing_state(Path(root, scene_name))
214
+ except Exception:
215
+ filtered_scene_names.append(scene_name)
216
+ continue
217
+ if "*" in process_key: # regex matching
218
+ for process_name in processing_state:
219
+ if re.match(process_key, process_name):
220
+ process_key = process_name
221
+ break
222
+ if (process_key not in processing_state) or (
223
+ processing_state[process_key]["state"] != process_state
224
+ ):
225
+ filtered_scene_names.append(scene_name)
226
+ scene_names = filtered_scene_names
227
+
228
+ else:
229
+ raise ValueError(f"Scene filter not supported: {scene_filter}")
230
+
231
+ elif isinstance(scene_filter, str):
232
+ # regex
233
+ scene_names = [
234
+ scene_name
235
+ for scene_name in scene_names
236
+ if re.fullmatch(scene_filter, scene_name)
237
+ ]
238
+ else:
239
+ raise ValueError(f"Scene filter not supported: {scene_filter}")
240
+
241
+ return scene_names
242
+
243
+
244
+ def _resolve_scene_frames_fn(scene_frames_fn: str) -> dict[str, list[str] | None]:
245
+ # support for file list in forms of lists or dicts
246
+ # containing scene_names [-> frames]
247
+ scene_frames_list = _load_readable(scene_frames_fn)
248
+ scene_frame_names = {}
249
+
250
+ # TODO: The following code seems unreachable as scene_frames_list is always a dict
251
+ if isinstance(scene_frames_list, (list, tuple)):
252
+ for entry in scene_frames_list:
253
+ if isinstance(entry, (tuple, list)):
254
+ if (
255
+ (len(entry) != 2)
256
+ or (not isinstance(entry[0], str))
257
+ or (not isinstance(entry[1], list))
258
+ ):
259
+ raise NotImplementedError(
260
+ "Only supports lists of [<scene_name>, [frame_names]]"
261
+ )
262
+ scene_frame_names[entry[0]] = entry[1]
263
+ elif isinstance(entry, str):
264
+ scene_frame_names[entry] = None
265
+ elif isinstance(entry, dict):
266
+ # scene_name -> frames
267
+ raise NotImplementedError("Dict entry not supported yet")
268
+ else:
269
+ raise IOError(f"File list contains an entry of wrong format: {entry}")
270
+
271
+ elif isinstance(scene_frames_list, dict):
272
+ # scene_name -> frames
273
+ for scene_name, frame in scene_frames_list.items():
274
+ if isinstance(frame, (tuple, list)):
275
+ scene_frame_names[scene_name] = frame
276
+ elif isinstance(frame, dict):
277
+ if "frame_names" in frame:
278
+ scene_frame_names[scene_name] = frame["frame_names"]
279
+ else:
280
+ raise IOError(f"Scene frames format not supported: {frame}")
281
+ elif frame is None:
282
+ scene_frame_names[scene_name] = frame
283
+ else:
284
+ raise IOError(f"Scene frames format not supported: {frame}")
285
+
286
+ else:
287
+ raise IOError(f"Scene frames format not supported: {scene_frames_list}")
288
+
289
+ return scene_frame_names
290
+
291
+
292
+ def _resolve_scene_frame_names(
293
+ cfg: dict | object,
294
+ scene_names: list[str],
295
+ root: Path | str | None = None,
296
+ scene_frame_names: dict[str, list[str | float] | None] | None = None,
297
+ keyframes: bool = True,
298
+ ) -> dict[str, list[str]]:
299
+ root = cfg.get("root") if root is None else root
300
+ if scene_frame_names is not None:
301
+ # restrict to the additional scene-level prefiltering
302
+ scene_frame_names = {
303
+ scene_name: scene_frame_names[scene_name] for scene_name in scene_names
304
+ }
305
+ # dict already loaded, apply additional filters
306
+ for scene_name, frame_names in scene_frame_names.items():
307
+ if frame_names is None:
308
+ scene_meta = _load_scene_meta(
309
+ Path(
310
+ root, scene_name, cfg.get("scene_meta_path", "scene_meta.json")
311
+ )
312
+ )
313
+ frame_names = [frame["frame_name"] for frame in scene_meta["frames"]]
314
+ # TODO: add some logic for video keyframes
315
+
316
+ scene_frame_names[scene_name] = _filter_frame_names(
317
+ root, frame_names, scene_name, cfg.get("frame_filters")
318
+ )
319
+ else:
320
+ scene_frame_names = {}
321
+ for scene_name in scene_names:
322
+ scene_meta = _load_scene_meta(
323
+ Path(root, scene_name, cfg.get("scene_meta_path", "scene_meta.json"))
324
+ )
325
+ if not keyframes:
326
+ frame_names = get_video_frames(scene_meta)
327
+ if frame_names is None:
328
+ keyframes = True
329
+ if keyframes:
330
+ frame_names = [frame["frame_name"] for frame in scene_meta["frames"]]
331
+ frame_names = _filter_frame_names(
332
+ root, frame_names, scene_name, cfg.get("frame_filters")
333
+ )
334
+ scene_frame_names[scene_name] = frame_names
335
+ return scene_frame_names
336
+
337
+
338
+ def _filter_frame_names(
339
+ root: Path | str,
340
+ frame_names: list[str],
341
+ scene_name: str,
342
+ frame_filters: list | tuple | None,
343
+ ) -> list[str]:
344
+ if frame_filters is None:
345
+ return frame_names
346
+
347
+ if not isinstance(frame_filters, (tuple, list)):
348
+ raise ValueError("frame_filters must be a list or tuple")
349
+
350
+ for frame_filter in frame_filters:
351
+ if frame_filter in [None, "all"]:
352
+ pass
353
+
354
+ elif isinstance(frame_filter, (tuple, list)):
355
+ if len(frame_filter) == 0:
356
+ raise ValueError("frame_filter cannot be empty")
357
+
358
+ if isinstance(frame_filter[0], int):
359
+ if len(frame_filter) == 2:
360
+ # start/end index
361
+ frame_names = frame_names[frame_filter[0] : frame_filter[1]]
362
+
363
+ elif len(frame_filter) == 3:
364
+ # start/end/step
365
+ frame_names = frame_names[
366
+ frame_filter[0] : frame_filter[1] : frame_filter[2]
367
+ ]
368
+
369
+ else:
370
+ raise ValueError(
371
+ "frame_filter format [start_idx, end_idx] or [start_idx, end_idx,step_size]"
372
+ )
373
+ else:
374
+ raise TypeError(
375
+ f"frame_filter[0] type not supported: {type(frame_filter[0])}"
376
+ )
377
+
378
+ elif isinstance(frame_filter, str):
379
+ # reserved key words
380
+ if match := re.match("exists: (.+)", frame_filter):
381
+ modality = match.group(1)
382
+ frame_names = [
383
+ frame_name
384
+ for frame_name in frame_names
385
+ if any(Path(root, scene_name, modality).glob(f"{frame_name}.*"))
386
+ ]
387
+
388
+ elif match := re.match("!exists: (.+)", frame_filter):
389
+ modality = match.group(1)
390
+ frame_names = [
391
+ frame_name
392
+ for frame_name in frame_names
393
+ if not any(Path(root, scene_name, modality).glob(f"{frame_name}.*"))
394
+ ]
395
+
396
+ else: # general regex
397
+ frame_names = [
398
+ frame_name
399
+ for frame_name in frame_names
400
+ if re.match(frame_filter, frame_name)
401
+ ]
402
+
403
+ else:
404
+ raise ValueError(f"frame_filter type not supported: {type(frame_filter)}")
405
+
406
+ return frame_names
407
+
408
+
409
+ def get_video_frames(scene_meta: dict[str, Any]):
410
+ """
411
+ Return names of video frames.
412
+ Args:
413
+ scene_meta: dictionary with scene_meat data.
414
+
415
+ Returns:
416
+ A list of video frame names.
417
+ """
418
+ image_modality = [mod for mod in scene_meta["frame_modalities"] if "image" in mod]
419
+ if len(image_modality) > 0:
420
+ image_modality = scene_meta["frame_modalities"][image_modality[0]]
421
+ if "chunks" in image_modality:
422
+ file_list = image_modality["chunks"]
423
+ else:
424
+ file_list = [image_modality]
425
+ frame_names = []
426
+ for chunk in file_list:
427
+ start, end, fps = chunk["start"], chunk["end"], chunk["fps"]
428
+ chunk_frame_names = np.arange(start, end, 1.0 / fps).tolist()
429
+ frame_names += chunk_frame_names
430
+ return frame_names
431
+ return None
mapanything/utils/wai/semantics.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This utils script contains PORTAGE of wai-core semantics methods for MapAnything.
3
+ """
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ INVALID_ID = 0
9
+ INVALID_COLOR = (0, 0, 0)
10
+
11
+
12
+ def load_semantic_color_mapping(filename: str = "colors_fps_5k.npz") -> np.ndarray:
13
+ """Loads a precomputed colormap."""
14
+ from mapanything.utils.wai.core import WAI_COLORMAP_PATH
15
+
16
+ return np.load(WAI_COLORMAP_PATH / filename).get("arr_0")
17
+
18
+
19
+ def apply_id_to_color_mapping(
20
+ data_id: np.ndarray | Image.Image,
21
+ semantic_color_mapping: np.ndarray,
22
+ ) -> tuple[np.ndarray, dict[int, tuple[int, int, int]]]:
23
+ """Maps semantic class/instance IDs to RGB colors."""
24
+ if isinstance(data_id, Image.Image):
25
+ data_id = np.array(data_id)
26
+
27
+ max_color_id = semantic_color_mapping.shape[0] - 1
28
+ max_data_id = data_id.max()
29
+ if max_data_id > max_color_id:
30
+ raise ValueError("The provided color palette does not have enough colors!")
31
+
32
+ # Create palette containing the id->color mappings of the input data IDs
33
+ unique_indices = np.unique(data_id).tolist()
34
+ color_palette = {
35
+ index: semantic_color_mapping[index, :].tolist() for index in unique_indices
36
+ }
37
+
38
+ data_colors = semantic_color_mapping[data_id]
39
+
40
+ return data_colors, color_palette
requirements.txt CHANGED
@@ -18,4 +18,5 @@ einops
18
  requests
19
  psutil
20
  tqdm
 
21
  uniception==0.1.4
 
18
  requests
19
  psutil
20
  tqdm
21
+ safetensors
22
  uniception==0.1.4