nkeetha user commited on
Commit
19d7794
·
1 Parent(s): 75d19ab

Update Model & Examples

Browse files
app.py CHANGED
@@ -124,7 +124,9 @@ def run_model(
124
  # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
125
  # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
126
  # Use checkbox values - mask_edges is set to True by default since there's no UI control for it
127
- outputs = model.infer(views, apply_mask=apply_mask, mask_edges=True)
 
 
128
 
129
  # Convert predictions to format expected by visualization
130
  predictions = {}
 
124
  # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
125
  # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
126
  # Use checkbox values - mask_edges is set to True by default since there's no UI control for it
127
+ outputs = model.infer(
128
+ views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
129
+ )
130
 
131
  # Convert predictions to format expected by visualization
132
  predictions = {}
examples/Cat-Girl/Cat_Girl.png ADDED

Git LFS Details

  • SHA256: 57fa6d587d598e7a428e8997b86d5c3a06e0e18529bfad8bab78ae03a1f5820f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
examples/Downtown/Downtown.jpg ADDED

Git LFS Details

  • SHA256: b87a72df1e3a010f4003ea8c2e7c08d1c6277009d369d684be93a44fb3593a19
  • Pointer size: 130 Bytes
  • Size of remote file: 52.5 kB
examples/Office/Office.jpg ADDED

Git LFS Details

  • SHA256: 28767640002f93b703b24a34a6d75ca24b1ef093a19f52ef0f9d3b074ef68c61
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
examples/Safari-Car/Safari_Car.jpg ADDED

Git LFS Details

  • SHA256: cc0b2cf1882f9ad3b0f284474a72ea7c30b5ba40101f9a5d7e899da70ea43d06
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
mapanything/models/mapanything/model.py CHANGED
@@ -4,7 +4,7 @@ MapAnything model class defined using UniCeption modules.
4
 
5
  import warnings
6
  from functools import partial
7
- from typing import Any, Callable, Dict, List, Type, Union
8
 
9
  import torch
10
  import torch.nn as nn
@@ -1255,7 +1255,221 @@ class MapAnything(nn.Module, PyTorchModelHubMixin):
1255
 
1256
  return fused_all_encoder_features_across_views
1257
 
1258
- def forward(self, views):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1259
  """
1260
  Forward pass performing the following operations:
1261
  1. Encodes the N input views (images).
@@ -1279,6 +1493,7 @@ class MapAnything(nn.Module, PyTorchModelHubMixin):
1279
  "camera_pose_quats" (tensor): Camera pose quaternions. Tensor of shape (B, 4). Camera pose is opencv (RDF) cam2world transformation.
1280
  "camera_pose_trans" (tensor): Camera pose translations. Tensor of shape (B, 3). Camera pose is opencv (RDF) cam2world transformation.
1281
  "is_metric_scale" (tensor): Boolean tensor indicating whether the geometric inputs are in metric scale or not. Tensor of shape (B, 1).
 
1282
 
1283
  Returns:
1284
  List[dict]: A list containing the final outputs for all N views.
@@ -1376,72 +1591,25 @@ class MapAnything(nn.Module, PyTorchModelHubMixin):
1376
  f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
1377
  )
1378
 
1379
- # Downstream task prediction
1380
  with torch.autocast("cuda", enabled=False):
1381
- # Run Prediction Heads & Post-Process Outputs
1382
  if self.pred_head_type == "linear":
1383
- dense_head_outputs = self.dense_head(
1384
- PredictionHeadInput(last_feature=dense_head_inputs)
1385
- )
1386
- dense_final_outputs = self.dense_adaptor(
1387
- AdaptorInput(
1388
- adaptor_feature=dense_head_outputs.decoded_channels,
1389
- output_shape_hw=img_shape,
1390
- )
1391
- )
1392
- elif self.pred_head_type == "dpt":
1393
- dense_head_outputs = self.dense_head(
1394
- PredictionHeadLayeredInput(
1395
- list_features=dense_head_inputs_list,
1396
- target_output_shape=img_shape,
1397
- )
1398
- )
1399
- dense_final_outputs = self.dense_adaptor(
1400
- AdaptorInput(
1401
- adaptor_feature=dense_head_outputs.decoded_channels,
1402
- output_shape_hw=img_shape,
1403
- )
1404
- )
1405
- elif self.pred_head_type == "dpt+pose":
1406
- dense_head_outputs = self.dense_head(
1407
- PredictionHeadLayeredInput(
1408
- list_features=dense_head_inputs_list,
1409
- target_output_shape=img_shape,
1410
- )
1411
- )
1412
- dense_final_outputs = self.dense_adaptor(
1413
- AdaptorInput(
1414
- adaptor_feature=dense_head_outputs.decoded_channels,
1415
- output_shape_hw=img_shape,
1416
- )
1417
- )
1418
- pose_head_outputs = self.pose_head(
1419
- PredictionHeadInput(last_feature=dense_head_inputs_list[-1])
1420
- )
1421
- pose_final_outputs = self.pose_adaptor(
1422
- AdaptorInput(
1423
- adaptor_feature=pose_head_outputs.decoded_channels,
1424
- output_shape_hw=img_shape,
1425
- )
1426
- )
1427
- else:
1428
- raise ValueError(
1429
- f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
1430
- )
1431
- scale_head_output = self.scale_head(
1432
- PredictionHeadTokenInput(
1433
- last_feature=final_info_sharing_multi_view_feat.additional_token_features
1434
- )
1435
  )
1436
- scale_final_output = self.scale_adaptor(
1437
- AdaptorInput(
1438
- adaptor_feature=scale_head_output.decoded_channels,
1439
- output_shape_hw=img_shape,
 
 
 
 
1440
  )
1441
  )
1442
- scale_final_output = scale_final_output.value.squeeze(
1443
- -1
1444
- ) # (B, 1, 1) -> (B, 1)
1445
 
1446
  # Prepare the final scene representation for all views
1447
  if self.scene_rep_type in [
@@ -1774,7 +1942,7 @@ class MapAnything(nn.Module, PyTorchModelHubMixin):
1774
  "ray_dirs_prob": 1.0 if use_calibration else 0.0,
1775
  "depth_prob": 1.0 if use_depth else 0.0,
1776
  "cam_prob": 1.0 if use_pose else 0.0,
1777
- "sparse_depth_prob": 0.0, # No sparsification during inference
1778
  "depth_scale_norm_all_prob": 0.0 if use_depth_scale else 1.0,
1779
  "pose_scale_norm_all_prob": 0.0 if use_pose_scale else 1.0,
1780
  }
@@ -1791,6 +1959,7 @@ class MapAnything(nn.Module, PyTorchModelHubMixin):
1791
  def infer(
1792
  self,
1793
  views: List[Dict[str, Any]],
 
1794
  use_amp: bool = True,
1795
  amp_dtype: str = "bf16",
1796
  apply_mask: bool = True,
@@ -1826,6 +1995,7 @@ class MapAnything(nn.Module, PyTorchModelHubMixin):
1826
  - 'idx': List[int] where length of list is B - index info for each view
1827
  - 'true_shape': List[tuple] where length of list is B - true shape info (H, W) for each view
1828
 
 
1829
  use_amp: Whether to use automatic mixed precision for faster inference. Defaults to True.
1830
  amp_dtype: The dtype to use for mixed precision. Defaults to "bf16" (bfloat16). Options: "fp16", "bf16", "fp32".
1831
  apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
@@ -1915,7 +2085,9 @@ class MapAnything(nn.Module, PyTorchModelHubMixin):
1915
 
1916
  # Run the model
1917
  with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
1918
- preds = self.forward(processed_views)
 
 
1919
 
1920
  # Post-process the model outputs
1921
  preds = postprocess_model_outputs_for_inference(
 
4
 
5
  import warnings
6
  from functools import partial
7
+ from typing import Any, Callable, Dict, List, Tuple, Type, Union
8
 
9
  import torch
10
  import torch.nn as nn
 
1255
 
1256
  return fused_all_encoder_features_across_views
1257
 
1258
+ def _compute_adaptive_minibatch_size(
1259
+ self,
1260
+ memory_safety_factor: float = 0.95,
1261
+ ) -> int:
1262
+ """
1263
+ Compute adaptive minibatch size based on available PyTorch memory.
1264
+
1265
+ Args:
1266
+ memory_safety_factor: Safety factor to avoid OOM (0.95 = use 95% of available memory)
1267
+
1268
+ Returns:
1269
+ Computed minibatch size
1270
+ """
1271
+ device = self.device
1272
+
1273
+ if device.type == "cuda":
1274
+ # Get available GPU memory
1275
+ torch.cuda.empty_cache()
1276
+ available_memory = torch.cuda.mem_get_info()[0] # Free memory in bytes
1277
+ usable_memory = (
1278
+ available_memory * memory_safety_factor
1279
+ ) # Use safety factor to avoid OOM
1280
+ else:
1281
+ # For non-CUDA devices, use conservative default
1282
+ print(
1283
+ "Non-CUDA device detected. Using conservative default minibatch size of 1 for memory efficient dense prediction head inference."
1284
+ )
1285
+ return 1
1286
+
1287
+ # Determine minibatch size based on available memory
1288
+ max_estimated_memory_per_sample = (
1289
+ 680 * 1024 * 1024
1290
+ ) # 680 MB per sample (upper bound profiling using a 518 x 518 input)
1291
+ computed_minibatch_size = int(usable_memory / max_estimated_memory_per_sample)
1292
+ if computed_minibatch_size < 1:
1293
+ computed_minibatch_size = 1
1294
+
1295
+ return computed_minibatch_size
1296
+
1297
+ def downstream_dense_head(
1298
+ self,
1299
+ dense_head_inputs: Union[torch.Tensor, List[torch.Tensor]],
1300
+ img_shape: Tuple[int, int],
1301
+ ):
1302
+ """
1303
+ Run the downstream dense prediction head
1304
+ """
1305
+ if self.pred_head_type == "linear":
1306
+ dense_head_outputs = self.dense_head(
1307
+ PredictionHeadInput(last_feature=dense_head_inputs)
1308
+ )
1309
+ dense_final_outputs = self.dense_adaptor(
1310
+ AdaptorInput(
1311
+ adaptor_feature=dense_head_outputs.decoded_channels,
1312
+ output_shape_hw=img_shape,
1313
+ )
1314
+ )
1315
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
1316
+ dense_head_outputs = self.dense_head(
1317
+ PredictionHeadLayeredInput(
1318
+ list_features=dense_head_inputs,
1319
+ target_output_shape=img_shape,
1320
+ )
1321
+ )
1322
+ dense_final_outputs = self.dense_adaptor(
1323
+ AdaptorInput(
1324
+ adaptor_feature=dense_head_outputs.decoded_channels,
1325
+ output_shape_hw=img_shape,
1326
+ )
1327
+ )
1328
+ else:
1329
+ raise ValueError(
1330
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
1331
+ )
1332
+
1333
+ return dense_final_outputs
1334
+
1335
+ def downstream_head(
1336
+ self,
1337
+ dense_head_inputs: Union[torch.Tensor, List[torch.Tensor]],
1338
+ scale_head_inputs: torch.Tensor,
1339
+ img_shape: Tuple[int, int],
1340
+ memory_efficient_inference: bool = False,
1341
+ ):
1342
+ """
1343
+ Run Prediction Heads & Post-Process Outputs
1344
+ """
1345
+ # Get device
1346
+ device = self.device
1347
+
1348
+ # Use mini-batch inference to run the dense prediction head (the memory bottleneck)
1349
+ # This saves memory and is slower than running the dense prediction head in one go
1350
+ if memory_efficient_inference:
1351
+ # Obtain the batch size of the dense head inputs
1352
+ if self.pred_head_type == "linear":
1353
+ batch_size = dense_head_inputs.shape[0]
1354
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
1355
+ batch_size = dense_head_inputs[0].shape[0]
1356
+ else:
1357
+ raise ValueError(
1358
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
1359
+ )
1360
+
1361
+ # Compute the mini batch size and number of mini batches adaptively based on available memory
1362
+ minibatch = self._compute_adaptive_minibatch_size()
1363
+ num_batches = (batch_size + minibatch - 1) // minibatch
1364
+
1365
+ # Run prediction for each mini-batch
1366
+ dense_final_outputs_list = []
1367
+ pose_final_outputs_list = [] if self.pred_head_type == "dpt+pose" else None
1368
+ for batch_idx in range(num_batches):
1369
+ start_idx = batch_idx * minibatch
1370
+ end_idx = min((batch_idx + 1) * minibatch, batch_size)
1371
+
1372
+ # Get the inputs for the current mini-batch
1373
+ if self.pred_head_type == "linear":
1374
+ dense_head_inputs_batch = dense_head_inputs[start_idx:end_idx]
1375
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
1376
+ dense_head_inputs_batch = [
1377
+ x[start_idx:end_idx] for x in dense_head_inputs
1378
+ ]
1379
+ else:
1380
+ raise ValueError(
1381
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
1382
+ )
1383
+
1384
+ # Dense prediction (mini-batched)
1385
+ dense_final_outputs_batch = self.downstream_dense_head(
1386
+ dense_head_inputs_batch, img_shape
1387
+ )
1388
+ dense_final_outputs_list.append(dense_final_outputs_batch)
1389
+
1390
+ # Pose prediction (mini-batched)
1391
+ if self.pred_head_type == "dpt+pose":
1392
+ pose_head_inputs_batch = dense_head_inputs[-1][start_idx:end_idx]
1393
+ pose_head_outputs_batch = self.pose_head(
1394
+ PredictionHeadInput(last_feature=pose_head_inputs_batch)
1395
+ )
1396
+ pose_final_outputs_batch = self.pose_adaptor(
1397
+ AdaptorInput(
1398
+ adaptor_feature=pose_head_outputs_batch.decoded_channels,
1399
+ output_shape_hw=img_shape,
1400
+ )
1401
+ )
1402
+ pose_final_outputs_list.append(pose_final_outputs_batch)
1403
+
1404
+ # Concatenate the dense prediction head outputs from all mini-batches
1405
+ available_keys = dense_final_outputs_batch.__dict__.keys()
1406
+ dense_pred_data_dict = {
1407
+ key: torch.cat(
1408
+ [getattr(output, key) for output in dense_final_outputs_list], dim=0
1409
+ )
1410
+ for key in available_keys
1411
+ }
1412
+ dense_final_outputs = dense_final_outputs_batch.__class__(
1413
+ **dense_pred_data_dict
1414
+ )
1415
+
1416
+ # Concatenate the pose prediction head outputs from all mini-batches
1417
+ pose_final_outputs = None
1418
+ if self.pred_head_type == "dpt+pose":
1419
+ available_keys = pose_final_outputs_batch.__dict__.keys()
1420
+ pose_pred_data_dict = {
1421
+ key: torch.cat(
1422
+ [getattr(output, key) for output in pose_final_outputs_list],
1423
+ dim=0,
1424
+ )
1425
+ for key in available_keys
1426
+ }
1427
+ pose_final_outputs = pose_final_outputs_batch.__class__(
1428
+ **pose_pred_data_dict
1429
+ )
1430
+
1431
+ # Clear CUDA cache for better memory efficiency
1432
+ if device.type == "cuda":
1433
+ torch.cuda.empty_cache()
1434
+ else:
1435
+ # Run prediction for all (batch_size * num_views) in one go
1436
+ # Dense prediction
1437
+ dense_final_outputs = self.downstream_dense_head(
1438
+ dense_head_inputs, img_shape
1439
+ )
1440
+
1441
+ # Pose prediction
1442
+ pose_final_outputs = None
1443
+ if self.pred_head_type == "dpt+pose":
1444
+ pose_head_outputs = self.pose_head(
1445
+ PredictionHeadInput(last_feature=dense_head_inputs[-1])
1446
+ )
1447
+ pose_final_outputs = self.pose_adaptor(
1448
+ AdaptorInput(
1449
+ adaptor_feature=pose_head_outputs.decoded_channels,
1450
+ output_shape_hw=img_shape,
1451
+ )
1452
+ )
1453
+
1454
+ # Scale prediction is lightweight, so we can run it in one go
1455
+ scale_head_output = self.scale_head(
1456
+ PredictionHeadTokenInput(last_feature=scale_head_inputs)
1457
+ )
1458
+ scale_final_output = self.scale_adaptor(
1459
+ AdaptorInput(
1460
+ adaptor_feature=scale_head_output.decoded_channels,
1461
+ output_shape_hw=img_shape,
1462
+ )
1463
+ )
1464
+ scale_final_output = scale_final_output.value.squeeze(-1) # (B, 1, 1) -> (B, 1)
1465
+
1466
+ # Clear CUDA cache for better memory efficiency
1467
+ if memory_efficient_inference and device.type == "cuda":
1468
+ torch.cuda.empty_cache()
1469
+
1470
+ return dense_final_outputs, pose_final_outputs, scale_final_output
1471
+
1472
+ def forward(self, views, memory_efficient_inference=False):
1473
  """
1474
  Forward pass performing the following operations:
1475
  1. Encodes the N input views (images).
 
1493
  "camera_pose_quats" (tensor): Camera pose quaternions. Tensor of shape (B, 4). Camera pose is opencv (RDF) cam2world transformation.
1494
  "camera_pose_trans" (tensor): Camera pose translations. Tensor of shape (B, 3). Camera pose is opencv (RDF) cam2world transformation.
1495
  "is_metric_scale" (tensor): Boolean tensor indicating whether the geometric inputs are in metric scale or not. Tensor of shape (B, 1).
1496
+ memory_efficient_inference (bool): Whether to use memory efficient inference or not. This runs the dense prediction head (the memory bottleneck) in a memory efficient manner. Default is False.
1497
 
1498
  Returns:
1499
  List[dict]: A list containing the final outputs for all N views.
 
1591
  f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
1592
  )
1593
 
 
1594
  with torch.autocast("cuda", enabled=False):
1595
+ # Prepare inputs for the downstream heads
1596
  if self.pred_head_type == "linear":
1597
+ dense_head_inputs = dense_head_inputs
1598
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
1599
+ dense_head_inputs = dense_head_inputs_list
1600
+ scale_head_inputs = (
1601
+ final_info_sharing_multi_view_feat.additional_token_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1602
  )
1603
+
1604
+ # Run the downstream heads
1605
+ dense_final_outputs, pose_final_outputs, scale_final_output = (
1606
+ self.downstream_head(
1607
+ dense_head_inputs=dense_head_inputs,
1608
+ scale_head_inputs=scale_head_inputs,
1609
+ img_shape=img_shape,
1610
+ memory_efficient_inference=memory_efficient_inference,
1611
  )
1612
  )
 
 
 
1613
 
1614
  # Prepare the final scene representation for all views
1615
  if self.scene_rep_type in [
 
1942
  "ray_dirs_prob": 1.0 if use_calibration else 0.0,
1943
  "depth_prob": 1.0 if use_depth else 0.0,
1944
  "cam_prob": 1.0 if use_pose else 0.0,
1945
+ "sparse_depth_prob": 0.0,
1946
  "depth_scale_norm_all_prob": 0.0 if use_depth_scale else 1.0,
1947
  "pose_scale_norm_all_prob": 0.0 if use_pose_scale else 1.0,
1948
  }
 
1959
  def infer(
1960
  self,
1961
  views: List[Dict[str, Any]],
1962
+ memory_efficient_inference: bool = False,
1963
  use_amp: bool = True,
1964
  amp_dtype: str = "bf16",
1965
  apply_mask: bool = True,
 
1995
  - 'idx': List[int] where length of list is B - index info for each view
1996
  - 'true_shape': List[tuple] where length of list is B - true shape info (H, W) for each view
1997
 
1998
+ memory_efficient_inference: Whether to use memory-efficient inference for dense prediction heads (trades off speed). Defaults to False.
1999
  use_amp: Whether to use automatic mixed precision for faster inference. Defaults to True.
2000
  amp_dtype: The dtype to use for mixed precision. Defaults to "bf16" (bfloat16). Options: "fp16", "bf16", "fp32".
2001
  apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
 
2085
 
2086
  # Run the model
2087
  with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
2088
+ preds = self.forward(
2089
+ processed_views, memory_efficient_inference=memory_efficient_inference
2090
+ )
2091
 
2092
  # Post-process the model outputs
2093
  preds = postprocess_model_outputs_for_inference(
mapanything/utils/hf_utils/visual_util.py CHANGED
@@ -159,7 +159,7 @@ def predictions_to_glb(
159
  as_mesh=True,
160
  ) -> trimesh.Scene:
161
  """
162
- Converts VGGT predictions to a 3D scene represented as a GLB file.
163
 
164
  Args:
165
  predictions (dict): Dictionary containing model predictions with keys:
 
159
  as_mesh=True,
160
  ) -> trimesh.Scene:
161
  """
162
+ Converts MapAnything predictions to a 3D scene represented as a GLB file.
163
 
164
  Args:
165
  predictions (dict): Dictionary containing model predictions with keys: