diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..1b11ea33f9d33b64baf0a497b56b106066b79945 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+examples/**/*.jpg filter=lfs diff=lfs merge=lfs -text
+examples/**/*.jpeg filter=lfs diff=lfs merge=lfs -text
+examples/**/*.png filter=lfs diff=lfs merge=lfs -text
+examples/**/*.bmp filter=lfs diff=lfs merge=lfs -text
+examples/**/*.tiff filter=lfs diff=lfs merge=lfs -text
+examples/**/*.tif filter=lfs diff=lfs merge=lfs -text
+examples/* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2099be4577a2bec84191d0a74064332812f9a1fa
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,276 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be added to the global gitignore or merged into this project gitignore. For a PyCharm
+# project, it is recommended to ignore the entire .idea directory, or at least the following:
+# .idea/workspace.xml
+# .idea/tasks.xml
+# .idea/usage.statistics.xml
+# .idea/dictionaries
+# .idea/shelf
+
+# VS Code
+.vscode/
+*.code-workspace
+
+# Local History for Visual Studio Code
+.history/
+
+# Built Visual Studio Code Extensions
+*.vsix
+
+# Hugging Face specific
+# Model files (usually large binary files)
+*.bin
+*.safetensors
+*.h5
+*.onnx
+*.pkl
+*.pth
+*.pt
+*.ckpt
+*.pb
+*.tflite
+*.mlmodel
+
+# Hugging Face cache and tokens
+.cache/
+cache/
+**/cache/
+hf_token*
+.huggingface/
+transformers_cache/
+datasets_cache/
+input_images_*
+
+# Gradio temporary files
+gradio_cached_examples/
+flagged/
+
+# Data directories
+data/
+checkpoints/
+outputs/
+results/
+logs/
+tmp/
+temp/
+# examples/*/
+# /examples*.jpg
+# *.png
+# *.jpeg
+# examples/
+
+# OS generated files
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db
+desktop.ini
+
+# Backup files
+*.bak
+*.swp
+*.swo
+*~
+
+# Compressed files
+*.7z
+*.dmg
+*.gz
+*.iso
+*.jar
+*.rar
+*.tar
+*.zip
+
+# IDE and editor files
+.idea/
+*.sublime-project
+*.sublime-workspace
+.vscode/settings.json
+.vscode/tasks.json
+.vscode/launch.json
+.vscode/extensions.json
+
+# Node modules (if any frontend components)
+node_modules/
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+
+# Docker
+Dockerfile*
+docker-compose*
+.dockerignore
+
+# MLOps and experiment tracking
+wandb/
+.neptune/
+mlruns/
+.mlflow/
+tensorboard_logs/
+
+# Secrets and configuration
+*.secret
+*.key
+config.ini
+.env.local
+.env.*.local
+secrets.json
diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem
new file mode 100644
index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3
--- /dev/null
+++ b/.gradio/certificate.pem
@@ -0,0 +1,31 @@
+-----BEGIN CERTIFICATE-----
+MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
+TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
+cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
+WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
+ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
+MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
+h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
+0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
+A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
+T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
+B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
+B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
+KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
+OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
+jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
+qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
+rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
+HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
+hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
+ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
+3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
+NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
+ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
+TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
+jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
+oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
+4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
+mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
+emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
+-----END CERTIFICATE-----
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..a889762ffcad48c01a92e3e9f3ba956d34967a2b
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2025 Meta, Nikhil Keetha
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index f4b141cdefe48e51dac3d904e2737d49f84892d0..fa42a21854aec982038d3554f04a6bc42805e4cf 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
---
-title: Map Anything
-emoji: 🏆
-colorFrom: yellow
-colorTo: pink
+title: Mapanything Gradio
+emoji: 🐠
+colorFrom: purple
+colorTo: green
sdk: gradio
sdk_version: 5.44.1
app_file: app.py
diff --git a/README_grad.md b/README_grad.md
new file mode 100644
index 0000000000000000000000000000000000000000..2df8d2357b0a95a41917648bf6ef8456a181e55b
--- /dev/null
+++ b/README_grad.md
@@ -0,0 +1,12 @@
+---
+title: Mapanything Gradio
+emoji: 🐠
+colorFrom: purple
+colorTo: green
+sdk: gradio
+sdk_version: 5.44.1
+app_file: app.py
+pinned: false
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab365ee48ef8cb16680d46b07242048849512b38
--- /dev/null
+++ b/app.py
@@ -0,0 +1,1752 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# conda activate hf3.10
+
+import base64
+import gc
+import os
+import shutil
+import sys
+import time
+from datetime import datetime
+
+import cv2
+import gradio as gr
+import numpy as np
+import spaces
+import torch
+from huggingface_hub import hf_hub_download
+
+sys.path.append("mapanything/")
+
+from hf_utils.css_and_html import (
+ get_acknowledgements_html,
+ get_description_html,
+ get_gradio_theme,
+ get_header_html,
+ GRADIO_CSS,
+ MEASURE_INSTRUCTIONS_HTML,
+)
+from hf_utils.vgg_geometry import unproject_depth_map_to_point_map
+from hf_utils.visual_util import predictions_to_glb
+from mapanything.models import init_model
+from mapanything.utils.geometry import depth_edge, normals_edge, points_to_normals
+from mapanything.utils.image import load_images, rgb
+from mapanything.utils.inference import loss_of_one_batch_multi_view
+
+
+def get_logo_base64():
+ """Convert WAI logo to base64 for embedding in HTML"""
+ logo_path = "examples/wai_logo/wai_logo.png"
+ try:
+ with open(logo_path, "rb") as img_file:
+ img_data = img_file.read()
+ base64_str = base64.b64encode(img_data).decode()
+ return f"data:image/png;base64,{base64_str}"
+ except FileNotFoundError:
+ return None
+
+
+print("Initializing and loading MapAnything model...")
+
+
+def load_hf_token():
+ """Load HuggingFace access token from local file"""
+ token_file_paths = [
+ "~/hf_token.txt",
+ ]
+
+ for token_path in token_file_paths:
+ if os.path.exists(token_path):
+ try:
+ with open(token_path, "r") as f:
+ token = f.read().strip()
+ print(f"Loaded HuggingFace token from: {token_path}")
+ return token
+ except Exception as e:
+ print(f"Error reading token from {token_path}: {e}")
+ continue
+
+ # Also try environment variable
+ # see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options
+ token = (
+ os.getenv("HF_TOKEN")
+ or os.getenv("HUGGING_FACE_HUB_TOKEN")
+ or os.getenv("HUGGING_FACE_MODEL_TOKEN")
+ )
+ if token:
+ print("Loaded HuggingFace token from environment variable")
+ return token
+
+ print(
+ "Warning: No HuggingFace token found. Model loading may fail for private repositories."
+ )
+ return None
+
+
+def init_hydra_config(config_path, overrides=None):
+ "Initialize Hydra config"
+ import hydra
+
+ config_dir = os.path.dirname(config_path)
+ config_name = os.path.basename(config_path).split(".")[0]
+ relative_path = os.path.relpath(config_dir, os.path.dirname(__file__))
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+ hydra.initialize(version_base=None, config_path=relative_path)
+ if overrides is not None:
+ cfg = hydra.compose(config_name=config_name, overrides=overrides)
+ else:
+ cfg = hydra.compose(config_name=config_name)
+ return cfg
+
+
+def init_inference_model(config, ckpt_path, device):
+ "Initialize the model for inference"
+ if isinstance(config, dict):
+ config_path = config["path"]
+ overrrides = config["config_overrides"]
+ model_args = init_hydra_config(config_path, overrides=overrrides)
+ model = init_model(model_args.model.model_str, model_args.model.model_config)
+ else:
+ config_path = config
+ model_args = init_hydra_config(config_path)
+ model = init_model(model_args.model_str, model_args.model_config)
+ model.to(device)
+ if ckpt_path is not None:
+ print("Loading model from: ", ckpt_path)
+
+ # Load HuggingFace token for private repositories
+ hf_token = load_hf_token()
+
+ # Try to download from HuggingFace Hub first if it's a HF URL
+ if "huggingface.co" in ckpt_path:
+ try:
+ # Extract repo_id and filename from URL
+ # URL format: https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth
+ parts = ckpt_path.replace("https://huggingface.co/", "").split("/")
+ repo_id = f"{parts[0]}/{parts[1]}" # e.g., "facebook/MapAnything"
+ filename = "/".join(
+ parts[4:]
+ ) # e.g., "mapa_curri_24v_13d_48ipg_64g.pth"
+
+ print(f"Downloading from HuggingFace Hub: {repo_id}/{filename}")
+ local_file = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ token=hf_token,
+ cache_dir=None, # Use default cache
+ )
+ ckpt = torch.load(local_file, map_location=device, weights_only=False)
+ except Exception as e:
+ print(f"HuggingFace Hub download failed: {e}")
+ print("Falling back to torch.hub.load_state_dict_from_url...")
+ # Fallback to original method
+ ckpt = torch.hub.load_state_dict_from_url(
+ ckpt_path, map_location=device
+ )
+ else:
+ # Use original method for non-HF URLs
+ ckpt = torch.hub.load_state_dict_from_url(ckpt_path, map_location=device)
+
+ print(model.load_state_dict(ckpt["model"], strict=False))
+ model.eval()
+ return model
+
+
+# MapAnything Configuration
+high_level_config = {
+ "path": "configs/train.yaml",
+ "config_overrides": [
+ "machine=aws",
+ "model=mapanything",
+ "model/task=images_only",
+ "model.encoder.uses_torch_hub=false",
+ ],
+ "checkpoint_path": "https://huggingface.co/facebook/MapAnything/resolve/main/mapa_curri_24v_13d_48ipg_64g.pth",
+ "trained_with_amp": True,
+ "trained_with_amp_dtype": "fp16",
+ "data_norm_type": "dinov2",
+ "patch_size": 14,
+ "resolution": 518,
+}
+
+# Initialize model - this will be done on GPU when needed
+model = None
+
+
+# -------------------------------------------------------------------------
+# 1) Core model inference
+# -------------------------------------------------------------------------
+@spaces.GPU(duration=120)
+def run_model(target_dir, model_placeholder):
+ """
+ Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
+ """
+ global model
+ print(f"Processing images from {target_dir}")
+
+ # Device check
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+ # if not torch.cuda.is_available():
+ # raise ValueError("CUDA is not available. Check your environment.")
+
+ # Initialize model if not already done
+ if model is None:
+ print("Initializing MapAnything model...")
+ model = init_inference_model(
+ high_level_config, high_level_config["checkpoint_path"], device
+ )
+ else:
+ model = model.to(device)
+
+ model.eval()
+
+ # Load images using MapAnything's load_images function
+ print("Loading images...")
+ image_folder_path = os.path.join(target_dir, "images")
+ views = load_images(
+ image_folder_path,
+ resolution_set=high_level_config["resolution"],
+ verbose=False,
+ norm_type=high_level_config["data_norm_type"],
+ patch_size=high_level_config["patch_size"],
+ stride=1,
+ )
+
+ print(f"Loaded {len(views)} images")
+ if len(views) == 0:
+ raise ValueError("No images found. Check your upload.")
+
+ # Run inference using MapAnything's inference function
+ print("Running MapAnything inference...")
+ with torch.no_grad():
+ pred_result = loss_of_one_batch_multi_view(
+ views,
+ model,
+ None,
+ device,
+ use_amp=high_level_config["trained_with_amp"],
+ amp_dtype=high_level_config["trained_with_amp_dtype"],
+ )
+
+ # Convert predictions to format expected by visualization
+ predictions = {}
+
+ # Initialize lists for the required keys
+ extrinsic_list = []
+ intrinsic_list = []
+ world_points_list = []
+ depth_maps_list = []
+ images_list = []
+ confidence_list = []
+ final_mask_list = []
+
+ # Check if confidence data is available
+ has_confidence = False
+ for view_idx, view in enumerate(views):
+ view_key = f"pred{view_idx + 1}"
+ if view_key in pred_result and "conf" in pred_result[view_key]:
+ has_confidence = True
+ break
+
+ # Extract predictions for each view
+ for view_idx, view in enumerate(views):
+ # Get image for colors
+ image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
+
+ view_key = f"pred{view_idx + 1}"
+ if view_key in pred_result:
+ pred_pts3d = pred_result[view_key]["pts3d"][0].cpu().numpy()
+
+ # Get confidence data if available
+ confidence_map = None
+ if "conf" in pred_result[view_key]:
+ confidence_map = pred_result[view_key]["conf"][0].cpu().numpy()
+
+ # Compute final_mask just like in visualize_raw_inference_output function
+ # Create the prediction mask based on parameters
+ pred_mask = None
+ use_gt_mask_on_pred = False # Set based on your requirements
+ use_pred_mask = True # Set based on your requirements
+ use_non_ambi_mask = True # Set based on your requirements
+ use_conf_mask = False # Set based on your requirements
+ conf_percentile = 10 # Set based on your requirements
+ use_edge_mask = True # Set based on your requirements
+ pts_edge_tol = 5 # Set based on your requirements
+ depth_edge_rtol = 0.03 # Set based on your requirements
+
+ if use_pred_mask:
+ # Get non ambiguous mask if available and requested
+ has_non_ambiguous_mask = (
+ "non_ambiguous_mask" in pred_result[view_key] and use_non_ambi_mask
+ )
+ if has_non_ambiguous_mask:
+ non_ambiguous_mask = (
+ pred_result[view_key]["non_ambiguous_mask"][0].cpu().numpy()
+ )
+ pred_mask = non_ambiguous_mask
+
+ # Get confidence mask if available and requested
+ has_conf = "conf" in pred_result[view_key] and use_conf_mask
+ if has_conf:
+ confidences = pred_result[view_key]["conf"][0].cpu()
+ percentile_threshold = torch.quantile(
+ confidences, conf_percentile / 100.0
+ )
+ conf_mask = confidences > percentile_threshold
+ conf_mask = conf_mask.numpy()
+ if pred_mask is not None:
+ pred_mask = pred_mask & conf_mask
+ else:
+ pred_mask = conf_mask
+
+ # Apply edge mask if requested
+ if use_edge_mask and pred_mask is not None:
+ if "cam_quats" not in pred_result[view_key]:
+ # For direct point prediction
+ # Compute normals and edge mask
+ normals, normals_mask = points_to_normals(
+ pred_pts3d, mask=pred_mask
+ )
+ edge_mask = ~(
+ normals_edge(normals, tol=pts_edge_tol, mask=normals_mask)
+ )
+ else:
+ # For ray-based prediction
+ ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
+ local_pts3d = (
+ pred_result[view_key]["ray_directions"][0].cpu() * ray_depth
+ )
+ depth_z = local_pts3d[..., 2].numpy()
+
+ # Compute normals and edge mask
+ normals, normals_mask = points_to_normals(
+ pred_pts3d, mask=pred_mask
+ )
+ edge_mask = ~(
+ depth_edge(depth_z, rtol=depth_edge_rtol, mask=pred_mask)
+ & normals_edge(normals, tol=pts_edge_tol, mask=normals_mask)
+ )
+ if pred_mask is not None:
+ pred_mask = pred_mask & edge_mask
+
+ # Determine final mask to use (like in visualize_raw_inference_output)
+ final_mask = None
+ valid_mask = np.ones_like(
+ pred_pts3d[..., 0], dtype=bool
+ ) # Create dummy valid_mask for app.py context
+
+ if use_gt_mask_on_pred:
+ final_mask = valid_mask
+ if use_pred_mask and pred_mask is not None:
+ final_mask = final_mask & pred_mask
+ elif use_pred_mask and pred_mask is not None:
+ final_mask = pred_mask
+ else:
+ final_mask = np.ones_like(valid_mask, dtype=bool)
+
+ # Check if we have camera pose and intrinsics data
+ if "cam_quats" in pred_result[view_key]:
+ # Get decoupled quantities (like in visualize_raw_custom_data_inference_output)
+ cam_quats = pred_result[view_key]["cam_quats"][0].cpu()
+ cam_trans = pred_result[view_key]["cam_trans"][0].cpu()
+ ray_directions = pred_result[view_key]["ray_directions"][0].cpu()
+ ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
+
+ # Convert the quantities
+ from mapanything.utils.geometry import (
+ quaternion_to_rotation_matrix,
+ recover_pinhole_intrinsics_from_ray_directions,
+ )
+
+ cam_rot = quaternion_to_rotation_matrix(cam_quats)
+ cam_pose = torch.eye(4)
+ cam_pose[:3, :3] = cam_rot
+ cam_pose[:3, 3] = cam_trans
+ cam_pose = np.linalg.inv(cam_pose)
+ cam_intrinsics = recover_pinhole_intrinsics_from_ray_directions(
+ ray_directions, use_geometric_calculation=True
+ )
+
+ # Compute depth as in app_map.py
+ local_pts3d = ray_directions * ray_depth
+ depth_z = local_pts3d[..., 2]
+
+ # Convert to numpy and extract 3x4 extrinsic (remove bottom row)
+ extrinsic = cam_pose[:3, :4].numpy() # Shape: (3, 4)
+ intrinsic = cam_intrinsics.numpy() # Shape: (3, 3)
+ depth_z = depth_z.numpy() # Shape: (H, W)
+ else:
+ # Use dummy values if camera info not available
+ # extrinsic: (3, 4) - [R|t] matrix
+ extrinsic = np.eye(3, 4) # Identity rotation, zero translation
+ # intrinsic: (3, 3) - camera intrinsic matrix
+ intrinsic = np.eye(3)
+ # depth_z: (H, W) - dummy depth values
+ depth_z = np.zeros_like(pred_pts3d[..., 0])
+
+ # Append to lists
+ extrinsic_list.append(extrinsic)
+ intrinsic_list.append(intrinsic)
+ world_points_list.append(pred_pts3d)
+ depth_maps_list.append(depth_z)
+ images_list.append(image[0]) # Add image to list
+ final_mask_list.append(final_mask) # Add final_mask to list
+
+ # Add confidence data (or None if not available)
+ if confidence_map is not None:
+ confidence_list.append(confidence_map)
+ elif has_confidence:
+ # If some views have confidence but this one doesn't, add dummy confidence
+ confidence_list.append(np.ones_like(depth_z))
+
+ # Convert lists to numpy arrays with required shapes
+ # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
+ predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
+
+ # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
+ predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
+
+ # world_points: (S, H, W, 3) - batch of 3D world points
+ predictions["world_points"] = np.stack(world_points_list, axis=0)
+
+ # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
+ depth_maps = np.stack(depth_maps_list, axis=0)
+ # Add channel dimension if needed to match (S, H, W, 1) format
+ if len(depth_maps.shape) == 3:
+ depth_maps = depth_maps[..., np.newaxis]
+ predictions["depth"] = depth_maps
+
+ # images: (S, H, W, 3) - batch of input images
+ predictions["images"] = np.stack(images_list, axis=0)
+
+ # confidence: (S, H, W) - batch of confidence maps (only if available)
+ if confidence_list:
+ predictions["confidence"] = np.stack(confidence_list, axis=0)
+
+ # final_mask: (S, H, W) - batch of final masks for filtering
+ predictions["final_mask"] = np.stack(final_mask_list, axis=0)
+
+ world_points = unproject_depth_map_to_point_map(
+ depth_maps, predictions["extrinsic"], predictions["intrinsic"]
+ )
+ predictions["world_points_from_depth"] = world_points
+
+ # Process data for visualization tabs (depth, normal, measure)
+ processed_data = process_predictions_for_visualization(
+ pred_result, views, high_level_config
+ )
+
+ # Clean up
+ torch.cuda.empty_cache()
+
+ return predictions, processed_data
+
+
+def update_view_selectors(processed_data):
+ """Update view selector dropdowns based on available views"""
+ if processed_data is None or len(processed_data) == 0:
+ choices = ["View 1"]
+ else:
+ num_views = len(processed_data)
+ choices = [f"View {i + 1}" for i in range(num_views)]
+
+ return (
+ gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
+ gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
+ gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
+ )
+
+
+def get_view_data_by_index(processed_data, view_index):
+ """Get view data by index, handling bounds"""
+ if processed_data is None or len(processed_data) == 0:
+ return None
+
+ view_keys = list(processed_data.keys())
+ if view_index < 0 or view_index >= len(view_keys):
+ view_index = 0
+
+ return processed_data[view_keys[view_index]]
+
+
+def update_depth_view(processed_data, view_index, conf_thres=None):
+ """Update depth view for a specific view index with optional confidence filtering"""
+ view_data = get_view_data_by_index(processed_data, view_index)
+ if view_data is None or view_data["depth"] is None:
+ return None
+
+ # Use confidence filtering if available
+ confidence = view_data.get("confidence")
+ return colorize_depth(
+ view_data["depth"], confidence=confidence, conf_thres=conf_thres
+ )
+
+
+def update_normal_view(processed_data, view_index, conf_thres=None):
+ """Update normal view for a specific view index with optional confidence filtering"""
+ view_data = get_view_data_by_index(processed_data, view_index)
+ if view_data is None or view_data["normal"] is None:
+ return None
+
+ # Use confidence filtering if available
+ confidence = view_data.get("confidence")
+ return colorize_normal(
+ view_data["normal"], confidence=confidence, conf_thres=conf_thres
+ )
+
+
+def update_measure_view(processed_data, view_index):
+ """Update measure view for a specific view index"""
+ view_data = get_view_data_by_index(processed_data, view_index)
+ if view_data is None:
+ return None, [] # image, measure_points
+ return view_data["image"], []
+
+
+def navigate_depth_view(
+ processed_data, current_selector_value, direction, conf_thres=None
+):
+ """Navigate depth view (direction: -1 for previous, +1 for next)"""
+ if processed_data is None or len(processed_data) == 0:
+ return "View 1", None
+
+ # Parse current view number
+ try:
+ current_view = int(current_selector_value.split()[1]) - 1
+ except:
+ current_view = 0
+
+ num_views = len(processed_data)
+ new_view = (current_view + direction) % num_views
+
+ new_selector_value = f"View {new_view + 1}"
+ depth_vis = update_depth_view(processed_data, new_view, conf_thres=conf_thres)
+
+ return new_selector_value, depth_vis
+
+
+def navigate_normal_view(
+ processed_data, current_selector_value, direction, conf_thres=None
+):
+ """Navigate normal view (direction: -1 for previous, +1 for next)"""
+ if processed_data is None or len(processed_data) == 0:
+ return "View 1", None
+
+ # Parse current view number
+ try:
+ current_view = int(current_selector_value.split()[1]) - 1
+ except:
+ current_view = 0
+
+ num_views = len(processed_data)
+ new_view = (current_view + direction) % num_views
+
+ new_selector_value = f"View {new_view + 1}"
+ normal_vis = update_normal_view(processed_data, new_view, conf_thres=conf_thres)
+
+ return new_selector_value, normal_vis
+
+
+def navigate_measure_view(processed_data, current_selector_value, direction):
+ """Navigate measure view (direction: -1 for previous, +1 for next)"""
+ if processed_data is None or len(processed_data) == 0:
+ return "View 1", None, []
+
+ # Parse current view number
+ try:
+ current_view = int(current_selector_value.split()[1]) - 1
+ except:
+ current_view = 0
+
+ num_views = len(processed_data)
+ new_view = (current_view + direction) % num_views
+
+ new_selector_value = f"View {new_view + 1}"
+ measure_image, measure_points = update_measure_view(processed_data, new_view)
+
+ return new_selector_value, measure_image, measure_points
+
+
+def populate_visualization_tabs(processed_data, conf_thres=None):
+ """Populate the depth, normal, and measure tabs with processed data"""
+ if processed_data is None or len(processed_data) == 0:
+ return None, None, None, []
+
+ # Use update functions to ensure confidence filtering is applied from the start
+ depth_vis = update_depth_view(processed_data, 0, conf_thres=conf_thres)
+ normal_vis = update_normal_view(processed_data, 0, conf_thres=conf_thres)
+ measure_img, _ = update_measure_view(processed_data, 0)
+
+ return depth_vis, normal_vis, measure_img, []
+
+
+# -------------------------------------------------------------------------
+# 2) Handle uploaded video/images --> produce target_dir + images
+# -------------------------------------------------------------------------
+def handle_uploads(input_video, input_images, s_time_interval=1.0):
+ """
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
+ images or extracted frames from video into it. Return (target_dir, image_paths).
+ """
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Create a unique folder name
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ target_dir = f"input_images_{timestamp}"
+ target_dir_images = os.path.join(target_dir, "images")
+
+ # Clean up if somehow that folder already exists
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+ os.makedirs(target_dir)
+ os.makedirs(target_dir_images)
+
+ image_paths = []
+
+ # --- Handle images ---
+ if input_images is not None:
+ for file_data in input_images:
+ if isinstance(file_data, dict) and "name" in file_data:
+ file_path = file_data["name"]
+ else:
+ file_path = file_data
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
+ shutil.copy(file_path, dst_path)
+ image_paths.append(dst_path)
+
+ # --- Handle video ---
+ if input_video is not None:
+ if isinstance(input_video, dict) and "name" in input_video:
+ video_path = input_video["name"]
+ else:
+ video_path = input_video
+
+ vs = cv2.VideoCapture(video_path)
+ fps = vs.get(cv2.CAP_PROP_FPS)
+ frame_interval = int(fps * s_time_interval) # 1 frame/sec
+
+ count = 0
+ video_frame_num = 0
+ while True:
+ gotit, frame = vs.read()
+ if not gotit:
+ break
+ count += 1
+ if count % frame_interval == 0:
+ image_path = os.path.join(
+ target_dir_images, f"{video_frame_num:06}.png"
+ )
+ cv2.imwrite(image_path, frame)
+ image_paths.append(image_path)
+ video_frame_num += 1
+
+ # Sort final images for gallery
+ image_paths = sorted(image_paths)
+
+ end_time = time.time()
+ print(
+ f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
+ )
+ return target_dir, image_paths
+
+
+# -------------------------------------------------------------------------
+# 3) Update gallery on upload
+# -------------------------------------------------------------------------
+def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
+ """
+ Whenever user uploads or changes files, immediately handle them
+ and show in the gallery. Return (target_dir, image_paths).
+ If nothing is uploaded, returns "None" and empty list.
+ """
+ if not input_video and not input_images:
+ return None, None, None, None
+ target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
+ return (
+ None,
+ target_dir,
+ image_paths,
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
+ )
+
+
+# -------------------------------------------------------------------------
+# 4) Reconstruction: uses the target_dir plus any viz parameters
+# -------------------------------------------------------------------------
+@spaces.GPU(duration=120)
+def gradio_demo(
+ target_dir,
+ conf_thres=3.0,
+ frame_filter="All",
+ show_cam=True,
+ filter_sky=False,
+ filter_black_bg=False,
+ filter_white_bg=False,
+ mask_ambiguous=False,
+):
+ """
+ Perform reconstruction using the already-created target_dir/images.
+ """
+ if not os.path.isdir(target_dir) or target_dir == "None":
+ return None, "No valid target directory found. Please upload first.", None, None
+
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Always use Pointmap Branch for MapAnything
+ prediction_mode = "Pointmap Branch"
+
+ # Prepare frame_filter dropdown
+ target_dir_images = os.path.join(target_dir, "images")
+ all_files = (
+ sorted(os.listdir(target_dir_images))
+ if os.path.isdir(target_dir_images)
+ else []
+ )
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
+ frame_filter_choices = ["All"] + all_files
+
+ print("Running MapAnything model...")
+ with torch.no_grad():
+ predictions, processed_data = run_model(target_dir, None)
+
+ # Save predictions
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
+ np.savez(prediction_save_path, **predictions)
+
+ # Handle None frame_filter
+ if frame_filter is None:
+ frame_filter = "All"
+
+ # Build a GLB file name
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_mask{mask_ambiguous}_pred{prediction_mode.replace(' ', '_')}.glb",
+ )
+
+ # Convert predictions to GLB
+ glbscene = predictions_to_glb(
+ predictions,
+ conf_thres=conf_thres,
+ filter_by_frames=frame_filter,
+ show_cam=show_cam,
+ target_dir=target_dir,
+ prediction_mode=prediction_mode,
+ mask_sky=filter_sky,
+ mask_black_bg=filter_black_bg,
+ mask_white_bg=filter_white_bg,
+ mask_ambiguous=mask_ambiguous,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ # Cleanup
+ del predictions
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ end_time = time.time()
+ print(f"Total time: {end_time - start_time:.2f} seconds")
+ log_msg = (
+ f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
+ )
+
+ # Populate visualization tabs with processed data
+ depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
+ processed_data, conf_thres=conf_thres
+ )
+
+ # Update view selectors based on available views
+ depth_selector, normal_selector, measure_selector = update_view_selectors(
+ processed_data
+ )
+
+ return (
+ glbfile,
+ log_msg,
+ gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
+ processed_data,
+ depth_vis,
+ normal_vis,
+ measure_img,
+ "", # measure_text (empty initially)
+ depth_selector,
+ normal_selector,
+ measure_selector,
+ )
+
+
+# -------------------------------------------------------------------------
+# 5) Helper functions for UI resets + re-visualization
+# -------------------------------------------------------------------------
+def apply_confidence_filtering(data, confidence, conf_thres):
+ """Apply confidence filtering to data arrays"""
+ if confidence is None or data is None:
+ return data
+
+ # Convert confidence threshold from percentage to confidence value
+ conf_threshold = np.percentile(confidence, conf_thres)
+ conf_mask = (confidence >= conf_threshold) & (confidence > 1e-5)
+
+ # conf_mask = confidence >= (conf_thres)
+
+ # Apply mask to data
+ if len(data.shape) == 3: # 3D data (H, W, C)
+ filtered_data = data.copy()
+ for c in range(data.shape[2]):
+ filtered_data[:, :, c] = np.where(conf_mask, data[:, :, c], 0)
+ elif len(data.shape) == 2: # 2D data (H, W)
+ filtered_data = np.where(conf_mask, data, 0)
+ else:
+ filtered_data = data
+
+ return filtered_data
+
+
+def colorize_depth(depth_map, confidence=None, conf_thres=None):
+ """Convert depth map to colorized visualization with optional confidence filtering"""
+ if depth_map is None:
+ return None
+
+ # Apply confidence filtering if available
+ if confidence is not None and conf_thres is not None:
+ depth_map = apply_confidence_filtering(depth_map, confidence, conf_thres)
+
+ # Normalize depth to 0-1 range
+ depth_normalized = depth_map.copy()
+ valid_mask = depth_normalized > 0
+ if valid_mask.sum() > 0:
+ valid_depths = depth_normalized[valid_mask]
+ p5 = np.percentile(valid_depths, 5)
+ p95 = np.percentile(valid_depths, 95)
+
+ depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
+
+ # Apply colormap
+ import matplotlib.pyplot as plt
+
+ colormap = plt.cm.turbo_r
+ # colormap = plt.cm.plasma
+ # colormap = plt.cm.viridis
+ colored = colormap(depth_normalized)
+ colored = (colored[:, :, :3] * 255).astype(np.uint8)
+
+ # Set invalid pixels to white
+ colored[~valid_mask] = [255, 255, 255]
+
+ return colored
+
+
+def colorize_normal(normal_map, confidence=None, conf_thres=None):
+ """Convert normal map to colorized visualization with optional confidence filtering"""
+ if normal_map is None:
+ return None
+
+ # Apply confidence filtering if available
+ if confidence is not None and conf_thres is not None:
+ normal_map = apply_confidence_filtering(normal_map, confidence, conf_thres)
+
+ # Normalize normals to [0, 1] range for visualization
+ normal_vis = (normal_map + 1.0) / 2.0
+ normal_vis = (normal_vis * 255).astype(np.uint8)
+
+ return normal_vis
+
+
+def process_predictions_for_visualization(pred_result, views, high_level_config):
+ """Extract depth, normal, and 3D points from predictions for visualization"""
+ processed_data = {}
+
+ # Check if confidence data is available in any view
+ has_confidence_data = False
+ for view_idx, view in enumerate(views):
+ view_key = f"pred{view_idx + 1}"
+ if view_key in pred_result and "conf" in pred_result[view_key]:
+ has_confidence_data = True
+ break
+
+ # Process each view
+ for view_idx, view in enumerate(views):
+ view_key = f"pred{view_idx + 1}"
+ if view_key not in pred_result:
+ continue
+
+ # Get image
+ image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
+
+ # Get predicted points
+ pred_pts3d = pred_result[view_key]["pts3d"][0].cpu().numpy()
+
+ # Initialize data for this view
+ view_data = {
+ "image": image[0],
+ "points3d": pred_pts3d,
+ "depth": None,
+ "normal": None,
+ "mask": None,
+ "confidence": None,
+ "has_confidence": has_confidence_data,
+ }
+
+ # Get confidence data if available
+ if "conf" in pred_result[view_key]:
+ confidence = pred_result[view_key]["conf"][0].cpu().numpy()
+ view_data["confidence"] = confidence
+
+ # Get masks if available
+ has_non_ambiguous_mask = "non_ambiguous_mask" in pred_result[view_key]
+ if has_non_ambiguous_mask:
+ view_data["mask"] = (
+ pred_result[view_key]["non_ambiguous_mask"][0].cpu().numpy()
+ )
+
+ # Extract depth and camera info if available
+ if "cam_quats" in pred_result[view_key]:
+ ray_directions = pred_result[view_key]["ray_directions"][0].cpu()
+ ray_depth = pred_result[view_key]["depth_along_ray"][0].cpu()
+
+ # Compute depth
+ local_pts3d = ray_directions * ray_depth
+ depth_z = local_pts3d[..., 2].numpy()
+ view_data["depth"] = depth_z
+
+ # Compute normals if we have valid points
+ if has_non_ambiguous_mask:
+ try:
+ normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
+ view_data["normal"] = normals
+ except:
+ # If normal computation fails, skip it
+ pass
+
+ processed_data[view_idx] = view_data
+
+ return processed_data
+
+
+def reset_measure(processed_data):
+ """Reset measure points"""
+ if processed_data is None or len(processed_data) == 0:
+ return None, [], ""
+
+ # Return the first view image
+ first_view = list(processed_data.values())[0]
+ return first_view["image"], [], ""
+
+
+def measure(
+ processed_data, measure_points, current_view_selector, event: gr.SelectData
+):
+ """Handle measurement on images"""
+ try:
+ print(f"Measure function called with selector: {current_view_selector}")
+
+ if processed_data is None or len(processed_data) == 0:
+ return None, [], "No data available"
+
+ # Use the currently selected view instead of always using the first view
+ try:
+ current_view_index = int(current_view_selector.split()[1]) - 1
+ except:
+ current_view_index = 0
+
+ print(f"Using view index: {current_view_index}")
+
+ # Get view data safely
+ if current_view_index < 0 or current_view_index >= len(processed_data):
+ current_view_index = 0
+
+ view_keys = list(processed_data.keys())
+ current_view = processed_data[view_keys[current_view_index]]
+
+ if current_view is None:
+ return None, [], "No view data available"
+
+ point2d = event.index[0], event.index[1]
+ print(f"Clicked point: {point2d}")
+
+ measure_points.append(point2d)
+
+ # Get image and ensure it's valid
+ image = current_view["image"]
+ if image is None:
+ return None, [], "No image available"
+
+ image = image.copy()
+ points3d = current_view["points3d"]
+
+ # Ensure image is in uint8 format for proper cv2 operations
+ try:
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ # Image is in [0, 1] range, convert to [0, 255]
+ image = (image * 255).astype(np.uint8)
+ else:
+ # Image is already in [0, 255] range
+ image = image.astype(np.uint8)
+ except Exception as e:
+ print(f"Image conversion error: {e}")
+ return None, [], f"Image conversion error: {e}"
+
+ # Draw circles for points
+ try:
+ for p in measure_points:
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
+ image = cv2.circle(
+ image, p, radius=5, color=(255, 0, 0), thickness=2
+ )
+ except Exception as e:
+ print(f"Drawing error: {e}")
+ return None, [], f"Drawing error: {e}"
+
+ depth_text = ""
+ try:
+ for i, p in enumerate(measure_points):
+ if (
+ current_view["depth"] is not None
+ and 0 <= p[1] < current_view["depth"].shape[0]
+ and 0 <= p[0] < current_view["depth"].shape[1]
+ ):
+ d = current_view["depth"][p[1], p[0]]
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
+ else:
+ # Use Z coordinate of 3D points if depth not available
+ if (
+ points3d is not None
+ and 0 <= p[1] < points3d.shape[0]
+ and 0 <= p[0] < points3d.shape[1]
+ ):
+ z = points3d[p[1], p[0], 2]
+ depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
+ except Exception as e:
+ print(f"Depth text error: {e}")
+ depth_text = f"Error computing depth: {e}\n"
+
+ if len(measure_points) == 2:
+ try:
+ point1, point2 = measure_points
+ # Draw line
+ if (
+ 0 <= point1[0] < image.shape[1]
+ and 0 <= point1[1] < image.shape[0]
+ and 0 <= point2[0] < image.shape[1]
+ and 0 <= point2[1] < image.shape[0]
+ ):
+ image = cv2.line(
+ image, point1, point2, color=(255, 0, 0), thickness=2
+ )
+
+ # Compute 3D distance
+ distance_text = "- **Distance: Unable to compute**"
+ if (
+ points3d is not None
+ and 0 <= point1[1] < points3d.shape[0]
+ and 0 <= point1[0] < points3d.shape[1]
+ and 0 <= point2[1] < points3d.shape[0]
+ and 0 <= point2[0] < points3d.shape[1]
+ ):
+ try:
+ p1_3d = points3d[point1[1], point1[0]]
+ p2_3d = points3d[point2[1], point2[0]]
+ distance = np.linalg.norm(p1_3d - p2_3d)
+ distance_text = f"- **Distance: {distance:.2f}m**"
+ except Exception as e:
+ print(f"Distance computation error: {e}")
+ distance_text = f"- **Distance computation error: {e}**"
+
+ measure_points = []
+ text = depth_text + distance_text
+ print(f"Measurement complete: {text}")
+ return [image, measure_points, text]
+ except Exception as e:
+ print(f"Final measurement error: {e}")
+ return None, [], f"Measurement error: {e}"
+ else:
+ print(f"Single point measurement: {depth_text}")
+ return [image, measure_points, depth_text]
+
+ except Exception as e:
+ print(f"Overall measure function error: {e}")
+ return None, [], f"Measure function error: {e}"
+
+
+def clear_fields():
+ """
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
+ """
+ return None
+
+
+def update_log():
+ """
+ Display a quick log message while waiting.
+ """
+ return "Loading and Reconstructing..."
+
+
+def update_visualization(
+ target_dir,
+ conf_thres,
+ frame_filter,
+ show_cam,
+ is_example,
+ filter_sky=False,
+ filter_black_bg=False,
+ filter_white_bg=False,
+ mask_ambiguous=False,
+):
+ """
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
+ and return it for the 3D viewer. If is_example == "True", skip.
+ """
+
+ # If it's an example click, skip as requested
+ if is_example == "True":
+ return (
+ gr.update(),
+ "No reconstruction available. Please click the Reconstruct button first.",
+ )
+
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
+ return (
+ gr.update(),
+ "No reconstruction available. Please click the Reconstruct button first.",
+ )
+
+ predictions_path = os.path.join(target_dir, "predictions.npz")
+ if not os.path.exists(predictions_path):
+ return (
+ gr.update(),
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
+ )
+
+ loaded = np.load(predictions_path, allow_pickle=True)
+ predictions = {key: loaded[key] for key in loaded.keys()}
+
+ # Always use Pointmap Branch for MapAnything
+ prediction_mode = "Pointmap Branch"
+
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_sky{filter_sky}_black{filter_black_bg}_white{filter_white_bg}_pred{prediction_mode.replace(' ', '_')}.glb",
+ )
+
+ if not os.path.exists(glbfile):
+ glbscene = predictions_to_glb(
+ predictions,
+ conf_thres=conf_thres,
+ filter_by_frames=frame_filter,
+ show_cam=show_cam,
+ target_dir=target_dir,
+ prediction_mode=prediction_mode,
+ mask_sky=filter_sky,
+ mask_black_bg=filter_black_bg,
+ mask_white_bg=filter_white_bg,
+ mask_ambiguous=mask_ambiguous,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ return (
+ glbfile,
+ "Visualization updated.",
+ )
+
+
+# -------------------------------------------------------------------------
+# Example scene functions
+# -------------------------------------------------------------------------
+def get_scene_info(examples_dir):
+ """Get information about scenes in the examples directory"""
+ import glob
+
+ scenes = []
+ if not os.path.exists(examples_dir):
+ return scenes
+
+ for scene_folder in sorted(os.listdir(examples_dir)):
+ scene_path = os.path.join(examples_dir, scene_folder)
+ if os.path.isdir(scene_path):
+ # Find all image files in the scene folder
+ image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
+ image_files = []
+ for ext in image_extensions:
+ image_files.extend(glob.glob(os.path.join(scene_path, ext)))
+ image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
+
+ if image_files:
+ # Sort images and get the first one for thumbnail
+ image_files = sorted(image_files)
+ first_image = image_files[0]
+ num_images = len(image_files)
+
+ scenes.append(
+ {
+ "name": scene_folder,
+ "path": scene_path,
+ "thumbnail": first_image,
+ "num_images": num_images,
+ "image_files": image_files,
+ }
+ )
+
+ return scenes
+
+
+def load_example_scene(scene_name, examples_dir="examples"):
+ """Load a scene from examples directory"""
+ scenes = get_scene_info(examples_dir)
+
+ # Find the selected scene
+ selected_scene = None
+ for scene in scenes:
+ if scene["name"] == scene_name:
+ selected_scene = scene
+ break
+
+ if selected_scene is None:
+ return None, None, None, "Scene not found"
+
+ # Create target directory and copy images
+ target_dir, image_paths = handle_uploads(None, selected_scene["image_files"])
+
+ return (
+ None, # Clear reconstruction output
+ target_dir, # Set target directory
+ image_paths, # Set gallery
+ f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.",
+ )
+
+
+# -------------------------------------------------------------------------
+# 6) Build Gradio UI
+# -------------------------------------------------------------------------
+theme = get_gradio_theme()
+
+with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
+ # State variables for the tabbed interface
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
+ processed_data_state = gr.State(value=None)
+ measure_points_state = gr.State(value=[])
+ current_view_index = gr.State(value=0) # Track current view index for navigation
+
+ gr.HTML(get_header_html(get_logo_base64()))
+ gr.HTML(get_description_html())
+
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ input_video = gr.Video(label="Upload Video", interactive=True)
+ s_time_interval = gr.Slider(
+ minimum=0.1,
+ maximum=5.0,
+ value=1.0,
+ step=0.1,
+ label="Sample time interval (take a sample every x sec.)",
+ interactive=True,
+ visible=True,
+ )
+ input_images = gr.File(
+ file_count="multiple", label="Upload Images", interactive=True
+ )
+
+ image_gallery = gr.Gallery(
+ label="Preview",
+ columns=4,
+ height="300px",
+ show_download_button=True,
+ object_fit="contain",
+ preview=True,
+ )
+
+ with gr.Column(scale=4):
+ with gr.Column():
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
+ log_output = gr.Markdown(
+ "Please upload a video or images, then click Reconstruct.",
+ elem_classes=["custom-log"],
+ )
+
+ # Add tabbed interface similar to MoGe
+ with gr.Tabs():
+ with gr.Tab("3D View"):
+ reconstruction_output = gr.Model3D(
+ height=520,
+ zoom_speed=0.5,
+ pan_speed=0.5,
+ clear_color=[0.0, 0.0, 0.0, 0.0],
+ key="persistent_3d_viewer",
+ elem_id="reconstruction_3d_viewer",
+ )
+ with gr.Tab("Depth"):
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
+ depth_view_selector = gr.Dropdown(
+ choices=["View 1"],
+ value="View 1",
+ label="Select View",
+ scale=2,
+ interactive=True,
+ allow_custom_value=True,
+ )
+ next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
+ depth_map = gr.Image(
+ type="numpy",
+ label="Colorized Depth Map",
+ format="png",
+ interactive=False,
+ )
+ with gr.Tab("Normal"):
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_normal_btn = gr.Button(
+ "◀ Previous", size="sm", scale=1
+ )
+ normal_view_selector = gr.Dropdown(
+ choices=["View 1"],
+ value="View 1",
+ label="Select View",
+ scale=2,
+ interactive=True,
+ allow_custom_value=True,
+ )
+ next_normal_btn = gr.Button("Next ▶", size="sm", scale=1)
+ normal_map = gr.Image(
+ type="numpy",
+ label="Normal Map",
+ format="png",
+ interactive=False,
+ )
+ with gr.Tab("Measure"):
+ gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_measure_btn = gr.Button(
+ "◀ Previous", size="sm", scale=1
+ )
+ measure_view_selector = gr.Dropdown(
+ choices=["View 1"],
+ value="View 1",
+ label="Select View",
+ scale=2,
+ interactive=True,
+ allow_custom_value=True,
+ )
+ next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
+ measure_image = gr.Image(
+ type="numpy",
+ show_label=False,
+ format="webp",
+ interactive=False,
+ sources=[],
+ )
+ measure_text = gr.Markdown("")
+
+ with gr.Row():
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
+ clear_btn = gr.ClearButton(
+ [
+ input_video,
+ input_images,
+ reconstruction_output,
+ log_output,
+ target_dir_output,
+ image_gallery,
+ ],
+ scale=1,
+ )
+
+ with gr.Row():
+ conf_thres = gr.Slider(
+ minimum=0,
+ maximum=100,
+ value=0,
+ step=0.1,
+ label="Confidence Threshold (%), only shown in depth and normals",
+ )
+ frame_filter = gr.Dropdown(
+ choices=["All"], value="All", label="Show Points from Frame"
+ )
+ with gr.Column():
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
+ filter_sky = gr.Checkbox(
+ label="Filter Sky (using skyseg.onnx)", value=False
+ )
+ filter_black_bg = gr.Checkbox(
+ label="Filter Black Background", value=False
+ )
+ filter_white_bg = gr.Checkbox(
+ label="Filter White Background", value=False
+ )
+ mask_ambiguous = gr.Checkbox(label="Mask Ambiguous", value=True)
+
+ # ---------------------- Example Scenes Section ----------------------
+ gr.Markdown("## Example Scenes")
+ gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
+
+ # Get scene information
+ scenes = get_scene_info("examples")
+
+ # Create thumbnail grid (4 columns, N rows)
+ if scenes:
+ for i in range(0, len(scenes), 4): # Process 4 scenes per row
+ with gr.Row():
+ for j in range(4):
+ scene_idx = i + j
+ if scene_idx < len(scenes):
+ scene = scenes[scene_idx]
+ with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
+ # Clickable thumbnail
+ scene_img = gr.Image(
+ value=scene["thumbnail"],
+ height=150,
+ interactive=False,
+ show_label=False,
+ elem_id=f"scene_thumb_{scene['name']}",
+ sources=[],
+ )
+
+ # Scene name and image count as text below thumbnail
+ gr.Markdown(
+ f"**{scene['name']}** \n {scene['num_images']} images",
+ elem_classes=["scene-info"],
+ )
+
+ # Connect thumbnail click to load scene
+ scene_img.select(
+ fn=lambda name=scene["name"]: load_example_scene(name),
+ outputs=[
+ reconstruction_output,
+ target_dir_output,
+ image_gallery,
+ log_output,
+ ],
+ )
+ else:
+ # Empty column to maintain grid structure
+ with gr.Column(scale=1):
+ pass
+
+ # -------------------------------------------------------------------------
+ # "Reconstruct" button logic:
+ # - Clear fields
+ # - Update log
+ # - gradio_demo(...) with the existing target_dir
+ # - Then set is_example = "False"
+ # -------------------------------------------------------------------------
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
+ fn=update_log, inputs=[], outputs=[log_output]
+ ).then(
+ fn=gradio_demo,
+ inputs=[
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ show_cam,
+ filter_sky,
+ filter_black_bg,
+ filter_white_bg,
+ mask_ambiguous,
+ ],
+ outputs=[
+ reconstruction_output,
+ log_output,
+ frame_filter,
+ processed_data_state,
+ depth_map,
+ normal_map,
+ measure_image,
+ measure_text,
+ depth_view_selector,
+ normal_view_selector,
+ measure_view_selector,
+ ],
+ ).then(
+ fn=lambda: "False",
+ inputs=[],
+ outputs=[is_example], # set is_example to "False"
+ )
+
+ # -------------------------------------------------------------------------
+ # Real-time Visualization Updates
+ # -------------------------------------------------------------------------
+ def update_all_visualizations_on_conf_change(
+ processed_data,
+ depth_selector,
+ normal_selector,
+ conf_thres_val,
+ target_dir,
+ frame_filter,
+ show_cam,
+ is_example,
+ ):
+ """Update 3D view and all tabs when confidence threshold changes"""
+
+ # Update 3D pointcloud visualization
+ glb_file, log_msg = update_visualization(
+ target_dir,
+ conf_thres_val,
+ frame_filter,
+ show_cam,
+ is_example,
+ )
+
+ # Update depth and normal tabs with new confidence threshold
+ depth_vis = None
+ normal_vis = None
+
+ if processed_data is not None:
+ # Get current view indices from selectors
+ try:
+ depth_view_idx = (
+ int(depth_selector.split()[1]) - 1 if depth_selector else 0
+ )
+ except:
+ depth_view_idx = 0
+
+ try:
+ normal_view_idx = (
+ int(normal_selector.split()[1]) - 1 if normal_selector else 0
+ )
+ except:
+ normal_view_idx = 0
+
+ # Update visualizations with new confidence threshold
+ depth_vis = update_depth_view(
+ processed_data, depth_view_idx, conf_thres=conf_thres_val
+ )
+ normal_vis = update_normal_view(
+ processed_data, normal_view_idx, conf_thres=conf_thres_val
+ )
+
+ return glb_file, log_msg, depth_vis, normal_vis
+
+ conf_thres.change(
+ fn=update_all_visualizations_on_conf_change,
+ inputs=[
+ processed_data_state,
+ depth_view_selector,
+ normal_view_selector,
+ conf_thres,
+ target_dir_output,
+ frame_filter,
+ show_cam,
+ is_example,
+ ],
+ outputs=[reconstruction_output, log_output, depth_map, normal_map],
+ )
+ frame_filter.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ show_cam,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
+ show_cam.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ show_cam,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
+ filter_sky.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ show_cam,
+ is_example,
+ filter_sky,
+ filter_black_bg,
+ filter_white_bg,
+ mask_ambiguous,
+ ],
+ [reconstruction_output, log_output],
+ )
+ filter_black_bg.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ show_cam,
+ is_example,
+ filter_sky,
+ filter_black_bg,
+ filter_white_bg,
+ mask_ambiguous,
+ ],
+ [reconstruction_output, log_output],
+ )
+ filter_white_bg.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ show_cam,
+ is_example,
+ filter_sky,
+ filter_black_bg,
+ filter_white_bg,
+ mask_ambiguous,
+ ],
+ [reconstruction_output, log_output],
+ )
+ mask_ambiguous.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ show_cam,
+ is_example,
+ filter_sky,
+ filter_black_bg,
+ filter_white_bg,
+ mask_ambiguous,
+ ],
+ [reconstruction_output, log_output],
+ )
+
+ # -------------------------------------------------------------------------
+ # Auto-update gallery whenever user uploads or changes their files
+ # -------------------------------------------------------------------------
+ input_video.change(
+ fn=update_gallery_on_upload,
+ inputs=[input_video, input_images, s_time_interval],
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
+ )
+ input_images.change(
+ fn=update_gallery_on_upload,
+ inputs=[input_video, input_images, s_time_interval],
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
+ )
+
+ # -------------------------------------------------------------------------
+ # Measure tab functionality
+ # -------------------------------------------------------------------------
+ measure_image.select(
+ fn=measure,
+ inputs=[processed_data_state, measure_points_state, measure_view_selector],
+ outputs=[measure_image, measure_points_state, measure_text],
+ )
+
+ # -------------------------------------------------------------------------
+ # Navigation functionality for Depth, Normal, and Measure tabs
+ # -------------------------------------------------------------------------
+
+ # Depth tab navigation
+ prev_depth_btn.click(
+ fn=lambda processed_data, current_selector, conf_thres_val: navigate_depth_view(
+ processed_data, current_selector, -1, conf_thres=conf_thres_val
+ ),
+ inputs=[processed_data_state, depth_view_selector, conf_thres],
+ outputs=[depth_view_selector, depth_map],
+ )
+
+ next_depth_btn.click(
+ fn=lambda processed_data, current_selector, conf_thres_val: navigate_depth_view(
+ processed_data, current_selector, 1, conf_thres=conf_thres_val
+ ),
+ inputs=[processed_data_state, depth_view_selector, conf_thres],
+ outputs=[depth_view_selector, depth_map],
+ )
+
+ depth_view_selector.change(
+ fn=lambda processed_data, selector_value, conf_thres_val: (
+ update_depth_view(
+ processed_data,
+ int(selector_value.split()[1]) - 1,
+ conf_thres=conf_thres_val,
+ )
+ if selector_value
+ else None
+ ),
+ inputs=[processed_data_state, depth_view_selector, conf_thres],
+ outputs=[depth_map],
+ )
+
+ # Normal tab navigation
+ prev_normal_btn.click(
+ fn=lambda processed_data,
+ current_selector,
+ conf_thres_val: navigate_normal_view(
+ processed_data, current_selector, -1, conf_thres=conf_thres_val
+ ),
+ inputs=[processed_data_state, normal_view_selector, conf_thres],
+ outputs=[normal_view_selector, normal_map],
+ )
+
+ next_normal_btn.click(
+ fn=lambda processed_data,
+ current_selector,
+ conf_thres_val: navigate_normal_view(
+ processed_data, current_selector, 1, conf_thres=conf_thres_val
+ ),
+ inputs=[processed_data_state, normal_view_selector, conf_thres],
+ outputs=[normal_view_selector, normal_map],
+ )
+
+ normal_view_selector.change(
+ fn=lambda processed_data, selector_value, conf_thres_val: (
+ update_normal_view(
+ processed_data,
+ int(selector_value.split()[1]) - 1,
+ conf_thres=conf_thres_val,
+ )
+ if selector_value
+ else None
+ ),
+ inputs=[processed_data_state, normal_view_selector, conf_thres],
+ outputs=[normal_map],
+ )
+
+ # Measure tab navigation
+ prev_measure_btn.click(
+ fn=lambda processed_data, current_selector: navigate_measure_view(
+ processed_data, current_selector, -1
+ ),
+ inputs=[processed_data_state, measure_view_selector],
+ outputs=[measure_view_selector, measure_image, measure_points_state],
+ )
+
+ next_measure_btn.click(
+ fn=lambda processed_data, current_selector: navigate_measure_view(
+ processed_data, current_selector, 1
+ ),
+ inputs=[processed_data_state, measure_view_selector],
+ outputs=[measure_view_selector, measure_image, measure_points_state],
+ )
+
+ measure_view_selector.change(
+ fn=lambda processed_data, selector_value: (
+ update_measure_view(processed_data, int(selector_value.split()[1]) - 1)
+ if selector_value
+ else (None, [])
+ ),
+ inputs=[processed_data_state, measure_view_selector],
+ outputs=[measure_image, measure_points_state],
+ )
+
+ # -------------------------------------------------------------------------
+ # Acknowledgement section
+ # -------------------------------------------------------------------------
+ gr.HTML(get_acknowledgements_html())
+
+ demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)
diff --git a/app_interactive.py b/app_interactive.py
new file mode 100644
index 0000000000000000000000000000000000000000..f635d1d83698c0332c5bb641222a34c71aee854f
--- /dev/null
+++ b/app_interactive.py
@@ -0,0 +1,9 @@
+import gradio as gr
+
+
+def greet(name):
+ return "Hello " + name + "!!"
+
+
+demo = gr.Interface(fn=greet, inputs="text", outputs="text")
+demo.launch()
diff --git a/configs/calibration_benchmark.yaml b/configs/calibration_benchmark.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78e6714cc164afbe8a39eee0c23c80267ad708a1
--- /dev/null
+++ b/configs/calibration_benchmark.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - machine: aws
+ - model: default
+ - dataset: default
+ - _self_
+
+output_dir: ${hydra:run.dir}
+root_data_dir: ${machine.root_data_dir}
+mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir}
+root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir}
+root_experiments_dir: ${machine.root_experiments_dir}
+root_uniception_pretrained_checkpoints_dir: ${machine.root_uniception_pretrained_checkpoints_dir}
+
+### Benchmarking args
+seed: 0
+# Disable CUDNN Benchmark (Disable for variable resolution & number of view training)
+disable_cudnn_benchmark: true
+# Batch size for inference (Metrics are computed per multi-view set and averaged, not per batch of multi-view sets)
+batch_size: 20
+# Use mixed precision for inference
+amp: 1
+# Floating point type to use for mixed precision
+amp_dtype: "bf16"
diff --git a/configs/dataset/ase_wai/default.yaml b/configs/dataset/ase_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/ase_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/ase_wai/train/default.yaml b/configs/dataset/ase_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2c77c0b49e50d015695f61db5f2ec4fd42fc8ca8
--- /dev/null
+++ b/configs/dataset/ase_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ASEWAI(
+ split='${dataset.ase_wai.train.split}',
+ resolution=${dataset.ase_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.ase_wai.train.principal_point_centered},
+ aug_crop=${dataset.ase_wai.train.aug_crop},
+ transform='${dataset.ase_wai.train.transform}',
+ data_norm_type='${dataset.ase_wai.train.data_norm_type}',
+ ROOT='${dataset.ase_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.ase_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.ase_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.ase_wai.train.variable_num_views},
+ num_views=${dataset.ase_wai.train.num_views},
+ covisibility_thres=${dataset.ase_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/ase
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/ase_wai/val/default.yaml b/configs/dataset/ase_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ee2a92715e80edb49535765455ee30d6f782fb2c
--- /dev/null
+++ b/configs/dataset/ase_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ASEWAI(
+ split='${dataset.ase_wai.val.split}',
+ resolution=${dataset.ase_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.ase_wai.val.principal_point_centered},
+ seed=${dataset.ase_wai.val.seed},
+ transform='${dataset.ase_wai.val.transform}',
+ data_norm_type='${dataset.ase_wai.val.data_norm_type}',
+ ROOT='${dataset.ase_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.ase_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.ase_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.ase_wai.val.variable_num_views},
+ num_views=${dataset.ase_wai.val.num_views},
+ covisibility_thres=${dataset.ase_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_ase}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/ase
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/bedlam_wai/default.yaml b/configs/dataset/bedlam_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/bedlam_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/bedlam_wai/train/default.yaml b/configs/dataset/bedlam_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..11dc8db66e605f9f42a33677c1bcba7236ca3ef7
--- /dev/null
+++ b/configs/dataset/bedlam_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "BedlamWAI(
+ split='${dataset.bedlam_wai.train.split}',
+ resolution=${dataset.bedlam_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.bedlam_wai.train.principal_point_centered},
+ aug_crop=${dataset.bedlam_wai.train.aug_crop},
+ transform='${dataset.bedlam_wai.train.transform}',
+ data_norm_type='${dataset.bedlam_wai.train.data_norm_type}',
+ ROOT='${dataset.bedlam_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.bedlam_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.bedlam_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.bedlam_wai.train.variable_num_views},
+ num_views=${dataset.bedlam_wai.train.num_views},
+ covisibility_thres=${dataset.bedlam_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/bedlam
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/bedlam_wai/val/default.yaml b/configs/dataset/bedlam_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d1471050f84b32fa1858a4d11ad4dc798c0f002
--- /dev/null
+++ b/configs/dataset/bedlam_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "BedlamWAI(
+ split='${dataset.bedlam_wai.val.split}',
+ resolution=${dataset.bedlam_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.bedlam_wai.val.principal_point_centered},
+ seed=${dataset.bedlam_wai.val.seed},
+ transform='${dataset.bedlam_wai.val.transform}',
+ data_norm_type='${dataset.bedlam_wai.val.data_norm_type}',
+ ROOT='${dataset.bedlam_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.bedlam_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.bedlam_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.bedlam_wai.val.variable_num_views},
+ num_views=${dataset.bedlam_wai.val.num_views},
+ covisibility_thres=${dataset.bedlam_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_bedlam}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/bedlam
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/benchmark_512_eth3d_snpp_tav2.yaml b/configs/dataset/benchmark_512_eth3d_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0bc7a5fd13da943404ada22c02028c3667e1cfc0
--- /dev/null
+++ b/configs/dataset/benchmark_512_eth3d_snpp_tav2.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 2
+
+# Test Resolution
+resolution_test_eth3d: ${dataset.resolution_options.512_1_52_ar}
+resolution_test_scannetpp: ${dataset.resolution_options.512_1_52_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.512_1_00_ar}
+
+# Test Set
+# Sample 10 multi-view sets from each scene
+# ETH3D: 13 scenes
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 130 @ ${dataset.eth3d_wai.test.dataset_str}
+ + 300 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/benchmark_512_snpp_tav2.yaml b/configs/dataset/benchmark_512_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4be5baf2584b8557cef2d80d1e8e41bc9b4689e8
--- /dev/null
+++ b/configs/dataset/benchmark_512_snpp_tav2.yaml
@@ -0,0 +1,17 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 2
+
+# Test Resolution
+resolution_test_scannetpp: ${dataset.resolution_options.512_1_52_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.512_1_00_ar}
+
+# Test Set
+# Sample 10 multi-view sets from each scene
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 300 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/benchmark_518_eth3d_snpp_tav2.yaml b/configs/dataset/benchmark_518_eth3d_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a5251facb27c98ce7625829012e3f289cd11ed12
--- /dev/null
+++ b/configs/dataset/benchmark_518_eth3d_snpp_tav2.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 2
+
+# Test Resolution
+resolution_test_eth3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_test_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+
+# Test Set
+# Sample 10 multi-view sets from each scene
+# ETH3D: 13 scenes
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 130 @ ${dataset.eth3d_wai.test.dataset_str}
+ + 300 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/benchmark_518_snpp_tav2.yaml b/configs/dataset/benchmark_518_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4b15f825631ce0421857902f570c33812c6236f2
--- /dev/null
+++ b/configs/dataset/benchmark_518_snpp_tav2.yaml
@@ -0,0 +1,17 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 2
+
+# Test Resolution
+resolution_test_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+
+# Test Set
+# Sample 10 multi-view sets from each scene
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 300 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/benchmark_sv_calib_518_many_ar_eth3d_snpp_tav2.yaml b/configs/dataset/benchmark_sv_calib_518_many_ar_eth3d_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d6266d248ef2d02b11b462ce922f0dfb8b52c91
--- /dev/null
+++ b/configs/dataset/benchmark_sv_calib_518_many_ar_eth3d_snpp_tav2.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 1
+
+# Test Resolution
+resolution_test_eth3d: ${dataset.resolution_options.518_many_ar}
+resolution_test_scannetpp: ${dataset.resolution_options.518_many_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.518_many_ar}
+
+# Test Set
+# Sample 20 frames from each scene
+# ETH3D: 13 scenes
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 260 @ ${dataset.eth3d_wai.test.dataset_str}
+ + 600 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 100 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/blendedmvs_wai/default.yaml b/configs/dataset/blendedmvs_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/blendedmvs_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/blendedmvs_wai/train/default.yaml b/configs/dataset/blendedmvs_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..39391d2da22bf69b6e8cc3d860c278d812f08944
--- /dev/null
+++ b/configs/dataset/blendedmvs_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "BlendedMVSWAI(
+ split='${dataset.blendedmvs_wai.train.split}',
+ resolution=${dataset.blendedmvs_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.blendedmvs_wai.train.principal_point_centered},
+ aug_crop=${dataset.blendedmvs_wai.train.aug_crop},
+ transform='${dataset.blendedmvs_wai.train.transform}',
+ data_norm_type='${dataset.blendedmvs_wai.train.data_norm_type}',
+ ROOT='${dataset.blendedmvs_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.blendedmvs_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.blendedmvs_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.blendedmvs_wai.train.variable_num_views},
+ num_views=${dataset.blendedmvs_wai.train.num_views},
+ covisibility_thres=${dataset.blendedmvs_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/blendedmvs
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/blendedmvs_wai/val/default.yaml b/configs/dataset/blendedmvs_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bbb876e5c41a4859d7c8f602b8ac37c7142a868b
--- /dev/null
+++ b/configs/dataset/blendedmvs_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "BlendedMVSWAI(
+ split='${dataset.blendedmvs_wai.val.split}',
+ resolution=${dataset.blendedmvs_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.blendedmvs_wai.val.principal_point_centered},
+ seed=${dataset.blendedmvs_wai.val.seed},
+ transform='${dataset.blendedmvs_wai.val.transform}',
+ data_norm_type='${dataset.blendedmvs_wai.val.data_norm_type}',
+ ROOT='${dataset.blendedmvs_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.blendedmvs_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.blendedmvs_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.blendedmvs_wai.val.variable_num_views},
+ num_views=${dataset.blendedmvs_wai.val.num_views},
+ covisibility_thres=${dataset.blendedmvs_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_blendedmvs}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/blendedmvs
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/default.yaml b/configs/dataset/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..84b4954bca4b0ac2127f7eafdc896ae4c46ce1b7
--- /dev/null
+++ b/configs/dataset/default.yaml
@@ -0,0 +1,45 @@
+defaults:
+ - resolution_options: default
+ - ase_wai: default
+ - bedlam_wai: default
+ - blendedmvs_wai: default
+ - dl3dv_wai: default
+ - dtu_wai: default
+ - dynamicreplica_wai: default
+ - eth3d_wai: default
+ - gta_sfm_wai: default
+ - matrixcity_wai: default
+ - megadepth_wai: default
+ - mpsd_wai: default
+ - mvs_synth_wai: default
+ - paralleldomain4d_wai: default
+ - sailvos3d_wai: default
+ - scannetpp_wai: default
+ - spring_wai: default
+ - structured3d_wai: default
+ - tav2_wb_wai: default
+ - unrealstereo4k_wai: default
+ - xrooms_wai: default
+
+# Training Set, For example: BlendedMVS(split='train', resolution=(512, 384), transform=...)
+train_dataset: ???
+# Validation Set
+test_dataset: "[null]"
+# Number of workers for dataloader
+num_workers: 12
+# Default resolution for training
+resolution_train: ???
+# Default resolution for validation
+resolution_val: ???
+# Number of views parameter for multi-view datasets
+num_views: 2
+# Use a centered principal point for all images
+principal_point_centered: false
+# Default config for multi-view datasets
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+val:
+ variable_num_views: false
+test:
+ variable_num_views: false
diff --git a/configs/dataset/dl3dv_wai/default.yaml b/configs/dataset/dl3dv_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/dl3dv_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/dl3dv_wai/train/default.yaml b/configs/dataset/dl3dv_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..af9610c79e25da5e6638043bfa0b8cec5acd4666
--- /dev/null
+++ b/configs/dataset/dl3dv_wai/train/default.yaml
@@ -0,0 +1,28 @@
+dataset_str:
+ "DL3DVWAI(
+ split='${dataset.dl3dv_wai.train.split}',
+ resolution=${dataset.dl3dv_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.dl3dv_wai.train.principal_point_centered},
+ aug_crop=${dataset.dl3dv_wai.train.aug_crop},
+ transform='${dataset.dl3dv_wai.train.transform}',
+ data_norm_type='${dataset.dl3dv_wai.train.data_norm_type}',
+ ROOT='${dataset.dl3dv_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.dl3dv_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.dl3dv_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.dl3dv_wai.train.variable_num_views},
+ num_views=${dataset.dl3dv_wai.train.num_views},
+ covisibility_thres=${dataset.dl3dv_wai.train.covisibility_thres},
+ mvs_confidence_filter_thres=${dataset.dl3dv_wai.train.mvs_confidence_filter_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dl3dv
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
+mvs_confidence_filter_thres: 0.25
diff --git a/configs/dataset/dl3dv_wai/val/default.yaml b/configs/dataset/dl3dv_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e4ba1e12ee503c5e1c867f375d682096dff5fdb
--- /dev/null
+++ b/configs/dataset/dl3dv_wai/val/default.yaml
@@ -0,0 +1,28 @@
+dataset_str:
+ "DL3DVWAI(
+ split='${dataset.dl3dv_wai.val.split}',
+ resolution=${dataset.dl3dv_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.dl3dv_wai.val.principal_point_centered},
+ seed=${dataset.dl3dv_wai.val.seed},
+ transform='${dataset.dl3dv_wai.val.transform}',
+ data_norm_type='${dataset.dl3dv_wai.val.data_norm_type}',
+ ROOT='${dataset.dl3dv_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.dl3dv_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.dl3dv_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.dl3dv_wai.val.variable_num_views},
+ num_views=${dataset.dl3dv_wai.val.num_views},
+ covisibility_thres=${dataset.dl3dv_wai.val.covisibility_thres},
+ mvs_confidence_filter_thres=${dataset.dl3dv_wai.val.mvs_confidence_filter_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_dl3dv}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dl3dv
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
+mvs_confidence_filter_thres: 0.25
diff --git a/configs/dataset/dtu_wai/default.yaml b/configs/dataset/dtu_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b1278dcc74c8a2ee16b87a31ebabca50234ab9fa
--- /dev/null
+++ b/configs/dataset/dtu_wai/default.yaml
@@ -0,0 +1,2 @@
+defaults:
+ - test: default
diff --git a/configs/dataset/dtu_wai/test/default.yaml b/configs/dataset/dtu_wai/test/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7910a3aae6da7cef8a26ca77932b090d998cf8bb
--- /dev/null
+++ b/configs/dataset/dtu_wai/test/default.yaml
@@ -0,0 +1,22 @@
+dataset_str:
+ "DTUWAI(
+ resolution=${dataset.dtu_wai.test.dataset_resolution},
+ principal_point_centered=${dataset.dtu_wai.test.principal_point_centered},
+ seed=${dataset.dtu_wai.test.seed},
+ transform='${dataset.dtu_wai.test.transform}',
+ data_norm_type='${dataset.dtu_wai.test.data_norm_type}',
+ ROOT='${dataset.dtu_wai.test.ROOT}',
+ dataset_metadata_dir='${dataset.dtu_wai.test.dataset_metadata_dir}',
+ variable_num_views=${dataset.dtu_wai.test.variable_num_views},
+ num_views=${dataset.dtu_wai.test.num_views},
+ covisibility_thres=${dataset.dtu_wai.test.covisibility_thres})"
+dataset_resolution: ${dataset.resolution_test_dtu}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dtu
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+variable_num_views: ${dataset.test.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/dynamicreplica_wai/default.yaml b/configs/dataset/dynamicreplica_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/dynamicreplica_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/dynamicreplica_wai/train/default.yaml b/configs/dataset/dynamicreplica_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aa8b82c5692d494da325bbe7419e9a4d31db7c5f
--- /dev/null
+++ b/configs/dataset/dynamicreplica_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "DynamicReplicaWAI(
+ split='${dataset.dynamicreplica_wai.train.split}',
+ resolution=${dataset.dynamicreplica_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.dynamicreplica_wai.train.principal_point_centered},
+ aug_crop=${dataset.dynamicreplica_wai.train.aug_crop},
+ transform='${dataset.dynamicreplica_wai.train.transform}',
+ data_norm_type='${dataset.dynamicreplica_wai.train.data_norm_type}',
+ ROOT='${dataset.dynamicreplica_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.dynamicreplica_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.dynamicreplica_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.dynamicreplica_wai.train.variable_num_views},
+ num_views=${dataset.dynamicreplica_wai.train.num_views},
+ covisibility_thres=${dataset.dynamicreplica_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dynamicreplica
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/dynamicreplica_wai/val/default.yaml b/configs/dataset/dynamicreplica_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7f50be2705174dad31dc4d1be1f965a4419f36f3
--- /dev/null
+++ b/configs/dataset/dynamicreplica_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "DynamicReplicaWAI(
+ split='${dataset.dynamicreplica_wai.val.split}',
+ resolution=${dataset.dynamicreplica_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.dynamicreplica_wai.val.principal_point_centered},
+ seed=${dataset.dynamicreplica_wai.val.seed},
+ transform='${dataset.dynamicreplica_wai.val.transform}',
+ data_norm_type='${dataset.dynamicreplica_wai.val.data_norm_type}',
+ ROOT='${dataset.dynamicreplica_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.dynamicreplica_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.dynamicreplica_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.dynamicreplica_wai.val.variable_num_views},
+ num_views=${dataset.dynamicreplica_wai.val.num_views},
+ covisibility_thres=${dataset.dynamicreplica_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_dynamicreplica}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dynamicreplica
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/eth3d_wai/default.yaml b/configs/dataset/eth3d_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b1278dcc74c8a2ee16b87a31ebabca50234ab9fa
--- /dev/null
+++ b/configs/dataset/eth3d_wai/default.yaml
@@ -0,0 +1,2 @@
+defaults:
+ - test: default
diff --git a/configs/dataset/eth3d_wai/test/default.yaml b/configs/dataset/eth3d_wai/test/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0848ef4ab87819185c059567cf66d36dabc152d4
--- /dev/null
+++ b/configs/dataset/eth3d_wai/test/default.yaml
@@ -0,0 +1,22 @@
+dataset_str:
+ "ETH3DWAI(
+ resolution=${dataset.eth3d_wai.test.dataset_resolution},
+ principal_point_centered=${dataset.eth3d_wai.test.principal_point_centered},
+ seed=${dataset.eth3d_wai.test.seed},
+ transform='${dataset.eth3d_wai.test.transform}',
+ data_norm_type='${dataset.eth3d_wai.test.data_norm_type}',
+ ROOT='${dataset.eth3d_wai.test.ROOT}',
+ dataset_metadata_dir='${dataset.eth3d_wai.test.dataset_metadata_dir}',
+ variable_num_views=${dataset.eth3d_wai.test.variable_num_views},
+ num_views=${dataset.eth3d_wai.test.num_views},
+ covisibility_thres=${dataset.eth3d_wai.test.covisibility_thres})"
+dataset_resolution: ${dataset.resolution_test_eth3d}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/eth3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+variable_num_views: ${dataset.test.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.025
diff --git a/configs/dataset/gta_sfm_wai/default.yaml b/configs/dataset/gta_sfm_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/gta_sfm_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/gta_sfm_wai/train/default.yaml b/configs/dataset/gta_sfm_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..971b4f4a78f3207d38f667860becf035269a46a6
--- /dev/null
+++ b/configs/dataset/gta_sfm_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "GTASfMWAI(
+ split='${dataset.gta_sfm_wai.train.split}',
+ resolution=${dataset.gta_sfm_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.gta_sfm_wai.train.principal_point_centered},
+ aug_crop=${dataset.gta_sfm_wai.train.aug_crop},
+ transform='${dataset.gta_sfm_wai.train.transform}',
+ data_norm_type='${dataset.gta_sfm_wai.train.data_norm_type}',
+ ROOT='${dataset.gta_sfm_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.gta_sfm_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.gta_sfm_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.gta_sfm_wai.train.variable_num_views},
+ num_views=${dataset.gta_sfm_wai.train.num_views},
+ covisibility_thres=${dataset.gta_sfm_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/gta_sfm
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/gta_sfm_wai/val/default.yaml b/configs/dataset/gta_sfm_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..430ac9e292dc1059dcccbbab6bdf82b0f46f391e
--- /dev/null
+++ b/configs/dataset/gta_sfm_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "GTASfMWAI(
+ split='${dataset.gta_sfm_wai.val.split}',
+ resolution=${dataset.gta_sfm_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.gta_sfm_wai.val.principal_point_centered},
+ seed=${dataset.gta_sfm_wai.val.seed},
+ transform='${dataset.gta_sfm_wai.val.transform}',
+ data_norm_type='${dataset.gta_sfm_wai.val.data_norm_type}',
+ ROOT='${dataset.gta_sfm_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.gta_sfm_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.gta_sfm_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.gta_sfm_wai.val.variable_num_views},
+ num_views=${dataset.gta_sfm_wai.val.num_views},
+ covisibility_thres=${dataset.gta_sfm_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_gta_sfm}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/gta_sfm
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/matrixcity_wai/default.yaml b/configs/dataset/matrixcity_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/matrixcity_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/matrixcity_wai/train/default.yaml b/configs/dataset/matrixcity_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ca7412ba48e10620b969af8311bb9e511bf5e437
--- /dev/null
+++ b/configs/dataset/matrixcity_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MatrixCityWAI(
+ split='${dataset.matrixcity_wai.train.split}',
+ resolution=${dataset.matrixcity_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.matrixcity_wai.train.principal_point_centered},
+ aug_crop=${dataset.matrixcity_wai.train.aug_crop},
+ transform='${dataset.matrixcity_wai.train.transform}',
+ data_norm_type='${dataset.matrixcity_wai.train.data_norm_type}',
+ ROOT='${dataset.matrixcity_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.matrixcity_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.matrixcity_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.matrixcity_wai.train.variable_num_views},
+ num_views=${dataset.matrixcity_wai.train.num_views},
+ covisibility_thres=${dataset.matrixcity_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/matrixcity
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/matrixcity_wai/val/default.yaml b/configs/dataset/matrixcity_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..64a73059704da9e721b18b949ec3487b436c3607
--- /dev/null
+++ b/configs/dataset/matrixcity_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MatrixCityWAI(
+ split='${dataset.matrixcity_wai.val.split}',
+ resolution=${dataset.matrixcity_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.matrixcity_wai.val.principal_point_centered},
+ seed=${dataset.matrixcity_wai.val.seed},
+ transform='${dataset.matrixcity_wai.val.transform}',
+ data_norm_type='${dataset.matrixcity_wai.val.data_norm_type}',
+ ROOT='${dataset.matrixcity_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.matrixcity_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.matrixcity_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.matrixcity_wai.val.variable_num_views},
+ num_views=${dataset.matrixcity_wai.val.num_views},
+ covisibility_thres=${dataset.matrixcity_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_matrixcity}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/matrixcity
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/megadepth_wai/default.yaml b/configs/dataset/megadepth_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/megadepth_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/megadepth_wai/train/default.yaml b/configs/dataset/megadepth_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..901443908e42340cb2bbdaa2fd4b0614c1748003
--- /dev/null
+++ b/configs/dataset/megadepth_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MegaDepthWAI(
+ split='${dataset.megadepth_wai.train.split}',
+ resolution=${dataset.megadepth_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.megadepth_wai.train.principal_point_centered},
+ aug_crop=${dataset.megadepth_wai.train.aug_crop},
+ transform='${dataset.megadepth_wai.train.transform}',
+ data_norm_type='${dataset.megadepth_wai.train.data_norm_type}',
+ ROOT='${dataset.megadepth_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.megadepth_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.megadepth_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.megadepth_wai.train.variable_num_views},
+ num_views=${dataset.megadepth_wai.train.num_views},
+ covisibility_thres=${dataset.megadepth_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/megadepth
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/megadepth_wai/val/default.yaml b/configs/dataset/megadepth_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b37ff6fc3966d54b40ed40cb66f1b0979598a897
--- /dev/null
+++ b/configs/dataset/megadepth_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MegaDepthWAI(
+ split='${dataset.megadepth_wai.val.split}',
+ resolution=${dataset.megadepth_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.megadepth_wai.val.principal_point_centered},
+ seed=${dataset.megadepth_wai.val.seed},
+ transform='${dataset.megadepth_wai.val.transform}',
+ data_norm_type='${dataset.megadepth_wai.val.data_norm_type}',
+ ROOT='${dataset.megadepth_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.megadepth_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.megadepth_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.megadepth_wai.val.variable_num_views},
+ num_views=${dataset.megadepth_wai.val.num_views},
+ covisibility_thres=${dataset.megadepth_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_megadepth}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/megadepth
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml b/configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a5c5f087b72828f98ae5406f871e763145dbd7b0
--- /dev/null
+++ b/configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml
@@ -0,0 +1,53 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 2_450_000 @ ${dataset.ase_wai.train.dataset_str}
+ + 250_000 @ ${dataset.dl3dv_wai.train.dataset_str}
+ + 12_400 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 1_675_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 3_000 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 36_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 22_600 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 800 @ ${dataset.spring_wai.train.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 200 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml b/configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e97bbb2c38a81334f7a57eb93854774c4df01b78
--- /dev/null
+++ b/configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml
@@ -0,0 +1,56 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 58_000 @ ${dataset.ase_wai.train.dataset_str}
+ + 58_000 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 45_000 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 58_000 @ ${dataset.megadepth_wai.train.dataset_str}
+ + 58_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 58_000 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 58_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 58_000 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 58_000 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 2_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 58_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 5_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.megadepth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_13d_512_many_ar_24ipg_16g.yaml b/configs/dataset/megatrain_13d_512_many_ar_24ipg_16g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e23dd5cca53addeb5515e213b8918b9de8d20da
--- /dev/null
+++ b/configs/dataset/megatrain_13d_512_many_ar_24ipg_16g.yaml
@@ -0,0 +1,59 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.512_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.512_1_00_ar}
+resolution_val_blendedmvs: ${dataset.resolution_options.512_1_33_ar}
+resolution_val_dl3dv: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_megadepth: ${dataset.resolution_options.512_1_52_ar}
+resolution_val_mpsd: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.512_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.512_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.512_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.512_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.512_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 52_500 @ ${dataset.ase_wai.train.dataset_str}
+ + 52_500 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 52_500 @ ${dataset.dl3dv_wai.train.dataset_str}
+ + 40_000 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 52_500 @ ${dataset.megadepth_wai.train.dataset_str}
+ + 52_500 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 52_500 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 52_500 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 52_500 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 52_500 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 2_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 52_500 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 5_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.megadepth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_13d_518_many_ar_24ipg_16g.yaml b/configs/dataset/megatrain_13d_518_many_ar_24ipg_16g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..187bc543689db308d17000869933bc88e41027db
--- /dev/null
+++ b/configs/dataset/megatrain_13d_518_many_ar_24ipg_16g.yaml
@@ -0,0 +1,59 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 52_500 @ ${dataset.ase_wai.train.dataset_str}
+ + 52_500 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 52_500 @ ${dataset.dl3dv_wai.train.dataset_str}
+ + 40_000 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 52_500 @ ${dataset.megadepth_wai.train.dataset_str}
+ + 52_500 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 52_500 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 52_500 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 52_500 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 52_500 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 2_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 52_500 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 5_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.megadepth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_13d_518_many_ar_48ipg_64g.yaml b/configs/dataset/megatrain_13d_518_many_ar_48ipg_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..199214082801e43e36b1f70930a8b3a2d24cfd3b
--- /dev/null
+++ b/configs/dataset/megatrain_13d_518_many_ar_48ipg_64g.yaml
@@ -0,0 +1,59 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 420_000 @ ${dataset.ase_wai.train.dataset_str}
+ + 420_000 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 420_000 @ ${dataset.dl3dv_wai.train.dataset_str}
+ + 320_000 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 420_000 @ ${dataset.megadepth_wai.train.dataset_str}
+ + 420_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 420_000 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 420_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 420_000 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 420_000 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 16_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 420_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 44_000 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.megadepth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_6d_518_many_ar_48ipg_64g.yaml b/configs/dataset/megatrain_6d_518_many_ar_48ipg_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..add073ff63121c292c35867704eaf08b237d23ef
--- /dev/null
+++ b/configs/dataset/megatrain_6d_518_many_ar_48ipg_64g.yaml
@@ -0,0 +1,38 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 1_120_000 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 1_120_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 1_120_000 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 44_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 1_120_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 116_000 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_6d_518_many_ar_48ipg_8g.yaml b/configs/dataset/megatrain_6d_518_many_ar_48ipg_8g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..31dc47a184bb839be495af040263561eba0f2ba3
--- /dev/null
+++ b/configs/dataset/megatrain_6d_518_many_ar_48ipg_8g.yaml
@@ -0,0 +1,38 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 140_000 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 140_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 140_000 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 5_500 @ ${dataset.spring_wai.train.dataset_str}
+ + 140_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 14_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/mpsd_wai/default.yaml b/configs/dataset/mpsd_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/mpsd_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/mpsd_wai/train/default.yaml b/configs/dataset/mpsd_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2ad2e3ca2aec9e6b12c01a8e3a4b6ff0a635b95c
--- /dev/null
+++ b/configs/dataset/mpsd_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MPSDWAI(
+ split='${dataset.mpsd_wai.train.split}',
+ resolution=${dataset.mpsd_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.mpsd_wai.train.principal_point_centered},
+ aug_crop=${dataset.mpsd_wai.train.aug_crop},
+ transform='${dataset.mpsd_wai.train.transform}',
+ data_norm_type='${dataset.mpsd_wai.train.data_norm_type}',
+ ROOT='${dataset.mpsd_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.mpsd_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.mpsd_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.mpsd_wai.train.variable_num_views},
+ num_views=${dataset.mpsd_wai.train.num_views},
+ covisibility_thres=${dataset.mpsd_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/mpsd
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.15
diff --git a/configs/dataset/mpsd_wai/val/default.yaml b/configs/dataset/mpsd_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cec643ba5bf830fe64cb2cbc6ef83ce28913d0ff
--- /dev/null
+++ b/configs/dataset/mpsd_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MPSDWAI(
+ split='${dataset.mpsd_wai.val.split}',
+ resolution=${dataset.mpsd_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.mpsd_wai.val.principal_point_centered},
+ seed=${dataset.mpsd_wai.val.seed},
+ transform='${dataset.mpsd_wai.val.transform}',
+ data_norm_type='${dataset.mpsd_wai.val.data_norm_type}',
+ ROOT='${dataset.mpsd_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.mpsd_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.mpsd_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.mpsd_wai.val.variable_num_views},
+ num_views=${dataset.mpsd_wai.val.num_views},
+ covisibility_thres=${dataset.mpsd_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_mpsd}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/mpsd
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.15
diff --git a/configs/dataset/mvs_synth_wai/default.yaml b/configs/dataset/mvs_synth_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/mvs_synth_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/mvs_synth_wai/train/default.yaml b/configs/dataset/mvs_synth_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3f325cfb8d9e34e8364c3652467bfc97e9ea4b70
--- /dev/null
+++ b/configs/dataset/mvs_synth_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MVSSynthWAI(
+ split='${dataset.mvs_synth_wai.train.split}',
+ resolution=${dataset.mvs_synth_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.mvs_synth_wai.train.principal_point_centered},
+ aug_crop=${dataset.mvs_synth_wai.train.aug_crop},
+ transform='${dataset.mvs_synth_wai.train.transform}',
+ data_norm_type='${dataset.mvs_synth_wai.train.data_norm_type}',
+ ROOT='${dataset.mvs_synth_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.mvs_synth_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.mvs_synth_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.mvs_synth_wai.train.variable_num_views},
+ num_views=${dataset.mvs_synth_wai.train.num_views},
+ covisibility_thres=${dataset.mvs_synth_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/mvs_synth
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/mvs_synth_wai/val/default.yaml b/configs/dataset/mvs_synth_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f84630347c54b059ec5afb3c1aecf9a883769259
--- /dev/null
+++ b/configs/dataset/mvs_synth_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MVSSynthWAI(
+ split='${dataset.mvs_synth_wai.val.split}',
+ resolution=${dataset.mvs_synth_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.mvs_synth_wai.val.principal_point_centered},
+ seed=${dataset.mvs_synth_wai.val.seed},
+ transform='${dataset.mvs_synth_wai.val.transform}',
+ data_norm_type='${dataset.mvs_synth_wai.val.data_norm_type}',
+ ROOT='${dataset.mvs_synth_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.mvs_synth_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.mvs_synth_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.mvs_synth_wai.val.variable_num_views},
+ num_views=${dataset.mvs_synth_wai.val.num_views},
+ covisibility_thres=${dataset.mvs_synth_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_mvs_synth}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/mvs_synth
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/paralleldomain4d_wai/default.yaml b/configs/dataset/paralleldomain4d_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/paralleldomain4d_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/paralleldomain4d_wai/train/default.yaml b/configs/dataset/paralleldomain4d_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..14d3f6046ec5ccf83f489c5905093e94720101bd
--- /dev/null
+++ b/configs/dataset/paralleldomain4d_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ParallelDomain4DWAI(
+ split='${dataset.paralleldomain4d_wai.train.split}',
+ resolution=${dataset.paralleldomain4d_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.paralleldomain4d_wai.train.principal_point_centered},
+ aug_crop=${dataset.paralleldomain4d_wai.train.aug_crop},
+ transform='${dataset.paralleldomain4d_wai.train.transform}',
+ data_norm_type='${dataset.paralleldomain4d_wai.train.data_norm_type}',
+ ROOT='${dataset.paralleldomain4d_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.paralleldomain4d_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.paralleldomain4d_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.paralleldomain4d_wai.train.variable_num_views},
+ num_views=${dataset.paralleldomain4d_wai.train.num_views},
+ covisibility_thres=${dataset.paralleldomain4d_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/paralleldomain4d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/paralleldomain4d_wai/val/default.yaml b/configs/dataset/paralleldomain4d_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d9409e28d5ba686866286411e820a376fadbb645
--- /dev/null
+++ b/configs/dataset/paralleldomain4d_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ParallelDomain4DWAI(
+ split='${dataset.paralleldomain4d_wai.val.split}',
+ resolution=${dataset.paralleldomain4d_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.paralleldomain4d_wai.val.principal_point_centered},
+ seed=${dataset.paralleldomain4d_wai.val.seed},
+ transform='${dataset.paralleldomain4d_wai.val.transform}',
+ data_norm_type='${dataset.paralleldomain4d_wai.val.data_norm_type}',
+ ROOT='${dataset.paralleldomain4d_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.paralleldomain4d_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.paralleldomain4d_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.paralleldomain4d_wai.val.variable_num_views},
+ num_views=${dataset.paralleldomain4d_wai.val.num_views},
+ covisibility_thres=${dataset.paralleldomain4d_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_paralleldomain4d}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/paralleldomain4d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/resolution_options/default.yaml b/configs/dataset/resolution_options/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7adb12ff5a911bc004bd0e1ab832992708b95a4b
--- /dev/null
+++ b/configs/dataset/resolution_options/default.yaml
@@ -0,0 +1,77 @@
+518_many_ar: '[(518, 518), (518, 392), (518, 336), (518, 294), (518, 252), (518, 168), (392, 518), (336, 518), (294, 518), (252, 518)]'
+518_many_landscape_ar: '[(518, 518), (518, 392), (518, 336), (518, 294), (518, 252), (518, 168)]'
+518_many_non_square_landscape_ar: '[(518, 392), (518, 336), (518, 294), (518, 252), (518, 168)]'
+518_0_50_ar: (252, 518) # 1:2
+518_0_56_ar: (294, 518) # 9:16
+518_0_66_ar: (336, 518) # 2:3
+518_0_75_ar: (392, 518) # 3:4
+518_1_00_ar: (518, 518) # 1:1
+518_1_33_ar: (518, 392) # 4:3
+518_1_52_ar: (518, 336) # 3:2
+518_1_77_ar: (518, 294) # 16:9
+518_2_00_ar: (518, 252) # 2:1
+518_3_20_ar: (518, 168) # 3.2:1
+512_many_ar: '[(512, 512), (512, 384), (512, 336), (512, 288), (512, 256), (512, 160), (384, 512), (336, 512), (288, 512), (256, 512)]'
+512_many_landscape_ar: '[(512, 512), (512, 384), (512, 336), (512, 288), (512, 256), (512, 160)]'
+512_many_non_square_landscape_ar: '[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)]'
+512_0_50_ar: (256, 512)
+512_0_56_ar: (288, 512)
+512_0_66_ar: (336, 512)
+512_0_75_ar: (384, 512)
+512_1_00_ar: (512, 512)
+512_1_33_ar: (512, 384)
+512_1_52_ar: (512, 336)
+512_1_77_ar: (512, 288)
+512_2_00_ar: (512, 256)
+512_3_20_ar: (512, 160)
+504_many_ar: '[(504, 504), (504, 378), (504, 322), (504, 280), (504, 238), (504, 154), (378, 504), (322, 504), (280, 504), (238, 504)]'
+504_many_landscape_ar: '[(504, 504), (504, 378), (504, 322), (504, 280), (504, 238), (504, 154)]'
+504_many_non_square_landscape_ar: '[(504, 378), (504, 322), (504, 280), (504, 238), (504, 154)]'
+504_0_50_ar: (238, 504)
+504_0_56_ar: (280, 504)
+504_0_66_ar: (322, 504)
+504_0_75_ar: (378, 504)
+504_1_00_ar: (504, 504)
+504_1_33_ar: (504, 378)
+504_1_52_ar: (504, 322)
+504_1_77_ar: (504, 280)
+504_2_00_ar: (504, 238)
+504_3_20_ar: (504, 154)
+448_many_ar: '[(448, 448), (448, 336), (448, 294), (448, 252), (448, 224), (448, 140), (336, 448), (294, 448), (252, 448), (224, 448)]'
+448_many_landscape_ar: '[(448, 448), (448, 336), (448, 294), (448, 252), (448, 224), (448, 140)]'
+448_many_non_square_landscape_ar: '[(448, 336), (448, 294), (448, 252), (448, 224), (448, 140)]'
+448_0_50_ar: (224, 448)
+448_0_56_ar: (252, 448)
+448_0_66_ar: (294, 448)
+448_0_75_ar: (336, 448)
+448_1_00_ar: (448, 448)
+448_1_33_ar: (448, 336)
+448_1_52_ar: (448, 294)
+448_1_77_ar: (448, 252)
+448_2_00_ar: (448, 224)
+448_3_20_ar: (448, 140)
+224_many_ar_14ps: '[(224, 224), (224, 168), (224, 154), (224, 126), (224, 112), (224, 70), (168, 224), (154, 224), (126, 224), (112, 224)]'
+224_many_landscape_ar_14ps: '[(224, 224), (224, 168), (224, 154), (224, 126), (224, 112), (224, 70)]'
+224_many_non_square_landscape_ar_14ps: '[(224, 168), (224, 154), (224, 126), (224, 112), (224, 70)]'
+224_0_50_ar_14ps: (112, 224)
+224_0_56_ar_14ps: (126, 224)
+224_0_66_ar_14ps: (154, 224)
+224_0_75_ar_14ps: (168, 224)
+224_1_00_ar: (224, 224)
+224_1_33_ar_14ps: (224, 168)
+224_1_52_ar_14ps: (224, 154)
+224_1_77_ar_14ps: (224, 126)
+224_2_00_ar_14ps: (224, 112)
+224_3_20_ar_14ps: (224, 70)
+224_many_ar_16ps: '[(224, 224), (224, 176), (224, 160), (224, 128), (224, 112), (224, 80), (176, 224), (160, 224), (128, 224), (112, 224)]'
+224_many_landscape_ar_16ps: '[(224, 224), (224, 176), (224, 160), (224, 128), (224, 112), (224, 80)]'
+224_many_non_square_landscape_ar_16ps: '[(224, 176), (224, 160), (224, 128), (224, 112), (224, 80)]'
+224_0_50_ar_16ps: (112, 224)
+224_0_56_ar_16ps: (128, 224)
+224_0_66_ar_16ps: (160, 224)
+224_0_75_ar_16ps: (176, 224)
+224_1_33_ar_16ps: (224, 176)
+224_1_52_ar_16ps: (224, 160)
+224_1_77_ar_16ps: (224, 128)
+224_2_00_ar_16ps: (224, 112)
+224_3_20_ar_16ps: (224, 80)
diff --git a/configs/dataset/sailvos3d_wai/default.yaml b/configs/dataset/sailvos3d_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/sailvos3d_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/sailvos3d_wai/train/default.yaml b/configs/dataset/sailvos3d_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2187deaedc5af41b3a7c32d7371f24b78028f8ac
--- /dev/null
+++ b/configs/dataset/sailvos3d_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "SAILVOS3DWAI(
+ split='${dataset.sailvos3d_wai.train.split}',
+ resolution=${dataset.sailvos3d_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.sailvos3d_wai.train.principal_point_centered},
+ aug_crop=${dataset.sailvos3d_wai.train.aug_crop},
+ transform='${dataset.sailvos3d_wai.train.transform}',
+ data_norm_type='${dataset.sailvos3d_wai.train.data_norm_type}',
+ ROOT='${dataset.sailvos3d_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.sailvos3d_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.sailvos3d_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.sailvos3d_wai.train.variable_num_views},
+ num_views=${dataset.sailvos3d_wai.train.num_views},
+ covisibility_thres=${dataset.sailvos3d_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/sailvos3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/sailvos3d_wai/val/default.yaml b/configs/dataset/sailvos3d_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f5936d324143f03abc8d100061e1b38217b2d35d
--- /dev/null
+++ b/configs/dataset/sailvos3d_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "SAILVOS3DWAI(
+ split='${dataset.sailvos3d_wai.val.split}',
+ resolution=${dataset.sailvos3d_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.sailvos3d_wai.val.principal_point_centered},
+ seed=${dataset.sailvos3d_wai.val.seed},
+ transform='${dataset.sailvos3d_wai.val.transform}',
+ data_norm_type='${dataset.sailvos3d_wai.val.data_norm_type}',
+ ROOT='${dataset.sailvos3d_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.sailvos3d_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.sailvos3d_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.sailvos3d_wai.val.variable_num_views},
+ num_views=${dataset.sailvos3d_wai.val.num_views},
+ covisibility_thres=${dataset.sailvos3d_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_sailvos3d}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/sailvos3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/scannetpp_wai/default.yaml b/configs/dataset/scannetpp_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d760ece911feb086e706395c16cb1eb86d758d79
--- /dev/null
+++ b/configs/dataset/scannetpp_wai/default.yaml
@@ -0,0 +1,4 @@
+defaults:
+ - train: default
+ - val: default
+ - test: default
diff --git a/configs/dataset/scannetpp_wai/test/default.yaml b/configs/dataset/scannetpp_wai/test/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a14c936aac97853cfebfcdc6671247facd77d291
--- /dev/null
+++ b/configs/dataset/scannetpp_wai/test/default.yaml
@@ -0,0 +1,24 @@
+dataset_str:
+ "ScanNetPPWAI(
+ split='${dataset.scannetpp_wai.test.split}',
+ resolution=${dataset.scannetpp_wai.test.dataset_resolution},
+ principal_point_centered=${dataset.scannetpp_wai.test.principal_point_centered},
+ seed=${dataset.scannetpp_wai.test.seed},
+ transform='${dataset.scannetpp_wai.test.transform}',
+ data_norm_type='${dataset.scannetpp_wai.test.data_norm_type}',
+ ROOT='${dataset.scannetpp_wai.test.ROOT}',
+ dataset_metadata_dir='${dataset.scannetpp_wai.test.dataset_metadata_dir}',
+ variable_num_views=${dataset.scannetpp_wai.test.variable_num_views},
+ num_views=${dataset.scannetpp_wai.test.num_views},
+ covisibility_thres=${dataset.scannetpp_wai.test.covisibility_thres})"
+split: 'test'
+dataset_resolution: ${dataset.resolution_test_scannetpp}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/scannetppv2
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+variable_num_views: ${dataset.test.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/scannetpp_wai/train/default.yaml b/configs/dataset/scannetpp_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9cad920ccb90ad492fdff199413fb9c569bd4bac
--- /dev/null
+++ b/configs/dataset/scannetpp_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ScanNetPPWAI(
+ split='${dataset.scannetpp_wai.train.split}',
+ resolution=${dataset.scannetpp_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.scannetpp_wai.train.principal_point_centered},
+ aug_crop=${dataset.scannetpp_wai.train.aug_crop},
+ transform='${dataset.scannetpp_wai.train.transform}',
+ data_norm_type='${dataset.scannetpp_wai.train.data_norm_type}',
+ ROOT='${dataset.scannetpp_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.scannetpp_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.scannetpp_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.scannetpp_wai.train.variable_num_views},
+ num_views=${dataset.scannetpp_wai.train.num_views},
+ covisibility_thres=${dataset.scannetpp_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/scannetppv2
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/scannetpp_wai/val/default.yaml b/configs/dataset/scannetpp_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a829383ecdcf50c2fee841e9dd8be61b2bd4f599
--- /dev/null
+++ b/configs/dataset/scannetpp_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ScanNetPPWAI(
+ split='${dataset.scannetpp_wai.val.split}',
+ resolution=${dataset.scannetpp_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.scannetpp_wai.val.principal_point_centered},
+ seed=${dataset.scannetpp_wai.val.seed},
+ transform='${dataset.scannetpp_wai.val.transform}',
+ data_norm_type='${dataset.scannetpp_wai.val.data_norm_type}',
+ ROOT='${dataset.scannetpp_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.scannetpp_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.scannetpp_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.scannetpp_wai.val.variable_num_views},
+ num_views=${dataset.scannetpp_wai.val.num_views},
+ covisibility_thres=${dataset.scannetpp_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_scannetpp}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/scannetppv2
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/spring_wai/default.yaml b/configs/dataset/spring_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/spring_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/spring_wai/train/default.yaml b/configs/dataset/spring_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b073ccb8aa70267c06ce127a901b7e3fe47773db
--- /dev/null
+++ b/configs/dataset/spring_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "SpringWAI(
+ split='${dataset.spring_wai.train.split}',
+ resolution=${dataset.spring_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.spring_wai.train.principal_point_centered},
+ aug_crop=${dataset.spring_wai.train.aug_crop},
+ transform='${dataset.spring_wai.train.transform}',
+ data_norm_type='${dataset.spring_wai.train.data_norm_type}',
+ ROOT='${dataset.spring_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.spring_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.spring_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.spring_wai.train.variable_num_views},
+ num_views=${dataset.spring_wai.train.num_views},
+ covisibility_thres=${dataset.spring_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/spring
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/spring_wai/val/default.yaml b/configs/dataset/spring_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..92a28220e58b62cd46e5ed772fb9e4f690f4ad8e
--- /dev/null
+++ b/configs/dataset/spring_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "SpringWAI(
+ split='${dataset.spring_wai.val.split}',
+ resolution=${dataset.spring_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.spring_wai.val.principal_point_centered},
+ seed=${dataset.spring_wai.val.seed},
+ transform='${dataset.spring_wai.val.transform}',
+ data_norm_type='${dataset.spring_wai.val.data_norm_type}',
+ ROOT='${dataset.spring_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.spring_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.spring_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.spring_wai.val.variable_num_views},
+ num_views=${dataset.spring_wai.val.num_views},
+ covisibility_thres=${dataset.spring_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_spring}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/spring
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/structured3d_wai/default.yaml b/configs/dataset/structured3d_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/structured3d_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/structured3d_wai/train/default.yaml b/configs/dataset/structured3d_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8556d92cf9f9e830b35c797778ca63f79bc31a56
--- /dev/null
+++ b/configs/dataset/structured3d_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "Structured3DWAI(
+ split='${dataset.structured3d_wai.train.split}',
+ resolution=${dataset.structured3d_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.structured3d_wai.train.principal_point_centered},
+ aug_crop=${dataset.structured3d_wai.train.aug_crop},
+ transform='${dataset.structured3d_wai.train.transform}',
+ data_norm_type='${dataset.structured3d_wai.train.data_norm_type}',
+ ROOT='${dataset.structured3d_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.structured3d_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.structured3d_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.structured3d_wai.train.variable_num_views},
+ num_views=${dataset.structured3d_wai.train.num_views},
+ covisibility_thres=${dataset.structured3d_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/structured3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/structured3d_wai/val/default.yaml b/configs/dataset/structured3d_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..396a399b95993b4c20e452aa77f8e43d7177205b
--- /dev/null
+++ b/configs/dataset/structured3d_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "Structured3DWAI(
+ split='${dataset.structured3d_wai.val.split}',
+ resolution=${dataset.structured3d_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.structured3d_wai.val.principal_point_centered},
+ seed=${dataset.structured3d_wai.val.seed},
+ transform='${dataset.structured3d_wai.val.transform}',
+ data_norm_type='${dataset.structured3d_wai.val.data_norm_type}',
+ ROOT='${dataset.structured3d_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.structured3d_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.structured3d_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.structured3d_wai.val.variable_num_views},
+ num_views=${dataset.structured3d_wai.val.num_views},
+ covisibility_thres=${dataset.structured3d_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_structured3d}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/structured3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/tav2_wb_wai/default.yaml b/configs/dataset/tav2_wb_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d760ece911feb086e706395c16cb1eb86d758d79
--- /dev/null
+++ b/configs/dataset/tav2_wb_wai/default.yaml
@@ -0,0 +1,4 @@
+defaults:
+ - train: default
+ - val: default
+ - test: default
diff --git a/configs/dataset/tav2_wb_wai/test/default.yaml b/configs/dataset/tav2_wb_wai/test/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a616954c15ea1e4ec4dbe4083efb33b5349de2ff
--- /dev/null
+++ b/configs/dataset/tav2_wb_wai/test/default.yaml
@@ -0,0 +1,24 @@
+dataset_str:
+ "TartanAirV2WBWAI(
+ split='${dataset.tav2_wb_wai.test.split}',
+ resolution=${dataset.tav2_wb_wai.test.dataset_resolution},
+ principal_point_centered=${dataset.tav2_wb_wai.test.principal_point_centered},
+ seed=${dataset.tav2_wb_wai.test.seed},
+ transform='${dataset.tav2_wb_wai.test.transform}',
+ data_norm_type='${dataset.tav2_wb_wai.test.data_norm_type}',
+ ROOT='${dataset.tav2_wb_wai.test.ROOT}',
+ dataset_metadata_dir='${dataset.tav2_wb_wai.test.dataset_metadata_dir}',
+ variable_num_views=${dataset.tav2_wb_wai.test.variable_num_views},
+ num_views=${dataset.tav2_wb_wai.test.num_views},
+ covisibility_thres=${dataset.tav2_wb_wai.test.covisibility_thres})"
+split: 'test'
+dataset_resolution: ${dataset.resolution_test_tav2_wb}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/tav2_wb
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+variable_num_views: ${dataset.test.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/tav2_wb_wai/train/default.yaml b/configs/dataset/tav2_wb_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e897aefbcc774792bffc98bb430e8082dd1a163e
--- /dev/null
+++ b/configs/dataset/tav2_wb_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "TartanAirV2WBWAI(
+ split='${dataset.tav2_wb_wai.train.split}',
+ resolution=${dataset.tav2_wb_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.tav2_wb_wai.train.principal_point_centered},
+ aug_crop=${dataset.tav2_wb_wai.train.aug_crop},
+ transform='${dataset.tav2_wb_wai.train.transform}',
+ data_norm_type='${dataset.tav2_wb_wai.train.data_norm_type}',
+ ROOT='${dataset.tav2_wb_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.tav2_wb_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.tav2_wb_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.tav2_wb_wai.train.variable_num_views},
+ num_views=${dataset.tav2_wb_wai.train.num_views},
+ covisibility_thres=${dataset.tav2_wb_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/tav2_wb
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/tav2_wb_wai/val/default.yaml b/configs/dataset/tav2_wb_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..97b04c1a4778079a38dec5631708ecf97a84e496
--- /dev/null
+++ b/configs/dataset/tav2_wb_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "TartanAirV2WBWAI(
+ split='${dataset.tav2_wb_wai.val.split}',
+ resolution=${dataset.tav2_wb_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.tav2_wb_wai.val.principal_point_centered},
+ seed=${dataset.tav2_wb_wai.val.seed},
+ transform='${dataset.tav2_wb_wai.val.transform}',
+ data_norm_type='${dataset.tav2_wb_wai.val.data_norm_type}',
+ ROOT='${dataset.tav2_wb_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.tav2_wb_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.tav2_wb_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.tav2_wb_wai.val.variable_num_views},
+ num_views=${dataset.tav2_wb_wai.val.num_views},
+ covisibility_thres=${dataset.tav2_wb_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_tav2_wb}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/tav2_wb
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/unrealstereo4k_wai/default.yaml b/configs/dataset/unrealstereo4k_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/unrealstereo4k_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/unrealstereo4k_wai/train/default.yaml b/configs/dataset/unrealstereo4k_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3ecce8dde40fbdfa33da5f3e9d6a2631b9ccb2d4
--- /dev/null
+++ b/configs/dataset/unrealstereo4k_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "UnrealStereo4KWAI(
+ split='${dataset.unrealstereo4k_wai.train.split}',
+ resolution=${dataset.unrealstereo4k_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.unrealstereo4k_wai.train.principal_point_centered},
+ aug_crop=${dataset.unrealstereo4k_wai.train.aug_crop},
+ transform='${dataset.unrealstereo4k_wai.train.transform}',
+ data_norm_type='${dataset.unrealstereo4k_wai.train.data_norm_type}',
+ ROOT='${dataset.unrealstereo4k_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.unrealstereo4k_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.unrealstereo4k_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.unrealstereo4k_wai.train.variable_num_views},
+ num_views=${dataset.unrealstereo4k_wai.train.num_views},
+ covisibility_thres=${dataset.unrealstereo4k_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/unrealstereo4k
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/unrealstereo4k_wai/val/default.yaml b/configs/dataset/unrealstereo4k_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4570290d7259f26925234b208f92243f40c34fb3
--- /dev/null
+++ b/configs/dataset/unrealstereo4k_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "UnrealStereo4KWAI(
+ split='${dataset.unrealstereo4k_wai.val.split}',
+ resolution=${dataset.unrealstereo4k_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.unrealstereo4k_wai.val.principal_point_centered},
+ seed=${dataset.unrealstereo4k_wai.val.seed},
+ transform='${dataset.unrealstereo4k_wai.val.transform}',
+ data_norm_type='${dataset.unrealstereo4k_wai.val.data_norm_type}',
+ ROOT='${dataset.unrealstereo4k_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.unrealstereo4k_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.unrealstereo4k_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.unrealstereo4k_wai.val.variable_num_views},
+ num_views=${dataset.unrealstereo4k_wai.val.num_views},
+ covisibility_thres=${dataset.unrealstereo4k_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_unrealstereo4k}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/unrealstereo4k
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/xrooms_wai/default.yaml b/configs/dataset/xrooms_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/xrooms_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/xrooms_wai/train/default.yaml b/configs/dataset/xrooms_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2a6131e36392f0efe537dbdf3b6767c83d7b9a3b
--- /dev/null
+++ b/configs/dataset/xrooms_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "XRoomsWAI(
+ split='${dataset.xrooms_wai.train.split}',
+ resolution=${dataset.xrooms_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.xrooms_wai.train.principal_point_centered},
+ aug_crop=${dataset.xrooms_wai.train.aug_crop},
+ transform='${dataset.xrooms_wai.train.transform}',
+ data_norm_type='${dataset.xrooms_wai.train.data_norm_type}',
+ ROOT='${dataset.xrooms_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.xrooms_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.xrooms_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.xrooms_wai.train.variable_num_views},
+ num_views=${dataset.xrooms_wai.train.num_views},
+ covisibility_thres=${dataset.xrooms_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/xrooms
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/xrooms_wai/val/default.yaml b/configs/dataset/xrooms_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..90044d43e2cc0f3f6f67cd6e73e27c4898f88d30
--- /dev/null
+++ b/configs/dataset/xrooms_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "XRoomsWAI(
+ split='${dataset.xrooms_wai.val.split}',
+ resolution=${dataset.xrooms_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.xrooms_wai.val.principal_point_centered},
+ seed=${dataset.xrooms_wai.val.seed},
+ transform='${dataset.xrooms_wai.val.transform}',
+ data_norm_type='${dataset.xrooms_wai.val.data_norm_type}',
+ ROOT='${dataset.xrooms_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.xrooms_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.xrooms_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.xrooms_wai.val.variable_num_views},
+ num_views=${dataset.xrooms_wai.val.num_views},
+ covisibility_thres=${dataset.xrooms_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_xrooms}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/xrooms
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dense_n_view_benchmark.yaml b/configs/dense_n_view_benchmark.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..53f63f54d13bb2e421c47f2e5e2b3d23c746ff12
--- /dev/null
+++ b/configs/dense_n_view_benchmark.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - machine: aws
+ - model: default
+ - dataset: default
+ - _self_
+
+output_dir: ${hydra:run.dir}
+root_data_dir: ${machine.root_data_dir}
+mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir}
+root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir}
+root_experiments_dir: ${machine.root_experiments_dir}
+root_uniception_pretrained_checkpoints_dir: ${machine.root_uniception_pretrained_checkpoints_dir}
+
+### Benchmarking args
+seed: 0
+# Disable CUDNN Benchmark (Disable for variable resolution & number of view training)
+disable_cudnn_benchmark: true
+# Batch size for inference (Metrics are computed per multi-view set and averaged, not per batch of multi-view sets)
+batch_size: 10
+# Use mixed precision for inference
+amp: 1
+# Floating point type to use for mixed precision
+amp_dtype: "bf16"
diff --git a/configs/distributed/default.yaml b/configs/distributed/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..780d4ea77de057e5c554aa8cf25f3e7618740b2a
--- /dev/null
+++ b/configs/distributed/default.yaml
@@ -0,0 +1,6 @@
+# Distributed Training Params
+# Number of distributed processes
+world_size: 1
+local_rank: -1
+# Url used to set up distributed training
+dist_url: 'env://'
diff --git a/configs/loss/conf_pm_mask_loss.yaml b/configs/loss/conf_pm_mask_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d55324da8ff5c2b946a606de2640bd7f5160ee9
--- /dev/null
+++ b/configs/loss/conf_pm_mask_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfLoss(Regr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', loss_in_log=True), alpha=0.2) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(Regr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', flatten_across_image_only=True, loss_in_log=True), top_n_percent=5, apply_to_real_data_only=True) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/conf_pm_mask_scale_loss.yaml b/configs/loss/conf_pm_mask_scale_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eb186d734746083fa5bb327fa3c67d9a497eccfa
--- /dev/null
+++ b/configs/loss/conf_pm_mask_scale_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfLoss(PointsPlusScaleRegr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', flatten_across_image_only=True, loss_in_log=True), alpha=0.2) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(PointsPlusScaleRegr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', flatten_across_image_only=True, loss_in_log=True), top_n_percent=5, apply_to_real_data_only=True) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/default.yaml b/configs/loss/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e345afed2a90fa4dc0c6d36e51a835f4451f0329
--- /dev/null
+++ b/configs/loss/default.yaml
@@ -0,0 +1,6 @@
+# Training Loss, For example: "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)"
+train_criterion: ""
+# Validation Loss, For example:
+# "Regr3D_ScaleShiftInv(L21, norm_mode='?avg_dis', ambiguous_loss_value=0)" (DUSt3R)
+# "Regr3D(L21, norm_mode='?avg_dis', ambiguous_loss_value=2)" (MASt3R)
+test_criterion: ""
diff --git a/configs/loss/entangled_metric_loss.yaml b/configs/loss/entangled_metric_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b5107d58bfb4f34cd0bb983a6bbe8fcc27a87ee3
--- /dev/null
+++ b/configs/loss/entangled_metric_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_depth_loss.yaml b/configs/loss/no_depth_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eaa5c251136bb987a6475ec5effef6dc623b7ea2
--- /dev/null
+++ b/configs/loss/no_depth_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, depth_loss_weight=0.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, depth_loss_weight=0.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_log_scaling.yaml b/configs/loss/no_log_scaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e9e6b68279a4b869675fbc11a63c8c564c0b4bb
--- /dev/null
+++ b/configs/loss/no_log_scaling.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=False, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=False, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_points_loss.yaml b/configs/loss/no_points_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7126a4692c90439c1a4a81f056a364a704f1b4b4
--- /dev/null
+++ b/configs/loss/no_points_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=False, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, cam_frame_points_loss_weight=0.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[1]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=False, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, cam_frame_points_loss_weight=0.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[1]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_pose_loss.yaml b/configs/loss/no_pose_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5515eccdce79909e0c6ff40dd8e4ab4356011c5f
--- /dev/null
+++ b/configs/loss/no_pose_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, pose_quats_loss_weight=0.0, pose_trans_loss_weight=0.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, pose_quats_loss_weight=0.0, pose_trans_loss_weight=0.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_ray_dirs_loss.yaml b/configs/loss/no_ray_dirs_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4493b31c269bd4667bec579513a7dcb70868b08e
--- /dev/null
+++ b/configs/loss/no_ray_dirs_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, ray_directions_loss_weight=0.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, ray_directions_loss_weight=0.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_robust_loss.yaml b/configs/loss/no_robust_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..94afee0e092a53707134429d35335b8a9745ee81
--- /dev/null
+++ b/configs/loss/no_robust_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(L2Loss(), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=1.0, gm_loss_weight=1.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.1 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(L2Loss(), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=1.0, gm_loss_weight=1.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.1 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/overall_disentangled_loss.yaml b/configs/loss/overall_disentangled_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..60348d9add9433b451bc1cd59e64d2f1ae773dbf
--- /dev/null
+++ b/configs/loss/overall_disentangled_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfLoss(DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), alpha=0.2, loss_set_indices=[0]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/overall_loss.yaml b/configs/loss/overall_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fcbc8dc28f99c31a7c193bb209d0f4128b7b99a3
--- /dev/null
+++ b/configs/loss/overall_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/overall_loss_weigh_pm_higher.yaml b/configs/loss/overall_loss_weigh_pm_higher.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..23706c64548095601553ebd2eff6919821760d61
--- /dev/null
+++ b/configs/loss/overall_loss_weigh_pm_higher.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/up_to_scale_loss.yaml b/configs/loss/up_to_scale_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f2cc22d132ce3209db0c8f34d3e625cf734ee861
--- /dev/null
+++ b/configs/loss/up_to_scale_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/vggt_loss.yaml b/configs/loss/vggt_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f42d7de190624a128d71d8c028b2787b48a0b201
--- /dev/null
+++ b/configs/loss/vggt_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_z', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2])"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_z', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2])"
diff --git a/configs/machine/aws.yaml b/configs/machine/aws.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..584639ff57977dcb505b43e866443ac8df155295
--- /dev/null
+++ b/configs/machine/aws.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - default
+
+# Root directory containing all datasets
+root_data_dir: "/fsx/xrtech/data"
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: "/fsx/nkeetha/mapanything_dataset_metadata"
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: "/fsx/nkeetha/mapanything_checkpoints"
+# Root directory to log experiments
+root_experiments_dir: "/fsx/nkeetha/experiments"
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: "/fsx/nkeetha/uniception_checkpoints"
diff --git a/configs/machine/default.yaml b/configs/machine/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d89f6d360340a9eb79a12707ce19b07aba3a083a
--- /dev/null
+++ b/configs/machine/default.yaml
@@ -0,0 +1,10 @@
+# Root directory containing all datasets
+root_data_dir: ???
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: ???
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: ???
+# Root directory to log experiments
+root_experiments_dir: ???
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: ???
diff --git a/configs/machine/psc.yaml b/configs/machine/psc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e529be838f2f9838e428bb25244b7b4c94aefa3
--- /dev/null
+++ b/configs/machine/psc.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - default
+
+# Root directory containing all datasets
+root_data_dir: "/ocean/projects/cis220039p/shared/datasets"
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: "/ocean/projects/cis220039p/shared/mapanything_dataset_metadata"
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/nkeetha/code/AnyMap/checkpoints"
+# Root directory to log experiments
+root_experiments_dir: "/ocean/projects/cis220039p/nkeetha/experiments"
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/nkeetha/code/AnyMap/UniCeption/checkpoints"
diff --git a/configs/machine/psc_yuchen.yaml b/configs/machine/psc_yuchen.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d071afa04eadb2cd35b4f41bc12c1e8f820fe27c
--- /dev/null
+++ b/configs/machine/psc_yuchen.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - default
+
+# Root directory containing all datasets
+root_data_dir: "/ocean/projects/cis220039p/shared/datasets"
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: "/ocean/projects/cis220039p/shared/mapanything_dataset_metadata"
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: "/jet/home/yzhang25/AnyMap/checkpoints"
+# Root directory to log experiments
+root_experiments_dir: "/jet/home/yzhang25/AnyMap/outputs"
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/shared/uniception/checkpoints/"
diff --git a/configs/machine/xri_dgx.yaml b/configs/machine/xri_dgx.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ba77beedfdd736d6a6d7c236ff3a1b9033ea8b24
--- /dev/null
+++ b/configs/machine/xri_dgx.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - default
+
+# Root directory containing all datasets
+root_data_dir: "/mnt/xri_mapsresearch/data/nkeetha"
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: "/mnt/xri_mapsresearch/data/nkeetha/mapanything_dataset_metadata"
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: "/mnt/xri_mapsresearch/code/nkeetha/AnyMap/checkpoints"
+# Root directory to log experiments
+root_experiments_dir: "/mnt/xri_mapsresearch/experiments/nkeetha"
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: "/mnt/xri_mapsresearch/code/nkeetha/AnyMap/UniCeption/checkpoints"
diff --git a/configs/model/anycalib.yaml b/configs/model/anycalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..76688024dc56b3cbcef486b09efb1b7a1231b94b
--- /dev/null
+++ b/configs/model/anycalib.yaml
@@ -0,0 +1,11 @@
+# String for model factory
+model_str: "anycalib"
+# Model config
+model_config:
+ name: "anycalib"
+# Image Normalization Type
+data_norm_type: "identity"
+# AnyCalib checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/default.yaml b/configs/model/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0926e7bf93b61f81bc4397ec2a16331d10f756be
--- /dev/null
+++ b/configs/model/default.yaml
@@ -0,0 +1,16 @@
+# String for model factory (Options: "mapanything", "mapanything_ablations", "modular_dust3r", "vggt", "pi3")
+model_str: ???
+# Model config
+model_config:
+ # Path to pretrained model checkpoint
+ pretrained_checkpoint_path: null
+ # Load specific submodules from the checkpoint
+ load_specific_pretrained_submodules: False
+ # List of submodules to load from the checkpoint (if load_specific_pretrained_submodules is True)
+ specific_pretrained_submodules: []
+# Path of a starting checkpoint (to enable backward compatibility with original DUSt3R class)
+pretrained: null
+# Image normalization type
+data_norm_type: ???
+# Torch hub force reload
+torch_hub_force_reload: False
diff --git a/configs/model/dust3r.yaml b/configs/model/dust3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5496bb311883ea3db42f69d9f22a5208b26d28f9
--- /dev/null
+++ b/configs/model/dust3r.yaml
@@ -0,0 +1,23 @@
+# String for model factory
+model_str: "dust3r"
+# Model config
+model_config:
+ name: "dust3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
+ # Scene graph for BA
+ scene_graph: "complete"
+ # Pairwise inference batch size
+ inference_batch_size: 32
+ # Global optim schedule
+ global_optim_schedule: "cosine"
+ # Global optim lr
+ global_optim_lr: 0.01
+ # Number of iterations for global optimization
+ global_optim_niter: 300
+# Image Normalization Type
+data_norm_type: "dust3r"
+# DUSt3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/encoder/croco_512.yaml b/configs/model/encoder/croco_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..37d15cb7ac75f83836b7d1fed956a398b61af9df
--- /dev/null
+++ b/configs/model/encoder/croco_512.yaml
@@ -0,0 +1,16 @@
+# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list)
+encoder_str: "croco"
+# Name of the encoder
+name: "croco_512"
+# Data normalization type
+data_norm_type: "croco"
+# Patch embedding class
+patch_embed_cls: "PatchEmbedDust3R"
+# Image size
+img_size: [512, 512] # This parameter has no influence for PatchEmbedDust3R
+# Path to the pretrained encoder checkpoint
+pretrained_checkpoint_path: '${machine.root_uniception_pretrained_checkpoints_dir}/encoders/CroCo_Encoder_224.pth'
+# Override attributes in the pretrained checkpoint
+override_checkpoint_attributes: True
+# Flag to indicate whether model class uses torch hub
+uses_torch_hub: False
diff --git a/configs/model/encoder/croco_512_data_norm_dust3r.yaml b/configs/model/encoder/croco_512_data_norm_dust3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e98d9d3b91af5fe8ab35256ee8f0fe36a49b9890
--- /dev/null
+++ b/configs/model/encoder/croco_512_data_norm_dust3r.yaml
@@ -0,0 +1,16 @@
+# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list)
+encoder_str: "croco"
+# Name of the encoder
+name: "croco_512_img_norm_dust3r"
+# Data normalization type
+data_norm_type: "dust3r"
+# Patch embedding class
+patch_embed_cls: "PatchEmbedDust3R"
+# Image size
+img_size: [512, 512] # This parameter has no influence for PatchEmbedDust3R
+# Path to the pretrained encoder checkpoint
+pretrained_checkpoint_path: '${machine.root_uniception_pretrained_checkpoints_dir}/encoders/CroCo_Encoder_224.pth'
+# Override attributes in the pretrained checkpoint
+override_checkpoint_attributes: True
+# Flag to indicate whether model class uses torch hub
+uses_torch_hub: False
diff --git a/configs/model/encoder/dinov2_large.yaml b/configs/model/encoder/dinov2_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a15ad5aa58cfd7da159b4713d206d54ac0a8f9a7
--- /dev/null
+++ b/configs/model/encoder/dinov2_large.yaml
@@ -0,0 +1,14 @@
+# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list)
+encoder_str: "dinov2"
+# Name of the encoder
+name: "dinov2_large"
+# Data normalization type
+data_norm_type: "dinov2"
+# ViT size
+size: "large"
+# Registers
+with_registers: False
+# Flag to indicate whether model class uses torch hub
+uses_torch_hub: True
+# Flag to indicate whether to use gradient checkpointing for encoder
+gradient_checkpointing: False
diff --git a/configs/model/encoder/radio_v2_5_large.yaml b/configs/model/encoder/radio_v2_5_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..95ce4103aa426bcb3a7be5b48184778d935afac2
--- /dev/null
+++ b/configs/model/encoder/radio_v2_5_large.yaml
@@ -0,0 +1,10 @@
+# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list)
+encoder_str: "radio"
+# Name of the encoder
+name: "radio_v2.5-large"
+# Data normalization type
+data_norm_type: "radio"
+# Model version
+model_version: "radio_v2.5-l"
+# Flag to indicate whether model class uses torch hub
+uses_torch_hub: True
diff --git a/configs/model/info_sharing/aat_ifr_24_layers.yaml b/configs/model/info_sharing/aat_ifr_24_layers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..34f787a24dc3d3facbdc6e387046f1eafbe2d099
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_24_layers.yaml
@@ -0,0 +1,22 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_24_layers_ifr"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_24_layers_escaling.yaml b/configs/model/info_sharing/aat_ifr_24_layers_escaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..63705b14ff02c25d6a58f59791a10daa479f907f
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_24_layers_escaling.yaml
@@ -0,0 +1,24 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_24_layers_ifr"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Scale Entropy in Attention
+ use_entropy_scaling: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_24_layers_no_ref_view.yaml b/configs/model/info_sharing/aat_ifr_24_layers_no_ref_view.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4abf1d168eca3ccd6c7a8ca2717c1aee1c0e927b
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_24_layers_no_ref_view.yaml
@@ -0,0 +1,22 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_24_layers_ifr_no_ref_view"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: False
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_24_layers_w_view_pe.yaml b/configs/model/info_sharing/aat_ifr_24_layers_w_view_pe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ef2fb6ccae11b740d2ea081181893147d4ee415c
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_24_layers_w_view_pe.yaml
@@ -0,0 +1,26 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_24_layers_ifr_w_view_pe"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
+ # Maximum number of views for positional encoding
+ max_num_views_for_pe: 1000
+ # Use random indices within range (1, max_num_views_for_pe) for positional encoding of non reference views
+ use_rand_idx_pe_for_non_reference_views: True
diff --git a/configs/model/info_sharing/aat_ifr_48_layers.yaml b/configs/model/info_sharing/aat_ifr_48_layers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7249706a5b5031e71218604df44cd377a90dc37b
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_48_layers.yaml
@@ -0,0 +1,26 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_48_layers_ifr"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 23, 35]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "48_layers"
+ # Depth (this includes both frame-wise and gloabl attention layers)
+ depth: 48
+ # Feature dim (similar to ViT-Large)
+ dim: 1024
+ # Number of heads (similar to ViT-Large)
+ num_heads: 16
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_48_layers_escaling.yaml b/configs/model/info_sharing/aat_ifr_48_layers_escaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fc8ca1da63b0a523ea537d47515974473820ef9e
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_48_layers_escaling.yaml
@@ -0,0 +1,28 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_48_layers_ifr"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 23, 35]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "48_layers"
+ # Depth (this includes both frame-wise and gloabl attention layers)
+ depth: 48
+ # Feature dim (similar to ViT-Large)
+ dim: 1024
+ # Number of heads (similar to ViT-Large)
+ num_heads: 16
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Scale Entropy in Attention
+ use_entropy_scaling: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_48_layers_no_ref_view.yaml b/configs/model/info_sharing/aat_ifr_48_layers_no_ref_view.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..68456c32359c48f5e626a2ef9ecb0eb08faa2d54
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_48_layers_no_ref_view.yaml
@@ -0,0 +1,26 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_48_layers_ifr_no_ref_view"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 23, 35]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "48_layers"
+ # Depth (this includes both frame-wise and gloabl attention layers)
+ depth: 48
+ # Feature dim (similar to ViT-Large)
+ dim: 1024
+ # Number of heads (similar to ViT-Large)
+ num_heads: 16
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: False
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/cat_ifr_dust3r.yaml b/configs/model/info_sharing/cat_ifr_dust3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e52e1d37cefc92d7121bb6a4076198fef8a72ee6
--- /dev/null
+++ b/configs/model/info_sharing/cat_ifr_dust3r.yaml
@@ -0,0 +1,18 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "cross_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: "RoPE100"
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "base_cat_ifr_dust3r"
+ # Number of views
+ num_views: 2
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [5, 8]
+ # Normalize intermediate features
+ norm_intermediate: False
+ # Load CroCo cross-attention transformer for DUSt3R Init
+ pretrained_checkpoint_path: '${machine.root_uniception_pretrained_checkpoints_dir}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_CroCo.pth'
diff --git a/configs/model/info_sharing/gat_ifr_24_layers.yaml b/configs/model/info_sharing/gat_ifr_24_layers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..60423714da681f71ca2088f101f4b9fa734b254b
--- /dev/null
+++ b/configs/model/info_sharing/gat_ifr_24_layers.yaml
@@ -0,0 +1,24 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "global_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "gat_24_layers_ifr"
+ # Maximum number of views for positional encoding
+ max_num_views: 1000
+ # Use random indices within range (1, max_num_views) for positional encoding of non reference views
+ use_rand_idx_pe_for_non_reference_views: True
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/gat_ifr_24_layers_escaling.yaml b/configs/model/info_sharing/gat_ifr_24_layers_escaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6143b5b37adde56a6a742b01a9b0d4fc0824fa35
--- /dev/null
+++ b/configs/model/info_sharing/gat_ifr_24_layers_escaling.yaml
@@ -0,0 +1,26 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "global_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "gat_24_layers_ifr"
+ # Maximum number of views for positional encoding
+ max_num_views: 1000
+ # Use random indices within range (1, max_num_views) for positional encoding of non reference views
+ use_rand_idx_pe_for_non_reference_views: True
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Scale Entropy in Attention
+ use_entropy_scaling: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/mapanything.yaml b/configs/model/mapanything.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a69c44ec9cd9d684e2dc0764365d462fcc095030
--- /dev/null
+++ b/configs/model/mapanything.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_24_layers
+ - pred_head: dpt_pose_scale
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything"
+# Model config
+model_config:
+ name: "mapanything"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mapanything_ablations.yaml b/configs/model/mapanything_ablations.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4b8dc0d7a29ca93bdd2fe34660ea5c5a883e7d60
--- /dev/null
+++ b/configs/model/mapanything_ablations.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_24_layers
+ - pred_head: dpt_pose
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything_ablations"
+# Model config
+model_config:
+ name: "mapanything_ablations"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mapanything_inference.yaml b/configs/model/mapanything_inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a081212dc00ee6b2e105254346129e430531f642
--- /dev/null
+++ b/configs/model/mapanything_inference.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_24_layers_escaling
+ - pred_head: dpt_pose_scale
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything"
+# Model config
+model_config:
+ name: "mapanything"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mapanything_large.yaml b/configs/model/mapanything_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..371e2fe6513fee6f5a0619aff5fff0c515a4d8d9
--- /dev/null
+++ b/configs/model/mapanything_large.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_48_layers
+ - pred_head: dpt_pose_scale
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything"
+# Model config
+model_config:
+ name: "mapanything"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mapanything_large_inference.yaml b/configs/model/mapanything_large_inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2e57bf7ddbcc416d4e637276202e60515535e0e1
--- /dev/null
+++ b/configs/model/mapanything_large_inference.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_48_layers_escaling
+ - pred_head: dpt_pose_scale
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything"
+# Model config
+model_config:
+ name: "mapanything"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mast3r.yaml b/configs/model/mast3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dc315b3a24d0748a2b671a6d29e10dda9d446b3b
--- /dev/null
+++ b/configs/model/mast3r.yaml
@@ -0,0 +1,15 @@
+# String for model factory
+model_str: "mast3r"
+# Model config
+model_config:
+ name: "mast3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
+ # Cache dir
+ cache_dir: "${root_pretrained_checkpoints_dir}/mast3r_cache"
+# Image Normalization Type
+data_norm_type: "dust3r"
+# MASt3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/metric_dust3r.yaml b/configs/model/metric_dust3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..329a3cb9378b163049d898d1e54b778d9d2a4b0f
--- /dev/null
+++ b/configs/model/metric_dust3r.yaml
@@ -0,0 +1,23 @@
+# String for model factory
+model_str: "dust3r"
+# Model config
+model_config:
+ name: "metric_dust3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
+ # Scene graph for BA
+ scene_graph: "complete"
+ # Pairwise inference batch size
+ inference_batch_size: 32
+ # Global optim schedule
+ global_optim_schedule: "cosine"
+ # Global optim lr
+ global_optim_lr: 0.01
+ # Number of iterations for global optimization
+ global_optim_niter: 300
+# Image Normalization Type
+data_norm_type: "dust3r"
+# DUSt3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/modular_dust3r_512_dpt.yaml b/configs/model/modular_dust3r_512_dpt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a0ffe75a5af160b0f93e0790bf3788cd1fc1455
--- /dev/null
+++ b/configs/model/modular_dust3r_512_dpt.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+ - encoder: croco_512_data_norm_dust3r
+ - info_sharing: cat_ifr_dust3r
+ - pred_head: dpt
+
+# String for model factory
+model_str: "modular_dust3r"
+# Model config
+model_config:
+ name: "dust3r_512_dpt"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+# Image Normalization Type
+data_norm_type: "dust3r"
diff --git a/configs/model/moge_1.yaml b/configs/model/moge_1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4db9bcf8e7a5637393a71f5a9ae9bf0003b72d90
--- /dev/null
+++ b/configs/model/moge_1.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "moge"
+# Model config
+model_config:
+ name: "moge-1"
+ # MoGe pre-trained model checkpoint string
+ model_string: "Ruicheng/moge-vitl"
+ # Load custom checkpoint
+ load_custom_ckpt: false
+ # Custom checkpoint path
+ custom_ckpt_path: null
+# Image Normalization Type
+data_norm_type: "identity"
+# MoGe checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/moge_2.yaml b/configs/model/moge_2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0b048344b8cd34011ae17fdb7d543a5adce1f403
--- /dev/null
+++ b/configs/model/moge_2.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "moge"
+# Model config
+model_config:
+ name: "moge-2"
+ # MoGe pre-trained model checkpoint string
+ model_string: "Ruicheng/moge-2-vitl"
+ # Load custom checkpoint
+ load_custom_ckpt: false
+ # Custom checkpoint path
+ custom_ckpt_path: null
+# Image Normalization Type
+data_norm_type: "identity"
+# MoGe checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/must3r.yaml b/configs/model/must3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..09bc01eadb2853a1aea3ed316bfae3f2a5ad0371
--- /dev/null
+++ b/configs/model/must3r.yaml
@@ -0,0 +1,15 @@
+# String for model factory
+model_str: "must3r"
+# Model config
+model_config:
+ name: "must3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/MUSt3R_512.pth"
+ # Retrieval Checkpoint path
+ retrieval_ckpt_path: "${root_pretrained_checkpoints_dir}/MUSt3R_512_retrieval_trainingfree.pth"
+# Image Normalization Type
+data_norm_type: "dust3r"
+# MASt3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/pi3.yaml b/configs/model/pi3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..138e1393ffd1996f3412d6d91c775e1769c5c8c6
--- /dev/null
+++ b/configs/model/pi3.yaml
@@ -0,0 +1,13 @@
+# String for model factory
+model_str: "pi3"
+# Model config
+model_config:
+ name: "pi3"
+ # Load pre-trained weights
+ load_pretrained_weights: true
+# Image Normalization Type
+data_norm_type: "identity"
+# Pi3 checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: False
diff --git a/configs/model/pow3r.yaml b/configs/model/pow3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e7141fc206bb51d62ff03093409fa221d204bfee
--- /dev/null
+++ b/configs/model/pow3r.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - default
+ - task: images_only
+
+# String for model factory
+model_str: "pow3r"
+# Model config
+model_config:
+ name: "pow3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/Pow3R_ViTLarge_BaseDecoder_512_linear.pth"
+ # Geometric input config
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: "dust3r"
+# Pow3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/pow3r_ba.yaml b/configs/model/pow3r_ba.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..45c4dbca68f54d79f519e7757cd1490599c4839d
--- /dev/null
+++ b/configs/model/pow3r_ba.yaml
@@ -0,0 +1,29 @@
+defaults:
+ - default
+ - task: images_only
+
+# String for model factory
+model_str: "pow3r_ba"
+# Model config
+model_config:
+ name: "pow3r_ba"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/Pow3R_ViTLarge_BaseDecoder_512_linear.pth"
+ # Geometric input config
+ geometric_input_config: ${model.task}
+ # Scene graph for BA
+ scene_graph: "complete"
+ # Pairwise inference batch size
+ inference_batch_size: 32
+ # Global optim schedule
+ global_optim_schedule: "cosine"
+ # Global optim lr
+ global_optim_lr: 0.01
+ # Number of iterations for global optimization
+ global_optim_niter: 300
+# Image Normalization Type
+data_norm_type: "dust3r"
+# Pow3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask.yaml b/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c98de936666ba0c0f469df99d75935f34f492cc4
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask.yaml
@@ -0,0 +1,22 @@
+# Camera Frame Pointmap + Global Camera Pose (Trans + Quats) + Confidence + Mask
+input_dim: 5
+scene_rep_dim: 3
+type: "campointmap+pose+confidence+mask"
+scene_rep_type: "campointmap+pose"
+dense_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ pointmap_mode: "z_exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5ef92a540e7e2979c73541ed532ee225b0df3c11
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask_scale.yaml
@@ -0,0 +1,27 @@
+# Camera Frame Pointmap + Global Camera Pose (Trans + Quats) + Confidence + Mask + Scene-wide Metric Scaling Factor
+input_dim: 5
+scene_rep_dim: 3
+type: "campointmap+pose+confidence+mask"
+scene_rep_type: "campointmap+pose"
+dense_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ pointmap_mode: "z_exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/pointmap_confidence.yaml b/configs/model/pred_head/adaptor_config/pointmap_confidence.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..96b60f9a47407a048c1f4af698f5de693eb1d109
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_confidence.yaml
@@ -0,0 +1,13 @@
+# Pointmap + Confidence
+input_dim: 4
+scene_rep_dim: 3
+type: "pointmap+confidence"
+scene_rep_type: "pointmap"
+init_dict:
+ name: "pointmap+confidence"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/pointmap_confidence_mask.yaml b/configs/model/pred_head/adaptor_config/pointmap_confidence_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..833df59a2a8d941788cbf9ef6e988cd8e0cf5e6b
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_confidence_mask.yaml
@@ -0,0 +1,13 @@
+# Pointmap + Confidence + Mask
+input_dim: 5
+scene_rep_dim: 3
+type: "pointmap+confidence+mask"
+scene_rep_type: "pointmap"
+init_dict:
+ name: "pointmap+confidence+mask"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/pointmap_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/pointmap_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4230d94989652d7b5ee779bebd7d8ac52d8e3cd7
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_confidence_mask_scale.yaml
@@ -0,0 +1,18 @@
+# Pointmap + Confidence + Mask + Scene-wide Metric Scaling Factor
+input_dim: 5
+scene_rep_dim: 3
+type: "pointmap+confidence+mask"
+scene_rep_type: "pointmap"
+init_dict:
+ name: "pointmap+confidence+mask"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/pointmap_factored_raydirs_depth_pose_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/pointmap_factored_raydirs_depth_pose_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f93a518ba2322cc8e6915abfa4b1cdf00fb2dbf
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_factored_raydirs_depth_pose_confidence_mask_scale.yaml
@@ -0,0 +1,39 @@
+# Global Pointmaps + Ray Directions on Unit Sphere + Depth along Ray + Global Camera Pose (Trans + Quats) + Confidence + Mask + Global Metric Scaling Factor
+input_dim: 9
+scene_rep_dim: 7
+type: "pointmap+raydirs+depth+pose+confidence+mask"
+scene_rep_type: "pointmap+raydirs+depth+pose"
+dense_pred_init_dict:
+ name: "pointmap+raydirs+depth+pose+confidence+mask+scale"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ ray_directions_mode: "linear"
+ ray_directions_normalize_to_unit_sphere: true
+ ray_directions_normalize_to_unit_image_plane: false
+ ray_directions_vmin: ${special_float:"-inf"}
+ ray_directions_vmax: ${special_float:"inf"}
+ ray_directions_clamp_min_of_z_dir: false
+ ray_directions_z_dir_min: ${special_float:"-inf"}
+ depth_mode: "exp"
+ depth_vmin: 0
+ depth_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
+# Flag to decide what representaion to use for global pointmaps
+use_factored_predictions_for_global_pointmaps: true
diff --git a/configs/model/pred_head/adaptor_config/pointmap_raydirs_depth_pose_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/pointmap_raydirs_depth_pose_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bfb2408fd39ab13bd619b3b503e29f9900433a61
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_raydirs_depth_pose_confidence_mask_scale.yaml
@@ -0,0 +1,39 @@
+# Global Pointmaps + Ray Directions on Unit Sphere + Depth along Ray + Global Camera Pose (Trans + Quats) + Confidence + Mask + Global Metric Scaling Factor
+input_dim: 9
+scene_rep_dim: 7
+type: "pointmap+raydirs+depth+pose+confidence+mask"
+scene_rep_type: "pointmap+raydirs+depth+pose"
+dense_pred_init_dict:
+ name: "pointmap+raydirs+depth+pose+confidence+mask+scale"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ ray_directions_mode: "linear"
+ ray_directions_normalize_to_unit_sphere: true
+ ray_directions_normalize_to_unit_image_plane: false
+ ray_directions_vmin: ${special_float:"-inf"}
+ ray_directions_vmax: ${special_float:"inf"}
+ ray_directions_clamp_min_of_z_dir: false
+ ray_directions_z_dir_min: ${special_float:"-inf"}
+ depth_mode: "exp"
+ depth_vmin: 0
+ depth_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
+# Flag to decide what representaion to use for global pointmaps
+use_factored_predictions_for_global_pointmaps: false
diff --git a/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask.yaml b/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ee83ea0b976e3d173302d167b51404446b950b11
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask.yaml
@@ -0,0 +1,29 @@
+# Ray Directions on Unit Sphere + Depth along Ray + Global Camera Pose (Trans + Quats) + Confidence + Mask
+input_dim: 6
+scene_rep_dim: 4
+type: "raydirs+depth+pose+confidence+mask"
+scene_rep_type: "raydirs+depth+pose"
+dense_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask"
+ ray_directions_mode: "linear"
+ ray_directions_normalize_to_unit_sphere: true
+ ray_directions_normalize_to_unit_image_plane: false
+ ray_directions_vmin: ${special_float:"-inf"}
+ ray_directions_vmax: ${special_float:"inf"}
+ ray_directions_clamp_min_of_z_dir: false
+ ray_directions_z_dir_min: ${special_float:"-inf"}
+ depth_mode: "exp"
+ depth_vmin: 0
+ depth_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a939ad9eeed4131990941976bde8124bfbeba15
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask_scale.yaml
@@ -0,0 +1,34 @@
+# Ray Directions on Unit Sphere + Depth along Ray + Global Camera Pose (Trans + Quats) + Confidence + Mask + Global Metric Scaling Factor
+input_dim: 6
+scene_rep_dim: 4
+type: "raydirs+depth+pose+confidence+mask"
+scene_rep_type: "raydirs+depth+pose"
+dense_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ ray_directions_mode: "linear"
+ ray_directions_normalize_to_unit_sphere: true
+ ray_directions_normalize_to_unit_image_plane: false
+ ray_directions_vmin: ${special_float:"-inf"}
+ ray_directions_vmax: ${special_float:"inf"}
+ ray_directions_clamp_min_of_z_dir: false
+ ray_directions_z_dir_min: ${special_float:"-inf"}
+ depth_mode: "exp"
+ depth_vmin: 0
+ depth_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/dpt.yaml b/configs/model/pred_head/dpt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d05fe0b2695236470388b522d0ed70fcf603b1d
--- /dev/null
+++ b/configs/model/pred_head/dpt.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - adaptor_config: pointmap_confidence
+
+type: "dpt"
+feature_head:
+ feature_dim: 256
+ hooks: [0, 1, 2, 3]
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+regressor_head:
+ output_dim: ${model.pred_head.adaptor_config.input_dim}
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+adaptor_type: ${model.pred_head.adaptor_config.type}
+adaptor: ${model.pred_head.adaptor_config.init_dict}
+# Flag to indicate whether to use gradient checkpointing
+gradient_checkpointing: False
diff --git a/configs/model/pred_head/dpt_pose.yaml b/configs/model/pred_head/dpt_pose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..24d50ae5dbfa1ada961b0afde5422ce9dac97a8b
--- /dev/null
+++ b/configs/model/pred_head/dpt_pose.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - adaptor_config: raydirs_depth_pose_confidence_mask
+
+type: "dpt+pose"
+feature_head:
+ feature_dim: 256
+ hooks: [0, 1, 2, 3]
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+regressor_head:
+ output_dim: ${model.pred_head.adaptor_config.input_dim}
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+pose_head:
+ num_resconv_block: 2
+ rot_representation_dim: 4
+adaptor_type: ${model.pred_head.adaptor_config.type}
+dpt_adaptor: ${model.pred_head.adaptor_config.dense_pred_init_dict}
+pose_adaptor: ${model.pred_head.adaptor_config.pose_pred_init_dict}
+# Flag to indicate whether to use gradient checkpointing
+gradient_checkpointing: False
diff --git a/configs/model/pred_head/dpt_pose_scale.yaml b/configs/model/pred_head/dpt_pose_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..068dc181712e4d92f2e1e1e088f363ed698c2582
--- /dev/null
+++ b/configs/model/pred_head/dpt_pose_scale.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - adaptor_config: raydirs_depth_pose_confidence_mask_scale
+
+type: "dpt+pose"
+feature_head:
+ feature_dim: 256
+ hooks: [0, 1, 2, 3]
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+regressor_head:
+ output_dim: ${model.pred_head.adaptor_config.input_dim}
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+pose_head:
+ num_resconv_block: 2
+ rot_representation_dim: 4
+scale_head:
+ output_dim: 1
+adaptor_type: ${model.pred_head.adaptor_config.type}
+dpt_adaptor: ${model.pred_head.adaptor_config.dense_pred_init_dict}
+pose_adaptor: ${model.pred_head.adaptor_config.pose_pred_init_dict}
+scale_adaptor: ${model.pred_head.adaptor_config.scale_pred_init_dict}
+# Flag to indicate whether to use gradient checkpointing
+gradient_checkpointing: False
diff --git a/configs/model/pred_head/dpt_scale.yaml b/configs/model/pred_head/dpt_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..286eb589259cd07c6ff0d2c24e5a58a6d1c748c5
--- /dev/null
+++ b/configs/model/pred_head/dpt_scale.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - adaptor_config: pointmap_confidence_mask_scale
+
+type: "dpt"
+feature_head:
+ feature_dim: 256
+ hooks: [0, 1, 2, 3]
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+regressor_head:
+ output_dim: ${model.pred_head.adaptor_config.input_dim}
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+scale_head:
+ output_dim: 1
+adaptor_type: ${model.pred_head.adaptor_config.type}
+adaptor: ${model.pred_head.adaptor_config.init_dict}
+scale_adaptor: ${model.pred_head.adaptor_config.scale_pred_init_dict}
+# Flag to indicate whether to use gradient checkpointing
+gradient_checkpointing: False
diff --git a/configs/model/task/aug_training.yaml b/configs/model/task/aug_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1dfae15dcd58a86c6cef932a31ca6d74ba760aa4
--- /dev/null
+++ b/configs/model/task/aug_training.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 0.9
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0.05
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 0.5
+# Probability of Geometric Inputs with Depths
+depth_prob: 0.5
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0.5
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0.5
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale quantities for the input metric high quality gt depth
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0.05
+# Probability for skipping input of the metric scale quantities for the input metric pose
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0.05
diff --git a/configs/model/task/calibrated_sfm.yaml b/configs/model/task/calibrated_sfm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d24c0033e82e7db23c7c6c40428d5c295380c8b
--- /dev/null
+++ b/configs/model/task/calibrated_sfm.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/default.yaml b/configs/model/task/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8fdc9a98f4931a014dda25939aa44eadbdd614ed
--- /dev/null
+++ b/configs/model/task/default.yaml
@@ -0,0 +1,27 @@
+# Ray Directions Encoder Config
+ray_dirs_encoder_config:
+ name: "ray_dirs_encoder"
+ in_chans: 3
+ encoder_str: "dense_rep_encoder"
+ apply_pe: false
+# Depth Encoder Config
+depth_encoder_config:
+ name: "depth_encoder"
+ in_chans: 1
+ encoder_str: "dense_rep_encoder"
+ apply_pe: false
+# Cam Rotation (Quats) Encoder Config
+cam_rot_encoder_config:
+ name: "cam_rot_quats_encoder"
+ in_chans: 4
+ encoder_str: "global_rep_encoder"
+# Cam Translation Encoder Config
+cam_trans_encoder_config:
+ name: "cam_trans_encoder"
+ in_chans: 3
+ encoder_str: "global_rep_encoder"
+# Scale Encoder Config
+scale_encoder_config:
+ name: "scale_encoder"
+ in_chans: 1
+ encoder_str: "global_rep_encoder"
diff --git a/configs/model/task/depth_completion.yaml b/configs/model/task/depth_completion.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0461a2c7fce649a4b1fe4e55b80955375197a88e
--- /dev/null
+++ b/configs/model/task/depth_completion.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 1
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/images_only.yaml b/configs/model/task/images_only.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..99607dbd43ffc1687a665cfc32fa67e2d5b5908d
--- /dev/null
+++ b/configs/model/task/images_only.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 0
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 1
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 0
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/mvs.yaml b/configs/model/task/mvs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8cd2850ba2cad9d925a268639d09a5b19158f323
--- /dev/null
+++ b/configs/model/task/mvs.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/mvs_non_metric.yaml b/configs/model/task/mvs_non_metric.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ba80d5b4694e1c2798191d4949dea22688187185
--- /dev/null
+++ b/configs/model/task/mvs_non_metric.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/mvs_training.yaml b/configs/model/task/mvs_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eb6b4db744cb16eb7ea351082ffde5a6edd434cb
--- /dev/null
+++ b/configs/model/task/mvs_training.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale quantities for the input metric high quality gt depth
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale quantities for the input metric pose
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0.05
diff --git a/configs/model/task/non_metric_poses_metric_depth.yaml b/configs/model/task/non_metric_poses_metric_depth.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1bc224e89b0df9c44de0676bb4a085c1276abc4a
--- /dev/null
+++ b/configs/model/task/non_metric_poses_metric_depth.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/non_metric_poses_metric_depth_sparse.yaml b/configs/model/task/non_metric_poses_metric_depth_sparse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9c92ed95b830a78cf0224dacbb139fcc14e6b2f8
--- /dev/null
+++ b/configs/model/task/non_metric_poses_metric_depth_sparse.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 1
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/non_metric_poses_non_metric_depth.yaml b/configs/model/task/non_metric_poses_non_metric_depth.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9a5e9902c59379e6e9c8e9b2384a363f58e510bf
--- /dev/null
+++ b/configs/model/task/non_metric_poses_non_metric_depth.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 1
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/non_metric_poses_non_metric_depth_sparse.yaml b/configs/model/task/non_metric_poses_non_metric_depth_sparse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5ee97881825a37347ad476f6c35b46639de7e296
--- /dev/null
+++ b/configs/model/task/non_metric_poses_non_metric_depth_sparse.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 1
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 1
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/pass_through.yaml b/configs/model/task/pass_through.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..97aa6245875bbd587019d73c16edf169e489d425
--- /dev/null
+++ b/configs/model/task/pass_through.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/posed_sfm.yaml b/configs/model/task/posed_sfm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7c86724a0d08256b91c74ecc625ed46f2a700467
--- /dev/null
+++ b/configs/model/task/posed_sfm.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 0
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/posed_sfm_non_metric.yaml b/configs/model/task/posed_sfm_non_metric.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..62b540e366656503f984ed7053c453ad7bd1386c
--- /dev/null
+++ b/configs/model/task/posed_sfm_non_metric.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 0
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/registration.yaml b/configs/model/task/registration.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e1f7fc206dc81b91aa27f99d304f4386ffb705bb
--- /dev/null
+++ b/configs/model/task/registration.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/registration_sparse.yaml b/configs/model/task/registration_sparse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..548d43b39ace8f8f70b81e98cc9fd46144a43bc6
--- /dev/null
+++ b/configs/model/task/registration_sparse.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 1
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/registration_training.yaml b/configs/model/task/registration_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0676132925de37e46560550a917a83a777662beb
--- /dev/null
+++ b/configs/model/task/registration_training.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0.5
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale quantities for the input metric high quality gt depth
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0.05
+# Probability for skipping input of the metric scale quantities for the input metric pose
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/vggt.yaml b/configs/model/vggt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1911b3979b708da810046b5bdd59fa6addb25368
--- /dev/null
+++ b/configs/model/vggt.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "vggt"
+# Model config
+model_config:
+ name: "vggt"
+ # Load pre-trained weights
+ load_pretrained_weights: true
+ # Load custom checkpoint
+ load_custom_ckpt: false
+ # Custom checkpoint path
+ custom_ckpt_path: null
+# Image Normalization Type
+data_norm_type: "identity"
+# VGGT checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/vggt_commercial.yaml b/configs/model/vggt_commercial.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2a5dea00b52828f553daa3502f9474ed378d1192
--- /dev/null
+++ b/configs/model/vggt_commercial.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "vggt"
+# Model config
+model_config:
+ name: "vggt"
+ # Load pre-trained weights
+ load_pretrained_weights: true
+ # Load custom checkpoint
+ load_custom_ckpt: true
+ # Custom checkpoint path
+ custom_ckpt_path: "${root_pretrained_checkpoints_dir}/vggt_1B_commercial.pt"
+# Image Normalization Type
+data_norm_type: "identity"
+# VGGT checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/vggt_non_pretrained.yaml b/configs/model/vggt_non_pretrained.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a67a996aa754c7fe3ec53fa66bdad45fdc3db736
--- /dev/null
+++ b/configs/model/vggt_non_pretrained.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "vggt"
+# Model config
+model_config:
+ name: "vggt"
+ # Load pre-trained weights
+ load_pretrained_weights: false
+ # Load custom checkpoint
+ load_custom_ckpt: false
+ # Custom checkpoint path
+ custom_ckpt_path: null
+# Image Normalization Type
+data_norm_type: "identity"
+# VGGT checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/rmvd_benchmark.yaml b/configs/rmvd_benchmark.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b4042f433e9eabcf9fde4221d81c443495af8a64
--- /dev/null
+++ b/configs/rmvd_benchmark.yaml
@@ -0,0 +1,34 @@
+defaults:
+ - machine: aws
+ - model: default
+ - dataset: default
+ - _self_
+
+# Path Settings
+output_dir: ${hydra:run.dir}
+root_data_dir: ${machine.root_data_dir}
+mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir}
+root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir}
+root_experiments_dir: ${machine.root_experiments_dir}
+root_uniception_pretrained_checkpoints_dir: ${machine.root_uniception_pretrained_checkpoints_dir}
+
+### Benchmarking args
+seed: 0
+# Disable CUDNN Benchmark (Disable for variable resolution & number of view training)
+disable_cudnn_benchmark: true
+# Batch size for inference (Metrics are computed per multi-view set and averaged, not per batch of multi-view sets)
+batch_size: 10
+# Use mixed precision for inference
+amp: 1
+# Floating point type to use for mixed precision
+amp_dtype: "bf16"
+# Choose from eth3d, kitti, scannet
+eval_dataset: eth3d
+# Choose from img, img+intrinsics, img+intrinsics+pose
+evaluation_conditioning: img
+# Choose from "median", "none"
+evaluation_alignment: median
+# Choose from "multi_view", "single_view"
+evaluation_views: multi_view
+# Resolution to inference the selected model.
+evaluation_resolution: ${dataset.resolution_options.518_1_33_ar}
diff --git a/configs/train.yaml b/configs/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a8a4ad870c3a0fb88eb8671b95fd77be254c0931
--- /dev/null
+++ b/configs/train.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - machine: aws
+ - model: default
+ - dataset: default
+ - loss: default
+ - train_params: default
+ - distributed: default
+ - _self_
+
+output_dir: ${hydra:run.dir}
+root_data_dir: ${machine.root_data_dir}
+mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir}
+root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir}
+root_experiments_dir: ${machine.root_experiments_dir}
+root_uniception_pretrained_checkpoints_dir: ${machine.root_uniception_pretrained_checkpoints_dir}
diff --git a/configs/train_params/default.yaml b/configs/train_params/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d059c9d5ee6d4d0feb69ef6a5d3cef2014a037e1
--- /dev/null
+++ b/configs/train_params/default.yaml
@@ -0,0 +1,41 @@
+# Random Seed
+seed: 0
+# Maximum number of images per GPU (changes based on available GPU memory)
+max_num_of_imgs_per_gpu: 48
+# Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
+accum_iter: 1
+# Maximum number of epochs for the scheduler
+epochs: 100
+## Default Optimizer parameters
+# Learning rate (absolute lr)
+lr: 0.0001
+# Lower lr bound for cyclic schedulers that hit 0
+min_lr: 1e-06
+# Epochs to warmup LR
+warmup_epochs: 10
+# Weight decay
+weight_decay: 0.05
+# LR schedule type
+schedule_type: "linear_warmup_half_cycle_cosine_decay"
+# Warn if model params are not in the below submodule_configs
+warn_not_in_submodule: False
+# Optimizer parameters specific to submodules
+submodule_configs: {}
+# Use Automatic Mixed Precision for pretraining
+amp: 1
+# Floating point type to use for mixed precision training
+amp_dtype: "bf16"
+# Disable CUDNN Benchmark (Disable for variable resolution & number of view training)
+disable_cudnn_benchmark: true
+# Freeze the validation samples across all epochs
+freeze_val_samples_across_all_epochs: true
+# Test loss evaluation frequency
+eval_freq: 1
+# Frequency (number of epochs) to save checkpoint in checkpoint-last.pth
+save_freq: 1
+# Frequency (number of epochs) to save checkpoint in checkpoint-%d.pth
+keep_freq: 10
+# Frequence (number of iterations) to print infos while training (includes tensorboard logging)
+print_freq: 20
+# Resume Training from last checkpoint
+resume: True
diff --git a/configs/train_params/finetune_with_lower_encoder_lr.yaml b/configs/train_params/finetune_with_lower_encoder_lr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1dda66b8c03670e8cdbee34aa0f1f0f0fb44d29a
--- /dev/null
+++ b/configs/train_params/finetune_with_lower_encoder_lr.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+
+# Use 20x lower lr for finetuning
+lr: 5e-06
+min_lr: 5e-08
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 2.5e-07
+ min_lr: 2.5e-09
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/configs/train_params/finetune_with_lower_encoder_lr_64g.yaml b/configs/train_params/finetune_with_lower_encoder_lr_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f593dd3a43c6808bd65d5d34b8c0c5ad3253d28e
--- /dev/null
+++ b/configs/train_params/finetune_with_lower_encoder_lr_64g.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+
+# Use 20x lower lr for finetuning
+lr: 1e-05
+min_lr: 1e-07
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 5e-07
+ min_lr: 5e-09
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/configs/train_params/freeze_encoder.yaml b/configs/train_params/freeze_encoder.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..25e7508544f98564678bd863db491de856e010e4
--- /dev/null
+++ b/configs/train_params/freeze_encoder.yaml
@@ -0,0 +1,8 @@
+defaults:
+ - default
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 0
diff --git a/configs/train_params/lower_encoder_lr.yaml b/configs/train_params/lower_encoder_lr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a1ae06dae85f31174b5103c056eb4d3a163e2a8
--- /dev/null
+++ b/configs/train_params/lower_encoder_lr.yaml
@@ -0,0 +1,12 @@
+defaults:
+ - default
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 5e-06
+ min_lr: 5e-08
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/configs/train_params/lower_encoder_lr_64g.yaml b/configs/train_params/lower_encoder_lr_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4364232c4d65b6117a0ad25adace93f2444af002
--- /dev/null
+++ b/configs/train_params/lower_encoder_lr_64g.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+
+# Use 2x higher lr for 8x higher effective batch size
+lr: 2e-04
+min_lr: 2e-07
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 1e-05
+ min_lr: 1e-08
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/configs/train_params/vggt_finetune.yaml b/configs/train_params/vggt_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d130d4f12ad4d8e2de2165bfeb058addd0b668a8
--- /dev/null
+++ b/configs/train_params/vggt_finetune.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+
+# Use 10x lower lr for finetuning
+lr: 1e-05
+min_lr: 1e-07
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # DINOv2
+ model.aggregator.patch_embed:
+ lr: 5e-07
+ min_lr: 5e-09
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/examples/basketball/img_0.jpg b/examples/basketball/img_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..71fe099baa36135d1cdf6398012ed6478f119792
--- /dev/null
+++ b/examples/basketball/img_0.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db1d9f3f776373de1c0509a3bea4a0d52c00b28ef5893e16a4993a9565125805
+size 963893
diff --git a/examples/basketball/img_1.jpg b/examples/basketball/img_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ab97ba77ceffb8f25e42febad13953c1e4ee28cb
--- /dev/null
+++ b/examples/basketball/img_1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83678bd8cd6df7ea9c778ce3c3c3b87970e1368483a333fd71555d3a17725215
+size 942240
diff --git a/examples/desk/530554609_3367433673396747_2161028887770608277_n.jpg b/examples/desk/530554609_3367433673396747_2161028887770608277_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..860d48bfb94775f745555f09626f5953bff017d5
--- /dev/null
+++ b/examples/desk/530554609_3367433673396747_2161028887770608277_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08ddc149e61544d32e6c1455a34de10a6eecdd92422e0ae2a57d6f251b33c215
+size 2457571
diff --git a/examples/desk/532328457_1311198870420578_2167456836351167380_n.jpg b/examples/desk/532328457_1311198870420578_2167456836351167380_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..229b44dd39409d5eebc7c1bafbab441077814e84
--- /dev/null
+++ b/examples/desk/532328457_1311198870420578_2167456836351167380_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a186834d16c2ac2b4ea2340024b5bc19e4b3316f814e73ccd17bfbc271789b20
+size 2573673
diff --git a/examples/dino/528883410_1456464302336597_4114529568612559572_n.jpg b/examples/dino/528883410_1456464302336597_4114529568612559572_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..133aaef7fc0ed74e903828b1d833f91d1dc0dc1a
--- /dev/null
+++ b/examples/dino/528883410_1456464302336597_4114529568612559572_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12d801d8a1740147d94cf270eb5cd3d7df0e582377024ab8a2f788fa6662212d
+size 2076363
diff --git a/examples/dino/530182709_1122456693282934_3373468492106282632_n.jpg b/examples/dino/530182709_1122456693282934_3373468492106282632_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..38a7b85bbdb4fac006ac6b1991593e085c211e42
--- /dev/null
+++ b/examples/dino/530182709_1122456693282934_3373468492106282632_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41d2d932869ed402248dc414c2392b1f102a72714016f4911cdc6f0c04485089
+size 3035800
diff --git a/examples/dino/532847807_1055021109949229_8315548832183031452_n.jpg b/examples/dino/532847807_1055021109949229_8315548832183031452_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8b4125aa34202a8a37fa36f6a762cccd06f8dd56
--- /dev/null
+++ b/examples/dino/532847807_1055021109949229_8315548832183031452_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:308a101a66665a7bef9c0727e0125d7c01a1245a81d6dd69c839eb1d442e22e4
+size 3369507
diff --git a/examples/grindelwald/img_1.jpg b/examples/grindelwald/img_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..584d6505b7ebe67ef5405b34991c80fc08b53160
--- /dev/null
+++ b/examples/grindelwald/img_1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b801531b8d5b9558b0b1549c6ef823540fabfb98cbcbc705a72f075a72f46c2
+size 3268789
diff --git a/examples/grindelwald/img_2.jpg b/examples/grindelwald/img_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c96f3c110c64c9b2f34e4b036172f80f5d8889bd
--- /dev/null
+++ b/examples/grindelwald/img_2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4eec8b1558c18697787366e69fcefc7a1695513f62909116364963685767f67c
+size 4607606
diff --git a/examples/grindelwald/img_3.jpg b/examples/grindelwald/img_3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5328c2b840371418cfa5e873c23167fc92beb361
--- /dev/null
+++ b/examples/grindelwald/img_3.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6dac8aaf2140b5772566a19c2290161f25832f5c066bc5a7c3ca61f9e8ce16c4
+size 3245650
diff --git a/examples/lion_mirror/480388255_654178757132498_624999253317032198_n.jpg b/examples/lion_mirror/480388255_654178757132498_624999253317032198_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b4aff44cd16683ef47f0095702829f0f00871618
--- /dev/null
+++ b/examples/lion_mirror/480388255_654178757132498_624999253317032198_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:76c2b822d93efb3cd6ac5ea8e90395f1010aad6676e76c92d921f926f90517b4
+size 4701684
diff --git a/examples/lion_mirror/481118149_1175121204292619_2597291685520277061_n.jpg b/examples/lion_mirror/481118149_1175121204292619_2597291685520277061_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bcc57e266c17770b3465b09805f980ac57d4b6ac
--- /dev/null
+++ b/examples/lion_mirror/481118149_1175121204292619_2597291685520277061_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1e388a80cd978dcb6c363806c6390a934f584e73fe1e5bb8c6d78e35d62409b3
+size 5160182
diff --git a/examples/lion_mirror/481163520_1685649178709736_797454981910081980_n.jpg b/examples/lion_mirror/481163520_1685649178709736_797454981910081980_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1ce8d34f85130283a4c05f25b78654eda3287b2e
--- /dev/null
+++ b/examples/lion_mirror/481163520_1685649178709736_797454981910081980_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5238da3e47e2ba76513bc078c0f3bd4e09ac1669c775d4014da31326b6e47c22
+size 4768419
diff --git a/examples/lion_mirror/482451618_4385713848322602_4193708748704841166_n.jpg b/examples/lion_mirror/482451618_4385713848322602_4193708748704841166_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..628051cc41d1e918edb674f3ac43b1d6cdda42bb
--- /dev/null
+++ b/examples/lion_mirror/482451618_4385713848322602_4193708748704841166_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:450266e8fd34147a41d23853b3d0cead7f438dcd981a9a1d569f423b7639a1a9
+size 4912917
diff --git a/examples/llama/530552299_1350142206628921_7256209652527343353_n.jpg b/examples/llama/530552299_1350142206628921_7256209652527343353_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b0cfcb24f32123dc82622268f114816c976a1453
--- /dev/null
+++ b/examples/llama/530552299_1350142206628921_7256209652527343353_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:867e7d0a2d2bd3911fe76180bb07ae0b26fac6de9775c04f4d9a55962123a88e
+size 3876746
diff --git a/examples/llama/532531242_1263522274894680_7977300346885266196_n.jpg b/examples/llama/532531242_1263522274894680_7977300346885266196_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b08ad0beb0d3176b64fa8eea69de50026f7ef9a1
--- /dev/null
+++ b/examples/llama/532531242_1263522274894680_7977300346885266196_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35f4e2081af140777ff71503a14d24522bbe59de3f9ccb4ccdd6fa1d41fc0552
+size 3636302
diff --git a/examples/museum/00027.jpg b/examples/museum/00027.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..72cf086219601609dea71947f0e88d762c6692bb
--- /dev/null
+++ b/examples/museum/00027.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:87596ce4acb9069f15e13091002bca0f1cba47617d8f5d5c8f210221a904d50d
+size 652223
diff --git a/examples/museum/00028.jpg b/examples/museum/00028.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cffb57f6d1ed553b055617ac65f08e794c4b82a0
--- /dev/null
+++ b/examples/museum/00028.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:21ece45312fa5fd519b87d3008a51628b92bf09ffab36d8f2502eb39176df743
+size 674259
diff --git a/examples/museum/00029.jpg b/examples/museum/00029.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..11f2831133609bd4a74499f359513a133916172d
--- /dev/null
+++ b/examples/museum/00029.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e73811433866dc010c126719396429a151270a7ec0cded797a54120f115b3582
+size 643434
diff --git a/examples/night_temple/480811030_2684126631974802_260568564136361360_n.jpg b/examples/night_temple/480811030_2684126631974802_260568564136361360_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4d9bed7e9c40387b56b60cbdc77fc53f0751573e
--- /dev/null
+++ b/examples/night_temple/480811030_2684126631974802_260568564136361360_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f718ca85a2772943f044534d405b50837f2a76156da384f9dae068e01001f132
+size 2461264
diff --git a/examples/night_temple/482405589_3949803958672748_8316540969946364052_n.jpg b/examples/night_temple/482405589_3949803958672748_8316540969946364052_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0feb163182f29329fde8814902c66f9bf5a075db
--- /dev/null
+++ b/examples/night_temple/482405589_3949803958672748_8316540969946364052_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b0d5e6e0a52f7020b208aaac9882c2216b72ea2fa6f3e312fa53b995e55d642
+size 2908582
diff --git a/examples/painting/oil.png b/examples/painting/oil.png
new file mode 100644
index 0000000000000000000000000000000000000000..ac7202311214b9de61e5aba2f1785d6195232050
--- /dev/null
+++ b/examples/painting/oil.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9191a3f73f37ff811231aef529f3e8306e1200fea9faa26b7e87cf3799b499cf
+size 4398004
diff --git a/examples/scenic/001.jpg b/examples/scenic/001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b2ae42437e8ef9d9557e9cc9714ddb4d7d6a322a
--- /dev/null
+++ b/examples/scenic/001.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9fde736637e4c3f2cc01456c71b63a1540f0f82979134985f5bf2567c2400e6d
+size 301432
diff --git a/examples/scenic/002.jpg b/examples/scenic/002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f948ec6b6ba7825585e30fcb103e97f0b097254f
--- /dev/null
+++ b/examples/scenic/002.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c2ac14d1aac3c12c97489c6dd41149650867942b227ef17389f4a2da7279fa9
+size 298193
diff --git a/examples/scenic_mono/533229144_768752402788769_5743558040836355722_n.jpg b/examples/scenic_mono/533229144_768752402788769_5743558040836355722_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..df91421e1e6aa57bc1eb66c7366e7c837336068e
--- /dev/null
+++ b/examples/scenic_mono/533229144_768752402788769_5743558040836355722_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b09482a2b7e18d272205988732aa1ec7d74c26f7d523961457ae05edc3909bbb
+size 274384
diff --git a/examples/wai_logo/wai_logo.png b/examples/wai_logo/wai_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..312018467d59569671614e80a1422840b6ad8858
--- /dev/null
+++ b/examples/wai_logo/wai_logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1284af033fd4b75b3f17598177e77806c3c8d96a5e67f723bf97df395f5a0b7
+size 183073
diff --git a/hf_utils/css_and_html.py b/hf_utils/css_and_html.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1acd196ba61c21ed7aa6ac2075d21ea15d428e3
--- /dev/null
+++ b/hf_utils/css_and_html.py
@@ -0,0 +1,238 @@
+"""
+CSS and HTML content for the MapAnything Gradio application.
+This module contains all the CSS styles and HTML content blocks
+used in the Gradio interface.
+"""
+
+# CSS Styles for the Gradio interface
+GRADIO_CSS = """
+.custom-log * {
+ font-style: italic;
+ font-size: 22px !important;
+ background-image: linear-gradient(120deg, #ffb366 0%, #ffa366 60%, #ff9966 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ font-weight: bold !important;
+ color: transparent !important;
+ text-align: center !important;
+}
+
+.example-log * {
+ font-style: italic;
+ font-size: 16px !important;
+ background-image: linear-gradient(120deg, #ffb366 0%, #ffa366 60%, #ff9966 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ color: transparent !important;
+}
+
+#my_radio .wrap {
+ display: flex;
+ flex-wrap: nowrap;
+ justify-content: center;
+ align-items: center;
+}
+
+#my_radio .wrap label {
+ display: flex;
+ width: 50%;
+ justify-content: center;
+ align-items: center;
+ margin: 0;
+ padding: 10px 0;
+ box-sizing: border-box;
+}
+
+/* Align navigation buttons with dropdown bottom */
+.navigation-row {
+ display: flex !important;
+ align-items: flex-end !important;
+ gap: 8px !important;
+}
+
+.navigation-row > div:nth-child(1),
+.navigation-row > div:nth-child(3) {
+ align-self: flex-end !important;
+}
+
+.navigation-row > div:nth-child(2) {
+ flex: 1 !important;
+}
+
+/* Make thumbnails clickable with pointer cursor */
+.clickable-thumbnail img {
+ cursor: pointer !important;
+}
+
+.clickable-thumbnail:hover img {
+ cursor: pointer !important;
+ opacity: 0.8;
+ transition: opacity 0.3s ease;
+}
+
+/* Make thumbnail containers narrower horizontally */
+.clickable-thumbnail {
+ padding: 5px 2px !important;
+ margin: 0 2px !important;
+}
+
+.clickable-thumbnail .image-container {
+ margin: 0 !important;
+ padding: 0 !important;
+}
+
+.scene-info {
+ text-align: center !important;
+ padding: 5px 2px !important;
+ margin: 0 !important;
+}
+"""
+
+
+def get_header_html(logo_base64=None):
+ """
+ Generate the main header HTML with logo and title.
+
+ Args:
+ logo_base64 (str, optional): Base64 encoded logo image
+
+ Returns:
+ str: HTML string for the header
+ """
+ logo_style = "display: none;" if not logo_base64 else ""
+ logo_src = logo_base64 or ""
+
+ return f"""
+
+

+
MapAnything: 3D Scene Reconstruction
+
+
+ 🌟 GitHub Repository |
+ 🚀 Project Page
+
+ """
+
+
+def get_description_html():
+ """
+ Generate the main description and getting started HTML.
+
+ Returns:
+ str: HTML string for the description
+ """
+ return """
+
+
Upload a video or a set of images to create a 3D reconstruction of a scene or object. MapAnything takes these images and generates 3D point clouds directly from multi-view images.
+
This demo demonstrates the image input configuration only. For other input configuration, please check out the code in our Github repo.
+
+
Getting Started:
+
+ - Upload Your Data: Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).
+ - Preview: Your uploaded images will appear in the gallery on the left.
+ - Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction process.
+ - Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.
+ -
+ Adjust Visualization (Optional):
+ After reconstruction, you can fine-tune the visualization using the options below
+
+ (click to expand):
+
+ - Confidence Threshold: Adjust the filtering of depth/normals based on confidence.
+ - Show Points from Frame: Select specific frames to display in the point cloud.
+ - Show Camera: Toggle the display of estimated camera positions.
+ - Filter Sky / Filter Black Background: Remove sky or black-background points.
+
+
+
+
+
Please note: Our model itself usually only needs less than 1 second to reconstruct a scene. However, visualizing 3D points may take tens of seconds due to third-party rendering, which is independent of MapAnything's processing time. Please be patient or, for faster visualization, use a local machine to run our demo from our GitHub repository.
+
+ """
+
+
+def get_acknowledgements_html():
+ """
+ Generate the acknowledgements section HTML.
+
+ Returns:
+ str: HTML string for the acknowledgements
+ """
+ return """
+
+
+
Acknowledgements
+
This site builds upon code from:
+
+
We extend our gratitude to these projects for their valuable contributions to the research community.
+
+ """
+
+
+def get_gradio_theme():
+ """
+ Get the configured Gradio theme.
+
+ Returns:
+ gr.themes.Base: Configured Gradio theme
+ """
+ import gradio as gr
+
+ return gr.themes.Base(
+ primary_hue=gr.themes.Color(
+ c100="#ffedd5",
+ c200="#ffddb3",
+ c300="rgba(242.78125, 182.89427563548466, 120.32579495614034, 1)",
+ c400="#fb923c",
+ c50="#fff7ed",
+ c500="#f97316",
+ c600="#ea580c",
+ c700="#c2410c",
+ c800="#9a3412",
+ c900="#7c2d12",
+ c950="#6c2e12",
+ ),
+ secondary_hue="amber",
+ )
+
+
+# Example scene thumbnail grid CSS (if needed separately)
+THUMBNAIL_CSS = """
+/* Make thumbnails clickable with pointer cursor */
+.clickable-thumbnail img {
+ cursor: pointer !important;
+}
+
+.clickable-thumbnail:hover img {
+ cursor: pointer !important;
+ opacity: 0.8;
+ transition: opacity 0.3s ease;
+}
+
+/* Make thumbnail containers narrower horizontally */
+.clickable-thumbnail {
+ padding: 5px 2px !important;
+ margin: 0 2px !important;
+}
+
+.clickable-thumbnail .image-container {
+ margin: 0 !important;
+ padding: 0 !important;
+}
+
+.scene-info {
+ text-align: center !important;
+ padding: 5px 2px !important;
+ margin: 0 !important;
+}
+"""
+
+
+# Measure tab instructions HTML
+MEASURE_INSTRUCTIONS_HTML = """
+### Click on the image to measure the distance between two points.
+"""
diff --git a/hf_utils/vgg_geometry.py b/hf_utils/vgg_geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ebd25dbc6cac6b0095956524c4f0628410dd5cb
--- /dev/null
+++ b/hf_utils/vgg_geometry.py
@@ -0,0 +1,166 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+
+
+def unproject_depth_map_to_point_map(
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
+) -> np.ndarray:
+ """
+ Unproject a batch of depth maps to 3D world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
+
+ Returns:
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
+ """
+ if isinstance(depth_map, torch.Tensor):
+ depth_map = depth_map.cpu().numpy()
+ if isinstance(extrinsics_cam, torch.Tensor):
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
+ if isinstance(intrinsics_cam, torch.Tensor):
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
+
+ world_points_list = []
+ for frame_idx in range(depth_map.shape[0]):
+ cur_world_points, _, _ = depth_to_world_coords_points(
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
+ )
+ world_points_list.append(cur_world_points)
+ world_points_array = np.stack(world_points_list, axis=0)
+
+ return world_points_array
+
+
+def depth_to_world_coords_points(
+ depth_map: np.ndarray,
+ extrinsic: np.ndarray,
+ intrinsic: np.ndarray,
+ eps=1e-8,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
+
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
+
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
+
+ # Apply the rotation and translation to the camera coordinates
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
+
+ return world_coords_points, cam_coords_points, point_mask
+
+
+def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
+ """
+ H, W = depth_map.shape
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
+
+ # Generate grid of pixel coordinates
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+
+ # Unproject to camera coordinates
+ x_cam = (u - cu) * depth_map / fu
+ y_cam = (v - cv) * depth_map / fv
+ z_cam = depth_map
+
+ # Stack to form camera coordinates
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ return cam_coords
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
diff --git a/hf_utils/visual_util.py b/hf_utils/visual_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..d56d7cc3b99936a14ffb16cee4d2acdfb5780dfa
--- /dev/null
+++ b/hf_utils/visual_util.py
@@ -0,0 +1,501 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+import os
+
+import cv2
+import matplotlib
+import numpy as np
+import requests
+import trimesh
+from scipy.spatial.transform import Rotation
+
+
+def predictions_to_glb(
+ predictions,
+ conf_thres=50.0,
+ filter_by_frames="all",
+ mask_black_bg=False,
+ mask_white_bg=False,
+ show_cam=True,
+ mask_sky=False,
+ target_dir=None,
+ prediction_mode="Predicted Pointmap",
+ mask_ambiguous=False,
+) -> trimesh.Scene:
+ """
+ Converts VGGT predictions to a 3D scene represented as a GLB file.
+
+ Args:
+ predictions (dict): Dictionary containing model predictions with keys:
+ - world_points: 3D point coordinates (S, H, W, 3)
+ - world_points_conf: Confidence scores (S, H, W)
+ - images: Input images (S, H, W, 3)
+ - extrinsic: Camera extrinsic matrices (S, 3, 4)
+ conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0)
+ filter_by_frames (str): Frame filter specification (default: "all")
+ mask_black_bg (bool): Mask out black background pixels (default: False)
+ mask_white_bg (bool): Mask out white background pixels (default: False)
+ show_cam (bool): Include camera visualization (default: True)
+ mask_sky (bool): Apply sky segmentation mask (default: False)
+ target_dir (str): Output directory for intermediate files (default: None)
+ prediction_mode (str): Prediction mode selector (default: "Predicted Pointmap")
+
+ Returns:
+ trimesh.Scene: Processed 3D scene containing point cloud and cameras
+
+ Raises:
+ ValueError: If input predictions structure is invalid
+ """
+ if not isinstance(predictions, dict):
+ raise ValueError("predictions must be a dictionary")
+
+ if conf_thres is None:
+ conf_thres = 10.0
+
+ print("Building GLB scene")
+ selected_frame_idx = None
+ if filter_by_frames != "all" and filter_by_frames != "All":
+ try:
+ # Extract the index part before the colon
+ selected_frame_idx = int(filter_by_frames.split(":")[0])
+ except (ValueError, IndexError):
+ pass
+
+ if "Pointmap" in prediction_mode:
+ print("Using Pointmap Branch")
+ if "world_points" in predictions:
+ # import ipdb
+
+ # ipdb.set_trace()
+
+ pred_world_points = predictions[
+ "world_points"
+ ] # No batch dimension to remove
+ pred_world_points_conf = predictions.get(
+ "confidence", np.ones_like(pred_world_points[..., 0])
+ )
+ else:
+ print(
+ "Warning: world_points not found in predictions, falling back to depth-based points"
+ )
+ pred_world_points = predictions["world_points_from_depth"]
+ pred_world_points_conf = predictions.get(
+ "depth_conf", np.ones_like(pred_world_points[..., 0])
+ )
+ else:
+ print("Using Depthmap and Camera Branch")
+ pred_world_points = predictions["world_points_from_depth"]
+ pred_world_points_conf = predictions.get(
+ "depth_conf", np.ones_like(pred_world_points[..., 0])
+ )
+
+ # Get images from predictions
+ images = predictions["images"]
+ # Use extrinsic matrices instead of pred_extrinsic_list
+ camera_matrices = predictions["extrinsic"]
+
+ if mask_sky:
+ if target_dir is not None:
+ import onnxruntime
+
+ skyseg_session = None
+ target_dir_images = target_dir + "/images"
+ image_list = sorted(os.listdir(target_dir_images))
+ sky_mask_list = []
+
+ # Get the shape of pred_world_points_conf to match
+ S, H, W = (
+ pred_world_points_conf.shape
+ if hasattr(pred_world_points_conf, "shape")
+ else (len(images), images.shape[1], images.shape[2])
+ )
+
+ # Download skyseg.onnx if it doesn't exist
+ if not os.path.exists("skyseg.onnx"):
+ print("Downloading skyseg.onnx...")
+ download_file_from_url(
+ "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
+ "skyseg.onnx",
+ )
+
+ for i, image_name in enumerate(image_list):
+ image_filepath = os.path.join(target_dir_images, image_name)
+ mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
+
+ # Check if mask already exists
+ if os.path.exists(mask_filepath):
+ # Load existing mask
+ sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
+ else:
+ # Generate new mask
+ if skyseg_session is None:
+ skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
+ sky_mask = segment_sky(
+ image_filepath, skyseg_session, mask_filepath
+ )
+
+ # Resize mask to match H×W if needed
+ if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
+ sky_mask = cv2.resize(sky_mask, (W, H))
+
+ sky_mask_list.append(sky_mask)
+
+ # Convert list to numpy array with shape S×H×W
+ sky_mask_array = np.array(sky_mask_list)
+
+ # Apply sky mask to confidence scores
+ sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
+ pred_world_points_conf = pred_world_points_conf * sky_mask_binary
+
+ if selected_frame_idx is not None:
+ pred_world_points = pred_world_points[selected_frame_idx][None]
+ pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
+ images = images[selected_frame_idx][None]
+ camera_matrices = camera_matrices[selected_frame_idx][None]
+
+ vertices_3d = pred_world_points.reshape(-1, 3)
+ # Handle different image formats - check if images need transposing
+ if images.ndim == 4 and images.shape[1] == 3: # NCHW format
+ colors_rgb = np.transpose(images, (0, 2, 3, 1))
+ else: # Assume already in NHWC format
+ colors_rgb = images
+ colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
+
+ conf = pred_world_points_conf.reshape(-1)
+ # Convert percentage threshold to actual confidence value
+ if conf_thres == 0.0:
+ conf_threshold = 0.0
+ else:
+ conf_threshold = np.percentile(conf, conf_thres)
+
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
+ final_mask = predictions["final_mask"].reshape(-1)
+
+ if mask_black_bg:
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16 / 255.0
+ conf_mask = conf_mask & black_bg_mask
+
+ if mask_white_bg:
+ # Filter out white background pixels (RGB values close to white)
+ # Consider pixels white if all RGB values are above 240
+ white_bg_mask = (
+ (colors_rgb[:, 0] > 240 / 255.0)
+ & (colors_rgb[:, 1] > 240 / 255.0)
+ & (colors_rgb[:, 2] > 240 / 255.0)
+ )
+ conf_mask = conf_mask & white_bg_mask
+
+ # Use final_mask with conf_mask when mask_ambiguous is checked
+ if mask_ambiguous:
+ conf_mask = conf_mask & final_mask
+
+ vertices_3d = vertices_3d[conf_mask].copy()
+ colors_rgb = colors_rgb[conf_mask].copy()
+
+ if vertices_3d is None or np.asarray(vertices_3d).size == 0:
+ vertices_3d = np.array([[1, 0, 0]])
+ colors_rgb = np.array([[255, 255, 255]])
+ scene_scale = 1
+ else:
+ # Calculate the 5th and 95th percentiles along each axis
+ lower_percentile = np.percentile(vertices_3d, 5, axis=0)
+ upper_percentile = np.percentile(vertices_3d, 95, axis=0)
+
+ # Calculate the diagonal length of the percentile bounding box
+ scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
+
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
+
+ # Initialize a 3D scene
+ scene_3d = trimesh.Scene()
+
+ # Add point cloud data to the scene
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
+
+ scene_3d.add_geometry(point_cloud_data)
+
+ # Prepare 4x4 matrices for camera extrinsics
+ num_cameras = len(camera_matrices)
+ extrinsics_matrices = np.zeros((num_cameras, 4, 4))
+ extrinsics_matrices[:, :3, :4] = camera_matrices
+ extrinsics_matrices[:, 3, 3] = 1
+
+ if show_cam:
+ # Add camera models to the scene
+ for i in range(num_cameras):
+ world_to_camera = extrinsics_matrices[i]
+ camera_to_world = np.linalg.inv(world_to_camera)
+ rgba_color = colormap(i / num_cameras)
+ current_color = tuple(int(255 * x) for x in rgba_color[:3])
+
+ integrate_camera_into_scene(
+ scene_3d, camera_to_world, current_color, scene_scale
+ )
+
+ # Align scene to the observation of the first camera
+ scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
+
+ print("GLB Scene built")
+ return scene_3d
+
+
+def integrate_camera_into_scene(
+ scene: trimesh.Scene,
+ transform: np.ndarray,
+ face_colors: tuple,
+ scene_scale: float,
+):
+ """
+ Integrates a fake camera mesh into the 3D scene.
+
+ Args:
+ scene (trimesh.Scene): The 3D scene to add the camera model.
+ transform (np.ndarray): Transformation matrix for camera positioning.
+ face_colors (tuple): Color of the camera face.
+ scene_scale (float): Scale of the scene.
+ """
+
+ cam_width = scene_scale * 0.05
+ cam_height = scene_scale * 0.1
+
+ # Create cone shape for camera
+ rot_45_degree = np.eye(4)
+ rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
+ rot_45_degree[2, 3] = -cam_height
+
+ opengl_transform = get_opengl_conversion_matrix()
+ # Combine transformations
+ complete_transform = transform @ opengl_transform @ rot_45_degree
+ camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
+
+ # Generate mesh for the camera
+ slight_rotation = np.eye(4)
+ slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
+
+ vertices_combined = np.concatenate(
+ [
+ camera_cone_shape.vertices,
+ 0.95 * camera_cone_shape.vertices,
+ transform_points(slight_rotation, camera_cone_shape.vertices),
+ ]
+ )
+ vertices_transformed = transform_points(complete_transform, vertices_combined)
+
+ mesh_faces = compute_camera_faces(camera_cone_shape)
+
+ # Add the camera mesh to the scene
+ camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
+ camera_mesh.visual.face_colors[:, :3] = face_colors
+ scene.add_geometry(camera_mesh)
+
+
+def apply_scene_alignment(
+ scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
+) -> trimesh.Scene:
+ """
+ Aligns the 3D scene based on the extrinsics of the first camera.
+
+ Args:
+ scene_3d (trimesh.Scene): The 3D scene to be aligned.
+ extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
+
+ Returns:
+ trimesh.Scene: Aligned 3D scene.
+ """
+ # Set transformations for scene alignment
+ opengl_conversion_matrix = get_opengl_conversion_matrix()
+
+ # Rotation matrix for alignment (180 degrees around the y-axis)
+ align_rotation = np.eye(4)
+ align_rotation[:3, :3] = Rotation.from_euler("y", 0, degrees=True).as_matrix()
+
+ # Apply transformation
+ initial_transformation = (
+ np.linalg.inv(extrinsics_matrices[0])
+ @ opengl_conversion_matrix
+ @ align_rotation
+ )
+ scene_3d.apply_transform(initial_transformation)
+ return scene_3d
+
+
+def get_opengl_conversion_matrix() -> np.ndarray:
+ """
+ Constructs and returns the OpenGL conversion matrix.
+
+ Returns:
+ numpy.ndarray: A 4x4 OpenGL conversion matrix.
+ """
+ # Create an identity matrix
+ matrix = np.identity(4)
+
+ # Flip the y and z axes
+ matrix[1, 1] = -1
+ matrix[2, 2] = -1
+
+ return matrix
+
+
+def transform_points(
+ transformation: np.ndarray, points: np.ndarray, dim: int = None
+) -> np.ndarray:
+ """
+ Applies a 4x4 transformation to a set of points.
+
+ Args:
+ transformation (np.ndarray): Transformation matrix.
+ points (np.ndarray): Points to be transformed.
+ dim (int, optional): Dimension for reshaping the result.
+
+ Returns:
+ np.ndarray: Transformed points.
+ """
+ points = np.asarray(points)
+ initial_shape = points.shape[:-1]
+ dim = dim or points.shape[-1]
+
+ # Apply transformation
+ transformation = transformation.swapaxes(
+ -1, -2
+ ) # Transpose the transformation matrix
+ points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
+
+ # Reshape the result
+ result = points[..., :dim].reshape(*initial_shape, dim)
+ return result
+
+
+def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
+ """
+ Computes the faces for the camera mesh.
+
+ Args:
+ cone_shape (trimesh.Trimesh): The shape of the camera cone.
+
+ Returns:
+ np.ndarray: Array of faces for the camera mesh.
+ """
+ # Create pseudo cameras
+ faces_list = []
+ num_vertices_cone = len(cone_shape.vertices)
+
+ for face in cone_shape.faces:
+ if 0 in face:
+ continue
+ v1, v2, v3 = face
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
+
+ faces_list.extend(
+ [
+ (v1, v2, v2_offset),
+ (v1, v1_offset, v3),
+ (v3_offset, v2, v3),
+ (v1, v2, v2_offset_2),
+ (v1, v1_offset_2, v3),
+ (v3_offset_2, v2, v3),
+ ]
+ )
+
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
+ return np.array(faces_list)
+
+
+def segment_sky(image_path, onnx_session, mask_filename=None):
+ """
+ Segments sky from an image using an ONNX model.
+ Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing
+
+ Args:
+ image_path: Path to input image
+ onnx_session: ONNX runtime session with loaded model
+ mask_filename: Path to save the output mask
+
+ Returns:
+ np.ndarray: Binary mask where 255 indicates non-sky regions
+ """
+
+ assert mask_filename is not None
+ image = cv2.imread(image_path)
+
+ result_map = run_skyseg(onnx_session, [320, 320], image)
+ # resize the result_map to the original image size
+ result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
+
+ # Fix: Invert the mask so that 255 = non-sky, 0 = sky
+ # The model outputs low values for sky, high values for non-sky
+ output_mask = np.zeros_like(result_map_original)
+ output_mask[result_map_original < 32] = 255 # Use threshold of 32
+
+ os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
+ cv2.imwrite(mask_filename, output_mask)
+ return output_mask
+
+
+def run_skyseg(onnx_session, input_size, image):
+ """
+ Runs sky segmentation inference using ONNX model.
+
+ Args:
+ onnx_session: ONNX runtime session
+ input_size: Target size for model input (width, height)
+ image: Input image in BGR format
+
+ Returns:
+ np.ndarray: Segmentation mask
+ """
+
+ # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
+ temp_image = copy.deepcopy(image)
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
+ x = np.array(x, dtype=np.float32)
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ x = (x / 255 - mean) / std
+ x = x.transpose(2, 0, 1)
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
+
+ # Inference
+ input_name = onnx_session.get_inputs()[0].name
+ output_name = onnx_session.get_outputs()[0].name
+ onnx_result = onnx_session.run([output_name], {input_name: x})
+
+ # Post process
+ onnx_result = np.array(onnx_result).squeeze()
+ min_value = np.min(onnx_result)
+ max_value = np.max(onnx_result)
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
+ onnx_result *= 255
+ onnx_result = onnx_result.astype("uint8")
+
+ return onnx_result
+
+
+def download_file_from_url(url, filename):
+ """Downloads a file from a Hugging Face model repo, handling redirects."""
+ try:
+ # Get the redirect URL
+ response = requests.get(url, allow_redirects=False)
+ response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
+
+ if response.status_code == 302: # Expecting a redirect
+ redirect_url = response.headers["Location"]
+ response = requests.get(redirect_url, stream=True)
+ response.raise_for_status()
+ else:
+ print(f"Unexpected status code: {response.status_code}")
+ return
+
+ with open(filename, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+ print(f"Downloaded {filename} successfully.")
+
+ except requests.exceptions.RequestException as e:
+ print(f"Error downloading file: {e}")
diff --git a/mapanything/datasets/__init__.py b/mapanything/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5b413385e67e07ae394402431ce78a3b480300
--- /dev/null
+++ b/mapanything/datasets/__init__.py
@@ -0,0 +1,178 @@
+"""
+MapAnything Datasets
+"""
+
+import torch
+
+from mapanything.datasets.wai.ase import ASEWAI # noqa
+from mapanything.datasets.wai.bedlam import BedlamWAI # noqa
+from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa
+from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa
+from mapanything.datasets.wai.dtu import DTUWAI # noqa
+from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa
+from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa
+from mapanything.datasets.wai.gta_sfm import GTASfMWAI # noqa
+from mapanything.datasets.wai.matrixcity import MatrixCityWAI # noqa
+from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa
+from mapanything.datasets.wai.mpsd import MPSDWAI # noqa
+from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa
+from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa
+from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa
+from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa
+from mapanything.datasets.wai.spring import SpringWAI # noqa
+from mapanything.datasets.wai.structured3d import Structured3DWAI # noqa
+from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa
+from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa
+from mapanything.datasets.wai.xrooms import XRoomsWAI # noqa
+from mapanything.utils.train_tools import get_rank, get_world_size
+
+
+def get_test_data_loader(
+ dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True
+):
+ "Get simple PyTorch dataloader corresponding to the testing dataset"
+ # PyTorch dataset
+ if isinstance(dataset, str):
+ dataset = eval(dataset)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ if torch.distributed.is_initialized():
+ sampler = torch.utils.data.DistributedSampler(
+ dataset,
+ num_replicas=world_size,
+ rank=rank,
+ shuffle=shuffle,
+ drop_last=drop_last,
+ )
+ elif shuffle:
+ sampler = torch.utils.data.RandomSampler(dataset)
+ else:
+ sampler = torch.utils.data.SequentialSampler(dataset)
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ sampler=sampler,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=pin_mem,
+ drop_last=drop_last,
+ )
+
+ return data_loader
+
+
+def get_test_many_ar_data_loader(
+ dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True
+):
+ "Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios"
+ # PyTorch dataset
+ if isinstance(dataset, str):
+ dataset = eval(dataset)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ # Get BatchedMultiFeatureRandomSampler
+ sampler = dataset.make_sampler(
+ batch_size,
+ shuffle=True,
+ world_size=world_size,
+ rank=rank,
+ drop_last=drop_last,
+ use_dynamic_sampler=False,
+ )
+
+ # Init the data laoder
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ sampler=sampler,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=pin_mem,
+ drop_last=drop_last,
+ )
+
+ return data_loader
+
+
+class DynamicBatchDatasetWrapper:
+ """
+ Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output.
+
+ The dynamic sampler returns batches (lists of tuples) instead of individual samples.
+ This wrapper ensures that the underlying dataset's __getitem__ method gets called
+ with individual tuples as expected.
+ """
+
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __getitem__(self, batch_indices):
+ """
+ Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler.
+
+ Args:
+ batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]
+
+ Returns:
+ List of samples from the underlying dataset
+ """
+ if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0:
+ # If it's a batch (list of tuples), process each item
+ if isinstance(batch_indices[0], (list, tuple)):
+ return [self.dataset[idx] for idx in batch_indices]
+ else:
+ # Single tuple, call dataset directly
+ return self.dataset[batch_indices]
+ else:
+ # Fallback for single index
+ return self.dataset[batch_indices]
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getattr__(self, name):
+ # Delegate all other attributes to the wrapped dataset
+ return getattr(self.dataset, name)
+
+
+def get_train_data_loader(
+ dataset,
+ max_num_of_imgs_per_gpu,
+ num_workers=8,
+ shuffle=True,
+ drop_last=True,
+ pin_mem=True,
+):
+ "Dynamic PyTorch dataloader corresponding to the training dataset"
+ # PyTorch dataset
+ if isinstance(dataset, str):
+ dataset = eval(dataset)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ # Get DynamicBatchedMultiFeatureRandomSampler
+ batch_sampler = dataset.make_sampler(
+ shuffle=shuffle,
+ world_size=world_size,
+ rank=rank,
+ drop_last=drop_last,
+ max_num_of_images_per_gpu=max_num_of_imgs_per_gpu,
+ use_dynamic_sampler=True,
+ )
+
+ # Wrap the dataset to handle batch format from dynamic sampler
+ wrapped_dataset = DynamicBatchDatasetWrapper(dataset)
+
+ # Init the dynamic data loader
+ data_loader = torch.utils.data.DataLoader(
+ wrapped_dataset,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers,
+ pin_memory=pin_mem,
+ )
+
+ return data_loader
diff --git a/mapanything/datasets/base/__init__.py b/mapanything/datasets/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/datasets/base/base_dataset.py b/mapanything/datasets/base/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a184d200d98b181f68a358ae8dea73c3e3b072f
--- /dev/null
+++ b/mapanything/datasets/base/base_dataset.py
@@ -0,0 +1,692 @@
+"""
+Base class for MapAnything datasets.
+"""
+
+from typing import List, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+import torchvision.transforms as tvf
+from scipy.spatial.transform import Rotation
+
+from mapanything.datasets.base.easy_dataset import EasyDataset
+from mapanything.utils.cropping import (
+ bbox_from_intrinsics_in_out,
+ camera_matrix_of_crop,
+ crop_image_and_other_optional_info,
+ rescale_image_and_other_optional_info,
+)
+from mapanything.utils.geometry import (
+ depthmap_to_camera_coordinates,
+ get_absolute_pointmaps_and_rays_info,
+)
+from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
+
+
+class BaseDataset(EasyDataset):
+ """
+ Define all basic options.
+
+ Usage:
+ class MyDataset(BaseDataset):
+ def _get_views(self, idx):
+ views = []
+ views.append(dict(img=, ...))
+ return views
+ """
+
+ def __init__(
+ self,
+ num_views: int,
+ variable_num_views: bool = False,
+ split: str = None,
+ covisibility_thres: float = None,
+ resolution: Union[int, Tuple[int, int], List[Tuple[int, int]]] = None,
+ principal_point_centered: bool = False,
+ transform: str = None,
+ data_norm_type: str = None,
+ aug_crop: int = 0,
+ seed: int = None,
+ max_num_retries: int = 5,
+ ):
+ """
+ PyTorch dataset for multi-view images sampled from scenes, where the images form a single connected component.
+
+ Args:
+ num_views (int): Number of views.
+ variable_num_views (bool): If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2.
+ On by default for N-view train dataloader (hydra config).
+ split (str): 'train', 'val', 'test', etc.
+ covisibility_thres (float): Covisibility (%) threshold to determine if another image is a neighbor or not
+ resolution (int or tuple or list of tuples): Resolution of the images
+ principal_point_centered (bool): If True, the principal point is centered in the image.
+ transform (str): Transform to apply to the images. Options:
+ - 'colorjitter+grayscale+gaublur':
+ tvf.Compose([
+ tvf.RandomApply([tvf.ColorJittter(0.3, 0.4, 0.2, 0.1)], p=0.75),
+ tvf.RandomGrayscale(p=0.05),
+ tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05),
+ ]) after ImgNorm
+ - 'colorjitter': tvf.ColorJittter(0.5, 0.5, 0.5, 0.1) after ImgNorm
+ - 'imgnorm': ImgNorm only
+ data_norm_type (str): Image normalization type.
+ For options, see UniCeption image normalization dict.
+ aug_crop (int): Augment crop. If int greater than 0, indicates the number of pixels to increase in target resolution.
+ seed (int): Seed for the random number generator.
+ max_num_retries (int): Maximum number of retries for loading a different sample from the dataset, if provided idx fails.
+ """
+ self.num_views = num_views
+ self.variable_num_views = variable_num_views
+ self.num_views_min = 2
+ self.split = split
+ self.covisibility_thres = covisibility_thres
+ self._set_resolutions(resolution)
+ self.principal_point_centered = principal_point_centered
+
+ # Update the number of views if necessary and make it a list if variable_num_views is True
+ if self.variable_num_views and self.num_views > self.num_views_min:
+ self.num_views = list(range(self.num_views_min, self.num_views + 1))
+
+ # Initialize the image normalization type
+ if data_norm_type in IMAGE_NORMALIZATION_DICT.keys():
+ self.data_norm_type = data_norm_type
+ image_norm = IMAGE_NORMALIZATION_DICT[data_norm_type]
+ ImgNorm = tvf.Compose(
+ [
+ tvf.ToTensor(),
+ tvf.Normalize(mean=image_norm.mean, std=image_norm.std),
+ ]
+ )
+ elif data_norm_type == "identity":
+ self.data_norm_type = data_norm_type
+ ImgNorm = tvf.Compose([tvf.ToTensor()])
+ else:
+ raise ValueError(
+ f"Unknown data_norm_type: {data_norm_type}. Available options: identity or {list(IMAGE_NORMALIZATION_DICT.keys())}"
+ )
+
+ # Initialize torchvision transforms
+ if transform == "imgnorm":
+ self.transform = ImgNorm
+ elif transform == "colorjitter":
+ self.transform = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
+ elif transform == "colorjitter+grayscale+gaublur":
+ self.transform = tvf.Compose(
+ [
+ tvf.RandomApply([tvf.ColorJitter(0.3, 0.4, 0.2, 0.1)], p=0.75),
+ tvf.RandomGrayscale(p=0.05),
+ tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05),
+ ImgNorm,
+ ]
+ )
+ else:
+ raise ValueError(
+ 'Unknown transform. Available options: "imgnorm", "colorjitter", "colorjitter+grayscale+gaublur"'
+ )
+
+ # Initialize the augmentation parameters
+ self.aug_crop = aug_crop
+
+ # Initialize the seed for the random number generator
+ self.seed = seed
+ self._seed_offset = 0
+
+ # Initialize the maximum number of retries for loading a different sample from the dataset, if the first idx fails
+ self.max_num_retries = max_num_retries
+
+ # Initialize the dataset type flags
+ self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this
+ self.is_synthetic = False # by default a dataset is not synthetic, subclasses can overwrite this
+
+ def _load_data(self):
+ self.scenes = []
+ self.num_of_scenes = len(self.scenes)
+
+ def __len__(self):
+ "Length of the dataset is determined by the number of scenes in the dataset split"
+ return self.num_of_scenes
+
+ def get_stats(self):
+ "Get the number of scenes in the dataset split"
+ return f"{self.num_of_scenes} scenes"
+
+ def __repr__(self):
+ resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
+ return (
+ f"""{type(self).__name__}({self.get_stats()},
+ {self.num_views=}
+ {self.split=},
+ {self.seed=},
+ resolutions={resolutions_str},
+ {self.transform=})""".replace("self.", "")
+ .replace("\n", "")
+ .replace(" ", "")
+ )
+
+ def _get_views(self, idx, num_views_to_sample, resolution):
+ raise NotImplementedError()
+
+ def _set_seed_offset(self, idx):
+ """
+ Set the seed offset. This is directly added to self.seed when setting the random seed.
+ """
+ self._seed_offset = idx
+
+ def _set_resolutions(self, resolutions):
+ assert resolutions is not None, "undefined resolution"
+
+ if isinstance(resolutions, int):
+ resolutions = [resolutions]
+ elif isinstance(resolutions, tuple):
+ resolutions = [resolutions]
+ elif isinstance(resolutions, list):
+ assert all(isinstance(res, tuple) for res in resolutions), (
+ f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints"
+ )
+ else:
+ raise ValueError(
+ f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints"
+ )
+
+ self._resolutions = []
+ for resolution in resolutions:
+ if isinstance(resolution, int):
+ width = height = resolution
+ else:
+ width, height = resolution
+ assert isinstance(width, int), (
+ f"Bad type for {width=} {type(width)=}, should be int"
+ )
+ assert isinstance(height, int), (
+ f"Bad type for {height=} {type(height)=}, should be int"
+ )
+ self._resolutions.append((width, height))
+
+ def _crop_resize_if_necessary(
+ self,
+ image,
+ resolution,
+ depthmap,
+ intrinsics,
+ additional_quantities=None,
+ ):
+ """
+ Process an image by downsampling and cropping as needed to match the target resolution.
+
+ This method performs the following operations:
+ 1. Converts the image to PIL.Image if necessary
+ 2. Crops the image centered on the principal point if requested
+ 3. Downsamples the image using high-quality Lanczos filtering
+ 4. Performs final cropping to match the target resolution
+
+ Args:
+ image (numpy.ndarray or PIL.Image.Image): Input image to be processed
+ resolution (tuple): Target resolution as (width, height)
+ depthmap (numpy.ndarray): Depth map corresponding to the image
+ intrinsics (numpy.ndarray): Camera intrinsics matrix (3x3)
+ additional_quantities (dict, optional): Additional image-related data to be processed
+ alongside the main image with nearest interpolation. Defaults to None.
+
+ Returns:
+ tuple: Processed image, depthmap, and updated intrinsics matrix.
+ If additional_quantities is provided, it returns those as well.
+ """
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+
+ # Cropping centered on the principal point if necessary
+ if self.principal_point_centered:
+ W, H = image.size
+ cx, cy = intrinsics[:2, 2].round().astype(int)
+ if cx < 0 or cx >= W or cy < 0 or cy >= H:
+ # Skip centered cropping if principal point is outside image bounds
+ pass
+ else:
+ min_margin_x = min(cx, W - cx)
+ min_margin_y = min(cy, H - cy)
+ left, top = cx - min_margin_x, cy - min_margin_y
+ right, bottom = cx + min_margin_x, cy + min_margin_y
+ crop_bbox = (left, top, right, bottom)
+ # Only perform the centered crop if the crop_bbox is larger than the target resolution
+ crop_width = right - left
+ crop_height = bottom - top
+ if crop_width > resolution[0] and crop_height > resolution[1]:
+ image, depthmap, intrinsics, additional_quantities = (
+ crop_image_and_other_optional_info(
+ image=image,
+ crop_bbox=crop_bbox,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities=additional_quantities,
+ )
+ )
+
+ # Get the target resolution for re-scaling
+ target_rescale_resolution = np.array(resolution)
+ if self.aug_crop > 1:
+ target_rescale_resolution += self._rng.integers(0, self.aug_crop)
+
+ # High-quality Lanczos down-scaling if necessary
+ image, depthmap, intrinsics, additional_quantities = (
+ rescale_image_and_other_optional_info(
+ image=image,
+ output_resolution=target_rescale_resolution,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities_to_be_resized_with_nearest=additional_quantities,
+ )
+ )
+
+ # Actual cropping (if necessary)
+ new_intrinsics = camera_matrix_of_crop(
+ input_camera_matrix=intrinsics,
+ input_resolution=image.size,
+ output_resolution=resolution,
+ offset_factor=0.5,
+ )
+ crop_bbox = bbox_from_intrinsics_in_out(
+ input_camera_matrix=intrinsics,
+ output_camera_matrix=new_intrinsics,
+ output_resolution=resolution,
+ )
+ image, depthmap, new_intrinsics, additional_quantities = (
+ crop_image_and_other_optional_info(
+ image=image,
+ crop_bbox=crop_bbox,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities=additional_quantities,
+ )
+ )
+
+ # Return the output
+ if additional_quantities is not None:
+ return image, depthmap, new_intrinsics, additional_quantities
+ else:
+ return image, depthmap, new_intrinsics
+
+ def _random_walk_sampling(
+ self,
+ scene_pairwise_covisibility,
+ num_of_samples,
+ max_retries=4,
+ use_bidirectional_covis=True,
+ ):
+ """
+ Randomly samples S indices from an N x N covisbility matrix by forming adjacency edges such that the resulting subgraph (given by the indices) is connected.
+ If the current node has no new unvisited neighbors, backtracking occurs.
+ Retries with different starting indices if the desired number of samples is not reached, excluding previously visited components.
+
+ Args:
+ scene_pairwise_covisibility : np.ndarray (mmap)
+ N x N covisibility matrix for the scene, where N is the number of views in the scene.
+ num_of_samples : int
+ The desired number of nodes to sample (num_of_samples < N).
+ max_retries : int
+ The maximum number of retries with different starting indices.
+ use_bidirectional_covis : bool
+ Whether to compute bidirectional covisibility by averaging row and column values.
+ If False, uses only row access (faster for large memory-mapped arrays).
+ Defaults to True.
+
+ Returns:
+ np.ndarray
+ An array of sampled indices forming a connected subgraph.
+ """
+ excluded_nodes = set()
+ best_walk = [] # To keep track of the best walk found
+ for _ in range(max_retries):
+ visited = set()
+ walk = [] # List to store the random walk sampling order
+ stack = [] # Stack for backtracking
+
+ # Choose a random starting index that is not in the excluded set
+ all_nodes = set(range(len(scene_pairwise_covisibility)))
+ available_nodes = list(all_nodes - excluded_nodes)
+ if not available_nodes:
+ break # No more nodes to try
+ start = self._rng.choice(available_nodes)
+ walk.append(start)
+ visited.add(start)
+ stack.append(start)
+
+ # Continue until we have sampled S indices or all expandable nodes are exhausted
+ while len(walk) < num_of_samples and stack:
+ current = stack[-1]
+ # Get the pairwise covisibility for the current node
+ if use_bidirectional_covis:
+ # Use bidirectional covisibility (slower for large memory-mapped arrays)
+ pairwise_covisibility = (
+ scene_pairwise_covisibility[current, :]
+ + scene_pairwise_covisibility[:, current].T
+ ) / 2
+ else:
+ # Use only row access (faster for large memory-mapped arrays)
+ pairwise_covisibility = scene_pairwise_covisibility[current, :]
+ # Normalize the covisibility using self covisibility
+ pairwise_covisibility = pairwise_covisibility / (
+ pairwise_covisibility[current] + 1e-8
+ )
+ # Assign overlap score of zero to self-pairs
+ pairwise_covisibility[current] = 0
+ # Threshold the covisibility to get adjacency list for the current node
+ adjacency_list_for_current = (
+ pairwise_covisibility > self.covisibility_thres
+ ).astype(int)
+ adjacency_list_for_current = np.flatnonzero(adjacency_list_for_current)
+ # Get all unvisited neighbors
+ candidates = [
+ idx for idx in adjacency_list_for_current if idx not in visited
+ ] # Remove visited nodes
+ if candidates:
+ # Randomly select one of the unvisited overlapping neighbors
+ next_node = self._rng.choice(candidates)
+ walk.append(next_node)
+ visited.add(next_node)
+ stack.append(next_node)
+ else:
+ # If no unvisited neighbor is available, backtrack
+ stack.pop()
+
+ # Update the best walk if the current walk is larger
+ if len(walk) > len(best_walk):
+ best_walk = walk
+
+ # If we have enough samples, return the result
+ if len(walk) >= num_of_samples:
+ return np.array(walk)
+
+ # Add all visited nodes to the excluded set
+ excluded_nodes.update(visited)
+
+ # If all retries are exhausted and we still don't have enough samples, return the best walk found
+ return np.array(best_walk)
+
+ def _sample_view_indices(
+ self,
+ num_views_to_sample,
+ num_views_in_scene,
+ scene_pairwise_covisibility,
+ use_bidirectional_covis=True,
+ ):
+ """
+ Sample view indices from a scene based on the adjacency list and the number of views to sample.
+
+ Args:
+ num_views_to_sample (int): Number of views to sample.
+ num_views_in_scene (int): Total number of views available in the scene.
+ scene_pairwise_covisibility (np.ndarray): N x N covisibility matrix for the scene, where N is the number of views in the scene.
+ use_bidirectional_covis (bool): Whether to compute bidirectional covisibility by averaging row and column values.
+ If False, uses only row access (faster for large memory-mapped arrays).
+
+ Returns:
+ numpy.ndarray: Array of sampled view indices.
+ """
+ if num_views_to_sample == num_views_in_scene:
+ # Select all views in the scene
+ view_indices = self._rng.permutation(num_views_in_scene)
+ elif num_views_to_sample > num_views_in_scene:
+ # Select all views in the scene and repeat them to get the desired number of views
+ view_indices = self._rng.choice(
+ num_views_in_scene, size=num_views_to_sample, replace=True
+ )
+ else:
+ # Select a subset of single component connected views in the scene using random walk sampling
+ view_indices = self._random_walk_sampling(
+ scene_pairwise_covisibility,
+ num_views_to_sample,
+ use_bidirectional_covis=use_bidirectional_covis,
+ )
+ # If the required num of views can't be obtained even with 4 retries, repeat existing indices to get the desired number of views
+ if len(view_indices) < num_views_to_sample:
+ view_indices = self._rng.choice(
+ view_indices, size=num_views_to_sample, replace=True
+ )
+
+ return view_indices
+
+ def _getitem_fn(self, idx):
+ if isinstance(idx, tuple):
+ # The idx is a tuple if specifying the aspect-ratio or/and the number of views
+ if isinstance(self.num_views, int):
+ idx, ar_idx = idx
+ else:
+ idx, ar_idx, num_views_to_sample_idx = idx
+ else:
+ assert len(self._resolutions) == 1
+ assert isinstance(self.num_views, int)
+ ar_idx = 0
+
+ # Setup the rng
+ if self.seed: # reseed for each _getitem_fn
+ # Leads to deterministic sampling where repeating self.seed and self._seed_offset yields the same multi-view set again
+ # Scenes will be repeated if size of dataset is artificially increased using "N @" or "N *"
+ # When scenes are repeated, self._seed_offset is increased to ensure new multi-view sets
+ # This is useful for evaluation if the number of dataset scenes is < N, yet we want unique multi-view sets each iter
+ self._rng = np.random.default_rng(seed=self.seed + self._seed_offset + idx)
+ elif not hasattr(self, "_rng"):
+ seed = torch.initial_seed() # this is different for each dataloader process
+ self._rng = np.random.default_rng(seed=seed)
+
+ # Get the views for the given index and check that the number of views is correct
+ resolution = self._resolutions[ar_idx]
+ if isinstance(self.num_views, int):
+ num_views_to_sample = self.num_views
+ else:
+ num_views_to_sample = self.num_views[num_views_to_sample_idx]
+ views = self._get_views(idx, num_views_to_sample, resolution)
+ if isinstance(self.num_views, int):
+ assert len(views) == self.num_views
+ else:
+ assert len(views) in self.num_views
+
+ for v, view in enumerate(views):
+ # Store the index and other metadata
+ view["idx"] = (idx, ar_idx, v)
+ view["is_metric_scale"] = self.is_metric_scale
+ view["is_synthetic"] = self.is_synthetic
+
+ # Check the depth, intrinsics, and pose data (also other data if present)
+ assert "camera_intrinsics" in view
+ assert "camera_pose" in view
+ assert np.isfinite(view["camera_pose"]).all(), (
+ f"NaN or infinite values in camera pose for view {view_name(view)}"
+ )
+ assert np.isfinite(view["depthmap"]).all(), (
+ f"NaN or infinite values in depthmap for view {view_name(view)}"
+ )
+ assert "valid_mask" not in view
+ assert "pts3d" not in view, (
+ f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
+ )
+ if "prior_depth_z" in view:
+ assert np.isfinite(view["prior_depth_z"]).all(), (
+ f"NaN or infinite values in prior_depth_z for view {view_name(view)}"
+ )
+ if "non_ambiguous_mask" in view:
+ assert np.isfinite(view["non_ambiguous_mask"]).all(), (
+ f"NaN or infinite values in non_ambiguous_mask for view {view_name(view)}"
+ )
+
+ # Encode the image
+ width, height = view["img"].size
+ view["true_shape"] = np.int32((height, width))
+ view["img"] = self.transform(view["img"])
+ view["data_norm_type"] = self.data_norm_type
+
+ # Compute the pointmaps, raymap and depth along ray
+ (
+ pts3d,
+ valid_mask,
+ ray_origins_world,
+ ray_directions_world,
+ depth_along_ray,
+ ray_directions_cam,
+ pts3d_cam,
+ ) = get_absolute_pointmaps_and_rays_info(**view)
+ view["pts3d"] = pts3d
+ view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
+ view["depth_along_ray"] = depth_along_ray
+ view["ray_directions_cam"] = ray_directions_cam
+ view["pts3d_cam"] = pts3d_cam
+
+ # Compute the prior depth along ray if present
+ if "prior_depth_z" in view:
+ prior_pts3d, _ = depthmap_to_camera_coordinates(
+ view["prior_depth_z"], view["camera_intrinsics"]
+ )
+ view["prior_depth_along_ray"] = np.linalg.norm(prior_pts3d, axis=-1)
+ view["prior_depth_along_ray"] = view["prior_depth_along_ray"][..., None]
+ del view["prior_depth_z"]
+
+ # Convert ambiguous mask dtype to match valid mask dtype
+ if "non_ambiguous_mask" in view:
+ view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype(
+ view["valid_mask"].dtype
+ )
+ else:
+ ambiguous_mask = view["depthmap"] < 0
+ view["non_ambiguous_mask"] = ~ambiguous_mask
+ view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype(
+ view["valid_mask"].dtype
+ )
+
+ # Check all datatypes
+ for key, val in view.items():
+ res, err_msg = is_good_type(val)
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
+
+ # Check shapes
+ assert view["depthmap"].shape == view["img"].shape[1:]
+ assert view["depthmap"].shape == view["pts3d"].shape[:2]
+ assert view["depthmap"].shape == view["valid_mask"].shape
+ assert view["depthmap"].shape == view["depth_along_ray"].shape[:2]
+ assert view["depthmap"].shape == view["ray_directions_cam"].shape[:2]
+ assert view["depthmap"].shape == view["pts3d_cam"].shape[:2]
+ if "prior_depth_along_ray" in view:
+ assert view["depthmap"].shape == view["prior_depth_along_ray"].shape[:2]
+ if "non_ambiguous_mask" in view:
+ assert view["depthmap"].shape == view["non_ambiguous_mask"].shape
+
+ # Expand the last dimennsion of the depthmap
+ view["depthmap"] = view["depthmap"][..., None]
+
+ # Append RNG state to the views, this allows to check whether the RNG is in the same state each time
+ view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
+
+ # Compute and store the quaternions and translation for the camera poses
+ # Notation is (x, y, z, w) for quaternions
+ # This also ensures that the camera poses have a positive determinant (right-handed coordinate system)
+ view["camera_pose_quats"] = (
+ Rotation.from_matrix(view["camera_pose"][:3, :3])
+ .as_quat()
+ .astype(view["camera_pose"].dtype)
+ )
+ view["camera_pose_trans"] = view["camera_pose"][:3, 3].astype(
+ view["camera_pose"].dtype
+ )
+
+ # Check the pointmaps, rays, depth along ray, and camera pose quaternions and translation to ensure they are finite
+ assert np.isfinite(view["pts3d"]).all(), (
+ f"NaN in pts3d for view {view_name(view)}"
+ )
+ assert np.isfinite(view["valid_mask"]).all(), (
+ f"NaN in valid_mask for view {view_name(view)}"
+ )
+ assert np.isfinite(view["depth_along_ray"]).all(), (
+ f"NaN in depth_along_ray for view {view_name(view)}"
+ )
+ assert np.isfinite(view["ray_directions_cam"]).all(), (
+ f"NaN in ray_directions_cam for view {view_name(view)}"
+ )
+ assert np.isfinite(view["pts3d_cam"]).all(), (
+ f"NaN in pts3d_cam for view {view_name(view)}"
+ )
+ assert np.isfinite(view["camera_pose_quats"]).all(), (
+ f"NaN in camera_pose_quats for view {view_name(view)}"
+ )
+ assert np.isfinite(view["camera_pose_trans"]).all(), (
+ f"NaN in camera_pose_trans for view {view_name(view)}"
+ )
+ if "prior_depth_along_ray" in view:
+ assert np.isfinite(view["prior_depth_along_ray"]).all(), (
+ f"NaN in prior_depth_along_ray for view {view_name(view)}"
+ )
+
+ return views
+
+ def __getitem__(self, idx):
+ if self.max_num_retries == 0:
+ return self._getitem_fn(idx)
+
+ num_retries = 0
+ while num_retries <= self.max_num_retries:
+ try:
+ return self._getitem_fn(idx)
+ except Exception as e:
+ scene_idx = idx[0] if isinstance(idx, tuple) else idx
+ print(
+ f"Error in {type(self).__name__}.__getitem__ for scene_idx={scene_idx}: {e}"
+ )
+
+ if num_retries >= self.max_num_retries:
+ print(
+ f"Max retries ({self.max_num_retries}) reached, raising the exception"
+ )
+ raise e
+
+ # Retry with a different scene index
+ num_retries += 1
+ if isinstance(idx, tuple):
+ # The scene index is the first element of the tuple
+ idx_list = list(idx)
+ idx_list[0] = np.random.randint(0, len(self))
+ idx = tuple(idx_list)
+ else:
+ # The scene index is idx
+ idx = np.random.randint(0, len(self))
+ scene_idx = idx[0] if isinstance(idx, tuple) else idx
+ print(
+ f"Retrying with scene_idx={scene_idx} ({num_retries} of {self.max_num_retries})"
+ )
+
+
+def is_good_type(v):
+ """
+ Check if a value has an acceptable data type for processing in the dataset.
+
+ Args:
+ v: The value to check.
+
+ Returns:
+ tuple: A tuple containing:
+ - bool: True if the type is acceptable, False otherwise.
+ - str or None: Error message if the type is not acceptable, None otherwise.
+ """
+ if isinstance(v, (str, int, tuple)):
+ return True, None
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
+ return False, f"bad {v.dtype=}"
+ return True, None
+
+
+def view_name(view, batch_index=None):
+ """
+ Generate a string identifier for a view based on its dataset, label, and instance.
+
+ Args:
+ view (dict): Dictionary containing view information with 'dataset', 'label', and 'instance' keys.
+ batch_index (int, optional): Index to select from batched data. Defaults to None.
+
+ Returns:
+ str: A formatted string in the form "dataset/label/instance".
+ """
+
+ def sel(x):
+ return x[batch_index] if batch_index not in (None, slice(None)) else x
+
+ db = sel(view["dataset"])
+ label = sel(view["label"])
+ instance = sel(view["instance"])
+ return f"{db}/{label}/{instance}"
diff --git a/mapanything/datasets/base/batched_sampler.py b/mapanything/datasets/base/batched_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..322a66bfadd30b88ea6d14b1f8241dc570263846
--- /dev/null
+++ b/mapanything/datasets/base/batched_sampler.py
@@ -0,0 +1,426 @@
+"""
+Utilities for random sampling under a single or multiple constraints
+
+References: DUSt3R
+"""
+
+import numpy as np
+import torch
+
+
+def round_by(total, multiple, up=False):
+ """
+ Round a number to the nearest multiple of another number.
+
+ Args:
+ total (int): The number to round
+ multiple (int): The multiple to round to
+ up (bool, optional): Whether to round up. Defaults to False.
+
+ Returns:
+ int: The rounded number
+ """
+ if up:
+ total = total + multiple - 1
+ return (total // multiple) * multiple
+
+
+class BatchedRandomSampler:
+ """
+ Random sampling under a constraint: each sample in the batch has the same feature,
+ which is chosen randomly from a known pool of 'features' for each batch.
+
+ For instance, the 'feature' could be the image aspect-ratio.
+
+ The index returned is a tuple (sample_idx, feat_idx).
+ This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
+ """
+
+ def __init__(
+ self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True
+ ):
+ """
+ Args:
+ dataset: Dataset to sample from
+ batch_size: Number of samples per batch
+ pool_size: Integer representing the size of feature pool
+ world_size: Number of distributed processes
+ rank: Rank of the current process
+ drop_last: Whether to drop the last incomplete batch
+ """
+ self.batch_size = batch_size
+ self.pool_size = pool_size
+
+ self.len_dataset = N = len(dataset)
+ self.total_size = round_by(N, batch_size * world_size) if drop_last else N
+ assert world_size == 1 or drop_last, (
+ "must drop the last batch in distributed mode"
+ )
+
+ # Distributed sampler
+ self.world_size = world_size
+ self.rank = rank
+ self.epoch = None
+
+ def __len__(self):
+ """
+ Get the length of the sampler.
+
+ Returns:
+ int: The number of samples in the sampler for the current process
+ """
+ return self.total_size // self.world_size
+
+ def set_epoch(self, epoch):
+ """
+ Set the epoch for this sampler.
+
+ This should be called before each epoch to ensure proper shuffling of the data.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ self.epoch = epoch
+
+ def __iter__(self):
+ """
+ Iterator over the indices.
+
+ This method generates random indices for each batch, ensuring that all samples
+ within a batch have the same feature index for the given feature pool.
+
+ Yields:
+ tuple: A tuple containing (sample_idx, feat_idx)
+ """
+ # Prepare RNG
+ if self.epoch is None:
+ assert self.world_size == 1 and self.rank == 0, (
+ "use set_epoch() if distributed mode is used"
+ )
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.epoch + 777
+ rng = np.random.default_rng(seed=seed)
+
+ # Random indices (will restart from 0 if not drop_last)
+ sample_idxs = np.arange(self.total_size)
+ rng.shuffle(sample_idxs)
+
+ # Random feat_idxs (same across each batch)
+ n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
+ feat_idxs = rng.integers(self.pool_size, size=n_batches)
+ feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
+ feat_idxs = feat_idxs.ravel()[: self.total_size]
+
+ # Put them together
+ idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
+
+ # Distributed sampler: we select a subset of batches
+ # Make sure the slice for each node is aligned with batch_size
+ size_per_proc = self.batch_size * (
+ (self.total_size + self.world_size * self.batch_size - 1)
+ // (self.world_size * self.batch_size)
+ )
+ idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
+
+ yield from (tuple(idx) for idx in idxs)
+
+
+class BatchedMultiFeatureRandomSampler:
+ """
+ Random sampling under multiple constraints: each sample in the batch has the same features,
+ which are chosen randomly from known pools of 'features' for each batch.
+
+ For instance, the 'features' could be the image aspect-ratio and scene type.
+
+ The index returned is a tuple (sample_idx, feat_idx_1, feat_idx_2, ...).
+ This sampler ensures that each series of `batch_size` indices has the same feature indices.
+ """
+
+ def __init__(
+ self, dataset, batch_size, pool_sizes, world_size=1, rank=0, drop_last=True
+ ):
+ """
+ Args:
+ dataset: Dataset to sample from
+ batch_size: Number of samples per batch
+ pool_sizes: List of integers representing the size of each feature pool
+ world_size: Number of distributed processes
+ rank: Rank of the current process
+ drop_last: Whether to drop the last incomplete batch
+ """
+ self.batch_size = batch_size
+ self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes]
+
+ self.len_dataset = N = len(dataset)
+ self.total_size = round_by(N, batch_size * world_size) if drop_last else N
+ assert world_size == 1 or drop_last, (
+ "must drop the last batch in distributed mode"
+ )
+
+ # Distributed sampler
+ self.world_size = world_size
+ self.rank = rank
+ self.epoch = None
+
+ def __len__(self):
+ """
+ Get the length of the sampler.
+
+ Returns:
+ int: The number of samples in the sampler for the current process
+ """
+ return self.total_size // self.world_size
+
+ def set_epoch(self, epoch):
+ """
+ Set the epoch for this sampler.
+
+ This should be called before each epoch to ensure proper shuffling of the data.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ self.epoch = epoch
+
+ def __iter__(self):
+ """
+ Iterator over the indices.
+
+ This method generates random indices for each batch, ensuring that all samples
+ within a batch have the same feature indices for multiple features.
+
+ Yields:
+ tuple: A tuple containing (sample_idx, feat_idx_1, feat_idx_2, ...)
+ """
+ # Prepare RNG
+ if self.epoch is None:
+ assert self.world_size == 1 and self.rank == 0, (
+ "use set_epoch() if distributed mode is used"
+ )
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.epoch + 777
+ rng = np.random.default_rng(seed=seed)
+
+ # Random indices (will restart from 0 if not drop_last)
+ sample_idxs = np.arange(self.total_size)
+ rng.shuffle(sample_idxs)
+
+ # Random feat_idxs (same across each batch)
+ n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
+
+ # Generate feature indices for each feature pool
+ all_feat_idxs = []
+ for pool_size in self.pool_sizes:
+ feat_idxs = rng.integers(pool_size, size=n_batches)
+ feat_idxs = np.broadcast_to(
+ feat_idxs[:, None], (n_batches, self.batch_size)
+ )
+ feat_idxs = feat_idxs.ravel()[: self.total_size]
+ all_feat_idxs.append(feat_idxs)
+
+ # Put them together
+ idxs = np.column_stack(
+ [sample_idxs] + all_feat_idxs
+ ) # shape = (total_size, 1 + len(pool_sizes))
+
+ # Distributed sampler: we select a subset of batches
+ # Make sure the slice for each node is aligned with batch_size
+ size_per_proc = self.batch_size * (
+ (self.total_size + self.world_size * self.batch_size - 1)
+ // (self.world_size * self.batch_size)
+ )
+ idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
+
+ yield from (tuple(idx) for idx in idxs)
+
+
+class DynamicBatchedMultiFeatureRandomSampler:
+ """
+ Random sampling under multiple constraints with dynamic batch size:
+ each sample in the batch has the same features, which are chosen randomly
+ from known pools of 'features' for each batch.
+
+ The batch size is dynamically determined based on a specified feature index,
+ using a direct mapping from feature values to batch sizes.
+
+ For instance, if one of the features is the number of images in a multi-view set,
+ you can specify different batch sizes for different numbers of images to optimize
+ GPU memory usage. This is achieved by using the feature_to_batch_size_map parameter
+ to directly specify what batch size to use for each feature value.
+
+ The returned index is a list of tuples [(sample_idx, feat_idx_1, feat_idx_2, ...), ...].
+ """
+
+ def __init__(
+ self,
+ dataset,
+ pool_sizes,
+ scaling_feature_idx=0,
+ feature_to_batch_size_map=None,
+ world_size=1,
+ rank=0,
+ drop_last=True,
+ ):
+ """
+ Args:
+ dataset: Dataset to sample from
+ pool_sizes: List of integers representing the size of each feature pool
+ scaling_feature_idx: Index of the feature to use for determining batch size (0-based index into pool_sizes)
+ feature_to_batch_size_map: Optional function or dict that maps feature values directly to batch sizes.
+ For example, if the feature represents number of views, this maps number of views
+ to appropriate batch size that can fit in GPU memory.
+ If None, uses a default batch size of 1 for all feature values.
+ world_size: Number of distributed processes
+ rank: Rank of the current process
+ drop_last: Whether to drop the last incomplete batch
+ """
+ self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes]
+ self.scaling_feature_idx = scaling_feature_idx
+
+ # Ensure scaling_feature_idx is valid
+ if scaling_feature_idx < 0 or scaling_feature_idx >= len(self.pool_sizes):
+ raise ValueError(
+ f"scaling_feature_idx must be between 0 and {len(self.pool_sizes) - 1}"
+ )
+
+ # Set up mapping from feature values to batch sizes
+ self.feature_to_batch_size_map = feature_to_batch_size_map
+ if self.feature_to_batch_size_map is None:
+ # Default: batch size of 1 for all feature values
+ self.feature_to_batch_size_map = {
+ i: 1 for i in range(self.pool_sizes[scaling_feature_idx])
+ }
+
+ self.len_dataset = N = len(dataset)
+
+ # We don't know the exact batch size yet, so we use a large number for total_size
+ # This will be adjusted during iteration
+ self.total_size = N
+
+ # Distributed sampler
+ self.world_size = world_size
+ self.rank = rank
+ self.epoch = None
+ self.drop_last = drop_last
+
+ def __len__(self):
+ """
+ Get the approximate length of the sampler.
+
+ Since batch size varies, this is an estimate based on the largest batch size
+ in the mapping, which provides a lower bound on the number of batches.
+
+ Returns:
+ int: The estimated minimum number of samples in the sampler for the current process
+ """
+ # Find the largest batch size in the mapping
+ if callable(self.feature_to_batch_size_map):
+ # If it's a function, sample some values to find the maximum
+ batch_sizes = [
+ self.feature_to_batch_size_map(i)
+ for i in range(self.pool_sizes[self.scaling_feature_idx])
+ ]
+ max_batch_size = max(batch_sizes)
+ else:
+ # If it's a dict or similar, find the maximum directly
+ max_batch_size = max(self.feature_to_batch_size_map.values())
+
+ # Ensure minimum batch size of 1
+ max_batch_size = max(1, max_batch_size)
+
+ # Estimate total batches using the largest batch size
+ # This gives a lower bound on the number of batches
+ total_batches = self.total_size // max_batch_size
+ if not self.drop_last and self.total_size % max_batch_size > 0:
+ total_batches += 1
+
+ # Distribute among processes
+ return total_batches // self.world_size
+
+ def set_epoch(self, epoch):
+ """
+ Set the epoch for this sampler.
+
+ This should be called before each epoch to ensure proper shuffling of the data.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ self.epoch = epoch
+
+ def __iter__(self):
+ """
+ Iterator over the indices with dynamic batch sizes.
+
+ This method generates random indices for each batch, ensuring that all samples
+ within a batch have the same feature indices for multiple features.
+ The batch size is determined directly from the feature_to_batch_size_map.
+
+ The iterator enforces the length returned by __len__() by stopping after
+ exactly that many batches have been yielded for this process.
+
+ Yields:
+ list of tuples: A batch of tuples, each containing (sample_idx, feat_idx_1, feat_idx_2, ...)
+ """
+ # Prepare RNG
+ if self.epoch is None:
+ assert self.world_size == 1 and self.rank == 0, (
+ "use set_epoch() if distributed mode is used"
+ )
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.epoch + 777
+ rng = np.random.default_rng(seed=seed)
+
+ # Random indices for the entire dataset
+ sample_idxs = np.arange(self.total_size)
+ rng.shuffle(sample_idxs)
+
+ # Get the target number of batches for this process (enforce strict length)
+ target_batches_for_process = len(self)
+ batches_yielded_for_process = 0
+
+ # Process indices in batches with dynamic sizing
+ idx = 0
+ batch_idx = 0 # Track batch index for even distribution
+ while idx < len(sample_idxs) and (
+ batches_yielded_for_process < target_batches_for_process
+ ):
+ # Randomly select feature indices for this batch
+ feat_idxs = [rng.integers(pool_size) for pool_size in self.pool_sizes]
+
+ # Get the scaling feature value
+ scaling_feat = feat_idxs[self.scaling_feature_idx]
+
+ # Get the batch size directly from the mapping
+ if callable(self.feature_to_batch_size_map):
+ batch_size = self.feature_to_batch_size_map(scaling_feat)
+ else:
+ batch_size = self.feature_to_batch_size_map.get(scaling_feat, 1)
+
+ # Ensure minimum batch size of 1
+ batch_size = max(1, batch_size)
+
+ # Ensure we don't go beyond available samples
+ remaining = len(sample_idxs) - idx
+ if remaining < batch_size:
+ if self.drop_last:
+ break
+ batch_size = remaining
+
+ # Create batch with consistent feature indices
+ batch = []
+ for i in range(batch_size):
+ if idx + i < len(sample_idxs):
+ sample_idx = sample_idxs[idx + i]
+ batch.append(tuple([sample_idx] + feat_idxs))
+
+ # Distribute batches among processes in round-robin fashion
+ if len(batch) > 0 and (batch_idx % self.world_size == self.rank):
+ yield batch
+ batches_yielded_for_process += 1
+
+ batch_idx += 1 # Increment batch index
+ idx += batch_size
diff --git a/mapanything/datasets/base/easy_dataset.py b/mapanything/datasets/base/easy_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac429bf808762e4f83f04af50f562e885a3dd61
--- /dev/null
+++ b/mapanything/datasets/base/easy_dataset.py
@@ -0,0 +1,473 @@
+"""
+Base dataset class that enables easy resizing and combining
+
+References: DUSt3R
+"""
+
+import numpy as np
+
+from mapanything.datasets.base.batched_sampler import (
+ BatchedMultiFeatureRandomSampler,
+ DynamicBatchedMultiFeatureRandomSampler,
+)
+
+
+class EasyDataset:
+ """
+ Dataset that can be easily resized and combined.
+
+ Examples:
+ ---------
+ 2 * dataset ==> Duplicate each element 2x
+
+ 10 @ dataset ==> Set the size to 10 (random sampling, duplicates if necessary)
+
+ Dataset1 + Dataset2 ==> Concatenate datasets
+ """
+
+ def __add__(self, other):
+ """
+ Concatenate this dataset with another dataset.
+
+ Args:
+ other (EasyDataset): Another dataset to concatenate with this one
+
+ Returns:
+ CatDataset: A new dataset that is the concatenation of this dataset and the other
+ """
+ return CatDataset([self, other])
+
+ def __rmul__(self, factor):
+ """
+ Multiply the dataset by a factor, duplicating each element.
+
+ Args:
+ factor (int): Number of times to duplicate each element
+
+ Returns:
+ MulDataset: A new dataset with each element duplicated 'factor' times
+ """
+ return MulDataset(factor, self)
+
+ def __rmatmul__(self, factor):
+ """
+ Resize the dataset to a specific size using random sampling.
+
+ Args:
+ factor (int): The new size of the dataset
+
+ Returns:
+ ResizedDataset: A new dataset with the specified size
+ """
+ return ResizedDataset(factor, self)
+
+ def set_epoch(self, epoch):
+ """
+ Set the current epoch for all constituent datasets.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ pass # nothing to do by default
+
+ def make_sampler(
+ self,
+ batch_size=None,
+ shuffle=True,
+ world_size=1,
+ rank=0,
+ drop_last=True,
+ max_num_of_images_per_gpu=None,
+ use_dynamic_sampler=True,
+ ):
+ """
+ Create a sampler for this dataset.
+
+ Args:
+ batch_size (int, optional): Number of samples per batch (used for non-dynamic sampler). Defaults to None.
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
+ world_size (int, optional): Number of distributed processes. Defaults to 1.
+ rank (int, optional): Rank of the current process. Defaults to 0.
+ drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
+ max_num_of_images_per_gpu (int, optional): Maximum number of images per GPU for dynamic batching. Defaults to None.
+ use_dynamic_sampler (bool, optional): Whether to use the dynamic sampler. Defaults to True.
+
+ Returns:
+ DynamicBatchedMultiFeatureRandomSampler or BatchedMultiFeatureRandomSampler: A sampler for this dataset
+
+ Raises:
+ NotImplementedError: If shuffle is False
+ ValueError: If num_views has an invalid type or required parameters are missing
+ """
+ if not (shuffle):
+ raise NotImplementedError() # cannot deal yet
+
+ if isinstance(self.num_views, int):
+ num_of_aspect_ratios = len(self._resolutions)
+ feature_pool_sizes = [num_of_aspect_ratios]
+ scaling_feature_idx = 0 # Use aspect ratio as scaling feature
+ elif isinstance(self.num_views, list):
+ num_of_aspect_ratios = len(self._resolutions)
+ num_of_num_views = len(self.num_views)
+ feature_pool_sizes = [num_of_aspect_ratios, num_of_num_views]
+ scaling_feature_idx = 1 # Use num_views as scaling feature
+ else:
+ raise ValueError(
+ f"Bad type for {self.num_views=}, should be int or list of ints"
+ )
+
+ if use_dynamic_sampler:
+ if max_num_of_images_per_gpu is None:
+ raise ValueError(
+ "max_num_of_images_per_gpu must be provided when using dynamic sampler"
+ )
+
+ # Create feature-to-batch-size mapping
+ if isinstance(self.num_views, list):
+ # Map num_views_idx to batch size: max(1, max_num_of_images_per_gpu // (num_views_idx + dataset.num_views_min))
+ feature_to_batch_size_map = {}
+ for num_views_idx, num_views in enumerate(self.num_views):
+ batch_size_for_multi_view_sets = max(
+ 1, max_num_of_images_per_gpu // num_views
+ )
+ feature_to_batch_size_map[num_views_idx] = (
+ batch_size_for_multi_view_sets
+ )
+ else:
+ # For fixed num_views, use a simple mapping
+ feature_to_batch_size_map = {
+ 0: max(1, max_num_of_images_per_gpu // self.num_views)
+ }
+
+ return DynamicBatchedMultiFeatureRandomSampler(
+ self,
+ pool_sizes=feature_pool_sizes,
+ scaling_feature_idx=scaling_feature_idx,
+ feature_to_batch_size_map=feature_to_batch_size_map,
+ world_size=world_size,
+ rank=rank,
+ drop_last=drop_last,
+ )
+ else:
+ if batch_size is None:
+ raise ValueError(
+ "batch_size must be provided when not using dynamic sampler"
+ )
+
+ return BatchedMultiFeatureRandomSampler(
+ self,
+ batch_size,
+ feature_pool_sizes,
+ world_size=world_size,
+ rank=rank,
+ drop_last=drop_last,
+ )
+
+
+class MulDataset(EasyDataset):
+ """Artifically augmenting the size of a dataset."""
+
+ multiplicator: int
+
+ def __init__(self, multiplicator, dataset):
+ """
+ Initialize a dataset that artificially augments the size of another dataset.
+
+ Args:
+ multiplicator (int): Factor by which to multiply the dataset size
+ dataset (EasyDataset): The dataset to augment
+ """
+ assert isinstance(multiplicator, int) and multiplicator > 0
+ self.multiplicator = multiplicator
+ self.dataset = dataset
+
+ def __len__(self):
+ """
+ Get the length of the dataset.
+
+ Returns:
+ int: The number of samples in the dataset
+ """
+ return self.multiplicator * len(self.dataset)
+
+ def __repr__(self):
+ """
+ Get a string representation of the dataset.
+
+ Returns:
+ str: String representation showing the multiplication factor and the original dataset
+ """
+ return f"{self.multiplicator}*{repr(self.dataset)}"
+
+ def __getitem__(self, idx):
+ """
+ Get an item from the dataset.
+
+ Args:
+ idx: Index or tuple of indices to retrieve
+
+ Returns:
+ The item at the specified index from the original dataset
+ """
+ if isinstance(idx, tuple):
+ other = idx[1:]
+ idx = idx[0]
+ new_idx = (idx // self.multiplicator, *other)
+ return self.dataset[new_idx]
+ else:
+ return self.dataset[idx // self.multiplicator]
+
+ @property
+ def _resolutions(self):
+ """
+ Get the resolutions of the dataset.
+
+ Returns:
+ The resolutions from the original dataset
+ """
+ return self.dataset._resolutions
+
+ @property
+ def num_views(self):
+ """
+ Get the number of views used for the dataset.
+
+ Returns:
+ int or list: The number of views parameter from the original dataset
+ """
+ return self.dataset.num_views
+
+
+class ResizedDataset(EasyDataset):
+ """Artifically changing the size of a dataset."""
+
+ new_size: int
+
+ def __init__(self, new_size, dataset):
+ """
+ Initialize a dataset with an artificially changed size.
+
+ Args:
+ new_size (int): The new size of the dataset
+ dataset (EasyDataset): The original dataset
+ """
+ assert isinstance(new_size, int) and new_size > 0
+ self.new_size = new_size
+ self.dataset = dataset
+
+ def __len__(self):
+ """
+ Get the length of the dataset.
+
+ Returns:
+ int: The new size of the dataset
+ """
+ return self.new_size
+
+ def __repr__(self):
+ """
+ Get a string representation of the dataset.
+
+ Returns:
+ str: String representation showing the new size and the original dataset
+ """
+ size_str = str(self.new_size)
+ for i in range((len(size_str) - 1) // 3):
+ sep = -4 * i - 3
+ size_str = size_str[:sep] + "_" + size_str[sep:]
+ return f"{size_str} @ {repr(self.dataset)}"
+
+ def set_epoch(self, epoch):
+ """
+ Set the current epoch and generate a new random mapping of indices.
+
+ This method must be called before using __getitem__.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ # This random shuffle only depends on the epoch
+ rng = np.random.default_rng(seed=epoch + 777)
+
+ # Shuffle all indices
+ perm = rng.permutation(len(self.dataset))
+
+ # Calculate how many repetitions we need
+ num_repetitions = 1 + (len(self) - 1) // len(self.dataset)
+
+ # Rotary extension until target size is met
+ shuffled_idxs = np.concatenate([perm] * num_repetitions)
+ self._idxs_mapping = shuffled_idxs[: self.new_size]
+
+ # Generate the seed offset for each repetition
+ # This is needed to ensure we see unique samples when we repeat a scene
+ seed_offset_per_repetition = [
+ np.full(len(self.dataset), i) for i in range(num_repetitions)
+ ]
+ seed_offset_idxs = np.concatenate(seed_offset_per_repetition)
+ self._idxs_seed_offset = seed_offset_idxs[: self.new_size]
+
+ assert len(self._idxs_mapping) == self.new_size
+ assert len(self._idxs_seed_offset) == self.new_size
+
+ def __getitem__(self, idx):
+ """
+ Get an item from the dataset.
+
+ Args:
+ idx: Index or tuple of indices to retrieve
+
+ Returns:
+ The item at the mapped index from the original dataset
+
+ Raises:
+ AssertionError: If set_epoch has not been called
+ """
+ assert hasattr(self, "_idxs_mapping"), (
+ "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
+ )
+ if isinstance(idx, tuple):
+ other = idx[1:]
+ idx = idx[0]
+ self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
+ new_idx = (self._idxs_mapping[idx], *other)
+ return self.dataset[new_idx]
+ else:
+ self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
+ return self.dataset[self._idxs_mapping[idx]]
+
+ @property
+ def _resolutions(self):
+ """
+ Get the resolutions of the dataset.
+
+ Returns:
+ The resolutions from the original dataset
+ """
+ return self.dataset._resolutions
+
+ @property
+ def num_views(self):
+ """
+ Get the number of views used for the dataset.
+
+ Returns:
+ int or list: The number of views parameter from the original dataset
+ """
+ return self.dataset.num_views
+
+
+class CatDataset(EasyDataset):
+ """Concatenation of several datasets"""
+
+ def __init__(self, datasets):
+ """
+ Initialize a dataset that is a concatenation of several datasets.
+
+ Args:
+ datasets (list): List of EasyDataset instances to concatenate
+ """
+ for dataset in datasets:
+ assert isinstance(dataset, EasyDataset)
+ self.datasets = datasets
+ self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
+
+ def __len__(self):
+ """
+ Get the length of the concatenated dataset.
+
+ Returns:
+ int: Total number of samples across all datasets
+ """
+ return self._cum_sizes[-1]
+
+ def __repr__(self):
+ """
+ Get a string representation of the concatenated dataset.
+
+ Returns:
+ str: String representation showing all concatenated datasets joined by '+'
+ """
+ # Remove uselessly long transform
+ return " + ".join(
+ repr(dataset).replace(
+ ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
+ "",
+ )
+ for dataset in self.datasets
+ )
+
+ def set_epoch(self, epoch):
+ """
+ Set the current epoch for all constituent datasets.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ for dataset in self.datasets:
+ dataset.set_epoch(epoch)
+
+ def __getitem__(self, idx):
+ """
+ Get an item from the concatenated dataset.
+
+ Args:
+ idx: Index or tuple of indices to retrieve
+
+ Returns:
+ The item at the specified index from the appropriate constituent dataset
+
+ Raises:
+ IndexError: If the index is out of range
+ """
+ other = None
+ if isinstance(idx, tuple):
+ other = idx[1:]
+ idx = idx[0]
+
+ if not (0 <= idx < len(self)):
+ raise IndexError()
+
+ db_idx = np.searchsorted(self._cum_sizes, idx, "right")
+ dataset = self.datasets[db_idx]
+ new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
+
+ if other is not None:
+ new_idx = (new_idx, *other)
+ return dataset[new_idx]
+
+ @property
+ def _resolutions(self):
+ """
+ Get the resolutions of the dataset.
+
+ Returns:
+ The resolutions from the first dataset (all datasets must have the same resolutions)
+
+ Raises:
+ AssertionError: If datasets have different resolutions
+ """
+ resolutions = self.datasets[0]._resolutions
+ for dataset in self.datasets[1:]:
+ assert tuple(dataset._resolutions) == tuple(resolutions), (
+ "All datasets must have the same resolutions"
+ )
+ return resolutions
+
+ @property
+ def num_views(self):
+ """
+ Get the number of views used for the dataset.
+
+ Returns:
+ int or list: The number of views parameter from the first dataset
+
+ Raises:
+ AssertionError: If datasets have different num_views
+ """
+ num_views = self.datasets[0].num_views
+ for dataset in self.datasets[1:]:
+ assert dataset.num_views == num_views, (
+ "All datasets must have the same num_views and variable_num_views parameters"
+ )
+ return num_views
diff --git a/mapanything/datasets/utils/__init__.py b/mapanything/datasets/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/datasets/utils/data_splits.py b/mapanything/datasets/utils/data_splits.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a69ec6df5a98cd21e15c9de1791be49903d8383
--- /dev/null
+++ b/mapanything/datasets/utils/data_splits.py
@@ -0,0 +1,1741 @@
+"""
+Modules containing dataset split information
+"""
+
+
+class BlendedMVSSplits:
+ """
+ This class contains the information about the BlendedMVS dataset splits.
+ """
+
+ def __init__(self):
+ """
+ The splits are generated using the following logic:
+ # Get all seqls and seqhs using self.blendedmvs_info.all_sequences
+ all_sequences = self.blendedmvs_info.all_sequences
+ all_seqls = [int(seq[8:], 16) for seq in all_sequences]
+ all_seqhs = [int(seq[:8], 16) for seq in all_sequences]
+ # Split the seqls (& corresponding seqhs) using the DUSt3R train/val split logic
+ if split is None:
+ selection = slice(None)
+ elif split in ["train", "overfit"]:
+ # select 90% of all scenes
+ selection = [(seql % 10) > 0 for seql in all_seqls]
+ elif split == "val":
+ # select 10% of all scenes
+ selection = [(seql % 10) == 0 for seql in all_seqls]
+ else:
+ raise ValueError(f"Unknown split {split}, must be None, train, val or overfit")
+ # Filter sequences based on the selection
+ selected_seqls = [seql for seql, sel in zip(all_seqls, selection) if sel]
+ selected_seqhs = [seqh for seqh, sel in zip(all_seqhs, selection) if sel]
+ # Put them back into sequence names f"{seqh:08x}{seql:016x}"
+ sequence_names = [f"{seqh:08x}{seql:016x}" for seqh, seql in zip(selected_seqhs, selected_seqls)]
+ # Remove invalid sequence names which don't exist in self.blendedmvs_info.sequences
+ valid_sequences = set(self.blendedmvs_info.sequences)
+ valid_sequence_names = [name for name in sequence_names if name in valid_sequences]
+ """
+ # All the 502 sequences in the dataset (totals to 115k images)
+ self.all_scenes = [
+ "000000000000000000000000",
+ "00000000000000000000000a",
+ "00000000000000000000000b",
+ "00000000000000000000000c",
+ "00000000000000000000000d",
+ "00000000000000000000000e",
+ "00000000000000000000000f",
+ "000000000000000000000001",
+ "00000000000000000000001a",
+ "00000000000000000000001b",
+ "00000000000000000000001d",
+ "000000000000000000000002",
+ "000000000000000000000003",
+ "000000000000000000000004",
+ "000000000000000000000005",
+ "5a2a95f032a1c655cfe3de62",
+ "5a2af22b32a1c655cfe46013",
+ "5a2ba6de32a1c655cfe51b79",
+ "5a3b9731e24cd76dad1a5f1b",
+ "5a3ca9cb270f0e3f14d0eddb",
+ "5a3cb4e4270f0e3f14d12f43",
+ "5a03e732454a8a7ec672776c",
+ "5a3f4aba5889373fbbc5d3b5",
+ "5a4a38dad38c8a075495b5d2",
+ "5a5a1e48d62c7a12d5d00e47",
+ "5a6b1c418d100c2f8fdc4411",
+ "5a6feeb54a7fbc3f874f9db7",
+ "5a7cb1d6fe5c0d6fb53e64fb",
+ "5a7d3db14989e929563eb153",
+ "5a8aa0fab18050187cbe060e",
+ "5a9e5df65baeef72b4a021cd",
+ "5a48ba95c7dab83a7d7b44ed",
+ "5a48c4e9c7dab83a7d7b5cc7",
+ "5a48d4b2c7dab83a7d7b9851",
+ "5a69c47d0d5d0a7f3b2e9752",
+ "5a77b46b318efe6c6736e68a",
+ "5a355c271b63f53d5970f362",
+ "5a489fb1c7dab83a7d7b1070",
+ "5a533e8034d7582116e34209",
+ "5a562fc7425d0f5186314725",
+ "5a572fd9fc597b0478a81d14",
+ "5a588a8193ac3d233f77fbca",
+ "5a618c72784780334bc1972d",
+ "5a752d42acc41e2423f17674",
+ "5a969eea91dfc339a9a3ad2c",
+ "5a8315f624b8e938486e0bd8",
+ "5a57542f333d180827dfc132",
+ "5a0271884e62597cdee0d0eb",
+ "5a6400933d809f1d8200af15",
+ "5a6464143d809f1d8208c43c",
+ "5a563183425d0f5186314855",
+ "5aa0f9d7a9efce63548c69a1",
+ "5aa0f478a9efce63548c1cb4",
+ "5aa7db90bfdd572271e95246",
+ "5aa235f64a17b335eeaf9609",
+ "5aa515e613d42d091d29d300",
+ "5aa1196ea9efce63548ed649",
+ "5aaadd4cbc13235570d178a7",
+ "5ab6af12ac4291329b1072ab",
+ "5ab7e00aac4291329b15864d",
+ "5ab8b8e029f5351f7f2ccf59",
+ "5ab74bf2ac4291329b11e879",
+ "5ab85f1dac4291329b17cb50",
+ "5ab8713ba3799a1d138bd69a",
+ "5abc2506b53b042ead637d86",
+ "5acc7459a7853c4b5ebbef59",
+ "5acf8ca0f3d8a750097e4b15",
+ "5adc6bd52430a05ecb2ffb85",
+ "5ae2e9c5fe405c5076abc6b2",
+ "5af02e904c8216544b4ab5a2",
+ "5af28cea59bc705737003253",
+ "5af545d0559359053d25dcf5",
+ "5afacb69ab00705d0cefdd5b",
+ "5b2c67b5e0878c381608b8d8",
+ "5b3b2b9e8d46a939f933fdc0",
+ "5b3b353d8d46a939f93524b9",
+ "5b6e716d67b396324c2d77cb",
+ "5b6eff8b67b396324c5b2672",
+ "5b7a3890fc8fcf6781e2593a",
+ "5b21e18c58e2823a67a10dd8",
+ "5b60fa0c764f146feef84df0",
+ "5b69cc0cb44b61786eb959bf",
+ "5b78e57afc8fcf6781d0c3ba",
+ "5b192eb2170cf166458ff886",
+ "5b558a928bbfb62204e77ba2",
+ "5b864d850d072a699b32f4ae",
+ "5b908d3dc6ab78485f3d24a9",
+ "5b950c71608de421b1e7318f",
+ "5b4933abf2b5f44e95de482a",
+ "5b08286b2775267d5b0634ba",
+ "5b37189a35304b6f75e7583e",
+ "5b271079e0878c3816dacca4",
+ "5b22269758e2823a67a3bd03",
+ "5b62647143840965efc0dbde",
+ "5ba19a8a360c7c30c1c169df",
+ "5ba75d79d76ffa2c86cf2f05",
+ "5bb7a08aea1cfa39f1a947ab",
+ "5bb8a49aea1cfa39f1aa7f75",
+ "5bbb6eb2ea1cfa39f1af7e0c",
+ "5bc5f0e896b66a2cd8f9bd36",
+ "5bccd6beca24970bce448134",
+ "5bce7ac9ca24970bce4934b6",
+ "5bcf979a6d5f586b95c258cd",
+ "5bd43b4ba6b28b1ee86b92dd",
+ "5be3a5fb8cfdd56947f6b67c",
+ "5be3ae47f44e235bdbbc9771",
+ "5be4ab93870d330ff2dce134",
+ "5be47bf9b18881428d8fbc1d",
+ "5be883a4f98cee15019d5b83",
+ "5bea87f4abd34c35e1860ab5",
+ "5beb6e66abd34c35e18e66b9",
+ "5bf3a82cd439231948877aed",
+ "5bf7d63575c26f32dbf7413b",
+ "5bf17c0fd439231948355385",
+ "5bf26cbbd43923194854b270",
+ "5bf03590d4392319481971dc",
+ "5bf18642c50e6f7f8bdbd492",
+ "5bf21799d43923194842c001",
+ "5bfc9d5aec61ca1dd69132a2",
+ "5bfd0f32ec61ca1dd69dc77b",
+ "5bfe5ae0fe0ea555e6a969ca",
+ "5bff3c5cfe0ea555e6bcbf3a",
+ "5c0d13b795da9479e12e2ee9",
+ "5c1af2e2bee9a723c963d019",
+ "5c1b1500bee9a723c96c3e78",
+ "5c1dbf200843bc542d8ef8c4",
+ "5c1f33f1d33e1f2e4aa6dda4",
+ "5c2b3ed5e611832e8aed46bf",
+ "5c20ca3a0843bc542d94e3e2",
+ "5c062d84a96e33018ff6f0a6",
+ "5c189f2326173c3a09ed7ef3",
+ "5c1892f726173c3a09ea9aeb",
+ "5c34300a73a8df509add216d",
+ "5c34529873a8df509ae57b58",
+ "000000000000000000000006",
+ "000000000000000000000007",
+ "000000000000000000000008",
+ "000000000000000000000009",
+ "000000000000000000000010",
+ "000000000000000000000011",
+ "000000000000000000000012",
+ "000000000000000000000015",
+ "000000000000000000000016",
+ "000000000000000000000017",
+ "000000000000000000000018",
+ "000000000000000000000019",
+ "56d73ba74bd29b8c35abade2",
+ "56f34064e296120e10484dc4",
+ "57a4a7bb6b9272286e26dc18",
+ "57f8d9bbe73f6760f10e916a",
+ "58a0a2f33d0b4542479a11b1",
+ "58a0dd1a3d0b4542479a28f3",
+ "58a1a7914a4d262a170b1101",
+ "58a1bc804a4d262a170b2f01",
+ "58a1d9d14a4d262a170b58fe",
+ "58a01dea38486e3c98475871",
+ "58a1f5d74a4d262a170b65fc",
+ "58a2a09e156b87103d3d668c",
+ "58a2d9c3156b87103d3da90f",
+ "58a3ccb0156b87103d3e4332",
+ "58a3f2f8156b87103d3e5838",
+ "58a3f6c0156b87103d3e5971",
+ "58a3fc95156b87103d3e5d9b",
+ "58a07ce53d0b45424799fdde",
+ "58a07f233d0b45424799ffe7",
+ "58a44df2156b87103d3ee239",
+ "58a164f73d0b4542479a7a8e",
+ "58a0365e38486e3c984783eb",
+ "58a439cf156b87103d3ec885",
+ "58a464aa156b87103d3eec04",
+ "58a4452f156b87103d3ed55b",
+ "58a160983d0b4542479a7347",
+ "58a186444a4d262a170ae3ae",
+ "58a285424a4d262a170baf3e",
+ "58a41819156b87103d3e92a5",
+ "58a44463156b87103d3ed45e",
+ "58a47552156b87103d3f00a4",
+ "58c4bb4f4a69c55606122be4",
+ "58c6451e4a69c556061894f1",
+ "58ca7014affdfd07c70a95ce",
+ "58cf4771d0f5fb221defe6da",
+ "58d36897f387231e6c929903",
+ "58eaf1513353456af3a1682a",
+ "58f7f7299f5b5647873cb110",
+ "58f73e7c9f5b56478738929f",
+ "59a8f851597729752c31e7e0",
+ "59a452bf9b460239aa5d1c72",
+ "59a9619a825418241fb88191",
+ "59acd2f4b891807f439c8992",
+ "59bf97fe7e7b31545da34439",
+ "59c1c3e2fd6e3d4ead9f1013",
+ "59d2657f82ca7774b1ec081d",
+ "59da1fb88a126011d0394ae9",
+ "59e75a2ca9e91f2c5526005d",
+ "59e864b2a9e91f2c5529325f",
+ "59ecfd02e225f6492d20fcc9",
+ "59f37f74b45be2233001ba18",
+ "59f70ab1e5c5d366af29bf3e",
+ "59f87d0bfa6280566fb38c9a",
+ "59f363a8b45be22330016cad",
+ "564a27b26d07883f460d8ab0",
+ "565fb1dead14d4154dae2b94",
+ "567a0fb0a825d2fb79ac9a20",
+ "569b92eb826bcba945ca002b",
+ "576fefa017ce5a16397e87fd",
+ "584a7333fe3cb463906c9fe6",
+ "584aa8e9fe3cb463906cc7d0",
+ "584ad76bfe3cb463906ce6dc",
+ "584af003fe3cb463906d0e9b",
+ "584b9a747072670e72bfc49d",
+ "584b671f7072670e72bfaaf8",
+ "584b81747072670e72bfbbfd",
+ "584ba35f7072670e72bfca4d",
+ "584ba5977072670e72bfcc2d",
+ "584bc53c7072670e72bfe85f",
+ "584bc3997072670e72bfe58d",
+ "584bc4407072670e72bfe665",
+ "584bd5587072670e72bffe39",
+ "584bdadf7072670e72c0005c",
+ "584be5ed7072670e72c007b3",
+ "584c9ad27072670e72c060c5",
+ "584c9cc67072670e72c063a1",
+ "584c58b77072670e72c03990",
+ "584cea557072670e72c07fb4",
+ "584d19d47072670e72c0c6c0",
+ "584dfe467072670e72c1665a",
+ "584e875c7072670e72c1ec94",
+ "584e05667072670e72c17167",
+ "584f94e87072670e72c2d3f7",
+ "584fdffd7072670e72c32dc7",
+ "584fe07f7072670e72c32e59",
+ "585a2a71b338a62ad50138dc",
+ "585a206ab338a62ad501298f",
+ "585a217cb338a62ad5012b38",
+ "585b34afb338a62ad501e836",
+ "585bb25fc49c8507c3ce7812",
+ "585bbe55c49c8507c3ce81cd",
+ "585d6c8a2a57cc11d4920a1e",
+ "585e54c72a57cc11d492f71a",
+ "585e34302a57cc11d492be30",
+ "585ee0632a57cc11d4933608",
+ "585f9661712e2761468dabca",
+ "585ffe9a712e2761468df643",
+ "586a37ec9d1b5e34c28184fc",
+ "586a515a9d1b5e34c281b431",
+ "586a94939d1b5e34c2823b5d",
+ "586abc689d1b5e34c2826360",
+ "586b0e219d1b5e34c2828862",
+ "586b3db89d1b5e34c282cd52",
+ "586b4c459d1b5e34c282e66d",
+ "586b7d7d9d1b5e34c283359e",
+ "586b8f149d1b5e34c283497c",
+ "586b8f629d1b5e34c28349d6",
+ "586c4c4d9d1b5e34c28391a1",
+ "586c5b5b9d1b5e34c2839a5b",
+ "586c9fdf9d1b5e34c283b657",
+ "586c48329d1b5e34c2838e80",
+ "586caab99d1b5e34c283c213",
+ "586cd0779d1b5e34c28403a7",
+ "586d6d249d1b5e34c284b80e",
+ "586d8a029d1b5e34c284c948",
+ "586d55af9d1b5e34c284a999",
+ "586d07869d1b5e34c2842e5b",
+ "586d27489d1b5e34c28453af",
+ "586df9849d1b5e34c28506de",
+ "586e279c9d1b5e34c2852180",
+ "587bc5ec2366dd5d06e262c1",
+ "587c1abf2366dd5d06e28901",
+ "587c03f12366dd5d06e27722",
+ "587c19da2366dd5d06e2877b",
+ "587c31b92366dd5d06e2a9dc",
+ "587c87d02366dd5d06e2f989",
+ "587c97a52366dd5d06e30a96",
+ "587c45192366dd5d06e2c0eb",
+ "587cec702366dd5d06e37862",
+ "587cef0a2366dd5d06e379e3",
+ "587db5872366dd5d06e3e0af",
+ "587e2b1d2366dd5d06e41af0",
+ "587e2ea62366dd5d06e41f2e",
+ "587e5cb52366dd5d06e4486e",
+ "587eb1822366dd5d06e45f29",
+ "587f365d2366dd5d06e4906e",
+ "588a9c5fec4d5a1c088ec350",
+ "588a34cfec4d5a1c088ea8d1",
+ "588ab5bdec4d5a1c088ed60f",
+ "588aff9d90414422fbe7885a",
+ "588b20d290414422fbe79f40",
+ "588c08d590414422fbe8200b",
+ "588c203d90414422fbe8319e",
+ "588c989a90414422fbe86d96",
+ "588ca09d90414422fbe871a1",
+ "588cce2190414422fbe88520",
+ "588cd5ef90414422fbe8875c",
+ "588cf0ad90414422fbe8a20f",
+ "588e0d8c90414422fbe8f8b2",
+ "588e01c490414422fbe8ee2a",
+ "588e35e690414422fbe90a53",
+ "588f017e90414422fbe9b74b",
+ "588f095190414422fbe9c1ee",
+ "589aca717dc3d323d55671c4",
+ "589af2c97dc3d323d55691e8",
+ "589b49ea7dc3d323d556d9b4",
+ "589b04287dc3d323d556a185",
+ "589bf6a57dc3d323d55743ab",
+ "589c3c497dc3d323d5578468",
+ "589c3c577dc3d323d5578480",
+ "589c300f7dc3d323d5577926",
+ "589c24527dc3d323d5577126",
+ "589c35457dc3d323d5577d8d",
+ "589ca6a6b896147a1b73aff7",
+ "589d1e1fb896147a1b73ee5b",
+ "589d5c58b896147a1b742256",
+ "589d95538fa2cf375df3317b",
+ "589df0ffb504a864ad63521a",
+ "589ea316b504a864ad639a2b",
+ "589ec97cb504a864ad63adc3",
+ "589f214338486e3c9846f123",
+ "589fdfe738486e3c984736cf",
+ "590c2d70336bb52a190be886",
+ "590f91851225725be9e25d4e",
+ "591a467a6109e14d4f09b776",
+ "591cf3033162411cf9047f37",
+ "591ea44850991c70dc99a207",
+ "599aa591d5b41f366fed0d58",
+ "5643df56138263b51db1b5f3",
+ "5644bdac138263b51db9f669",
+ "5692a4c2adafac1f14201821",
+ "5850d4f97072670e72c425d6",
+ "5854c405804be105852330fe",
+ "5855a4fc804be1058523bd75",
+ "5856ac15804be105852419d8",
+ "5856ae8b804be10585241bae",
+ "5856b460804be10585242059",
+ "5857aa5ab338a62ad5ff4dbe",
+ "5857acf8b338a62ad5ff5107",
+ "5858db6cb338a62ad500103b",
+ "5858dbcab338a62ad5001081",
+ "5859d84fb338a62ad500e5cf",
+ "5861d8ea712e2761468f3cb3",
+ "5863edf8712e27614690cce0",
+ "5864a935712e2761469111b4",
+ "5864b076712e27614691197e",
+ "5864da88712e276146913d8b",
+ "5865f4a8712e27614691e39b",
+ "5867a434833dfe3f7b88edaf",
+ "5868cd15833dfe3f7b89bfa3",
+ "5880b3692366dd5d06e5d534",
+ "5880e3422366dd5d06e5ff8e",
+ "5880f0ef2366dd5d06e6166e",
+ "5881d2bfb6844814c136a119",
+ "5881f11d8ce2c2754d0714c3",
+ "5881fee18ce2c2754d0723f8",
+ "5882cda2b116682b4adebd25",
+ "5882d58fb116682b4adec7db",
+ "5884c256932ba84fbed70bf5",
+ "5884cc13932ba84fbed71ec4",
+ "5885bc5296fa095e0671a7f0",
+ "5886d14cb791366d617a362c",
+ "5888becfc02346100f4b0b21",
+ "5888e408c02346100f4b1a29",
+ "5889da66ec4d5a1c088e5187",
+ "5889e344ec4d5a1c088e59be",
+ "5889e754ec4d5a1c088e60ba",
+ "5890c16b90414422fbeb0262",
+ "5891d8ae9a8c0314c5cd30ab",
+ "5891d0479a8c0314c5cd2abd",
+ "5891ecf19a8c0314c5cd490a",
+ "5892c0cd9a8c0314c5cdc977",
+ "5894ab309a8c0314c5cee57d",
+ "5895a6a89a8c0314c5cfca7c",
+ "5895b8c29a8c0314c5cfd051",
+ "5895d38f9a8c0314c5cfe50c",
+ "5895f2329a8c0314c5d00117",
+ "5896bb989a8c0314c5d086b6",
+ "5896ebf39a8c0314c5d0a8c4",
+ "5898b1bac9dccc22987b7f74",
+ "5898b6ffc9dccc22987b8a03",
+ "5898b31cc9dccc22987b82ec",
+ "5898bbaac9dccc22987b8eba",
+ "5899cfa6b76d7a3780a4cb64",
+ "5899e5dcb76d7a3780a4ecc1",
+ "5947b62af1b45630bd0c2a02",
+ "57102be2877e1421026358af",
+ "57153d4031bb9900425bde85",
+ "57177cd7fb8d93461afc4527",
+ "58497cdf97b73e0b090c4273",
+ "58500b007072670e72c35588",
+ "58510bf97072670e72c46ddf",
+ "58522bd56789802282f2ecb3",
+ "58524a2e0e7012308944bcf3",
+ "58524a080e7012308944bcbf",
+ "58524c1d0e7012308944bfda",
+ "58524f170e7012308944c200",
+ "58529a4e0e70123089454c6f",
+ "58551bdf804be1058523556d",
+ "58568c9a804be10585240b03",
+ "58574b35804be105852455fd",
+ "58577c60b338a62ad5ff1564",
+ "58592d69b338a62ad5007a74",
+ "58598db2b338a62ad500bc38",
+ "58625f42712e2761468fb44c",
+ "58651bcc712e2761469166dc",
+ "58660e79712e27614691fe3d",
+ "58669aad712e27614692834c",
+ "58669c02712e27614692851a",
+ "58676c36833dfe3f7b88b7f2",
+ "58678b2d833dfe3f7b88e244",
+ "58790c82ce911104a3467c88",
+ "58800b0b2366dd5d06e5312d",
+ "58805eac2366dd5d06e56460",
+ "58806e422366dd5d06e57bb6",
+ "58831d060db9bf59bf8ab98b",
+ "58851ebb932ba84fbed7abad",
+ "58871dc3b791366d617a55ff",
+ "58873cabb791366d617a65a7",
+ "58873d44b791366d617a65dd",
+ "58888b3dc02346100f4af665",
+ "58897f62c02346100f4b8ee6",
+ "58933bac9a8c0314c5ce3508",
+ "58938e6d9a8c0314c5ce726f",
+ "58951cb49a8c0314c5cf4d5e",
+ "58970fd09a8c0314c5d0e383",
+ "58977ef09a8c0314c5d17b26",
+ "59056e6760bb961de55f3501",
+ "59071f2e5a6dbd3af4130f98",
+ "59102c811225725be9e64149",
+ "59338e76772c3e6384afbb15",
+ "59350ca084b7f26bf5ce6eb8",
+ "59397e493a87372f2c9e882b",
+ "59521e0b9096412211c2aa9d",
+ "59817e4a1bd4b175e7038d19",
+ "567884f58d2828b95e3c8eba",
+ "585559d9804be10585238ddf",
+ "585834cdb338a62ad5ffab4d",
+ "586082d8712e2761468e2877",
+ "586133c2712e2761468ecfe3",
+ "586281d2712e2761468fcaa2",
+ "586316e5712e276146903c4d",
+ "586326ad712e276146904571",
+ "586375c9712e276146907429",
+ "586389c9712e276146908da6",
+ "586496fa712e2761469108e7",
+ "586669c6712e27614692597a",
+ "586913a49d1b5e34c2808b02",
+ "586922da9d1b5e34c2809ff3",
+ "588185d8dfb7a15588a114a3",
+ "588305ed0db9bf59bf8a8c80",
+ "588315c60db9bf59bf8aa928",
+ "588332ee0db9bf59bf8ae9c3",
+ "588457b8932ba84fbed69942",
+ "588519d5932ba84fbed7a04a",
+ "588824d1b791366d617adeef",
+ "588857f6c02346100f4ac09f",
+ "589145ef90414422fbeb2e08",
+ "589433fa9a8c0314c5ce9656",
+ "589765d39a8c0314c5d16b12",
+ "5851165f7072670e72c4860d",
+ "5859341ab338a62ad500848d",
+ "5862388b712e2761468f84aa",
+ "5863915b712e276146909135",
+ "5866445b712e27614692383e",
+ "5866500d712e2761469240fd",
+ "5867785a833dfe3f7b88c764",
+ "5867969c833dfe3f7b88e8bc",
+ "5868040c833dfe3f7b8934f7",
+ "5880675a2366dd5d06e570ca",
+ "5882372c8ce2c2754d076af0",
+ "5883535e932ba84fbed5ad07",
+ "5888358cb791366d617af69d",
+ "5890330d90414422fbeaa0cb",
+ "5897076e9a8c0314c5d0d31b",
+ "5940564ec2d9527ab869f7e2",
+ "5947719bf1b45630bd096665",
+ "5948194ff1b45630bd0f47e3",
+ "5950206a41b158666ac50506",
+ "5983012d1bd4b175e70c985a",
+ "58586810b338a62ad5ffc20c",
+ "58592046b338a62ad5006b33",
+ "58592854b338a62ad500750a",
+ "58596531b338a62ad500aace",
+ "58818685dfb7a15588a11626",
+ "58829563f42b1d3ee3ec835f",
+ "58894345c02346100f4b51ca",
+ "585289980e7012308945276a",
+ "585369770e7012308945c709",
+ "585373640e7012308945cab9",
+ "588230658ce2c2754d076728",
+ "589388059a8c0314c5ce718b",
+ "595979485ec6a95e86a58c8d",
+ "5841206219d291325678ca90",
+ "58563650804be1058523da55",
+ "58564084804be1058523e116",
+ "58636467712e27614690661f",
+ "58647495712e27614690f36d",
+ "58654563712e276146918643",
+ "58664251712e276146923738",
+ "588084032366dd5d06e59e82",
+ "588159582366dd5d06e66877",
+ "5890279190414422fbea9734",
+ "5890523090414422fbeab3f0",
+ "5890641690414422fbeabbe7",
+ "585203546789802282f2aaf5",
+ ]
+
+ # Final sequences to be used after filtering (some of the sequences have incorrect/low quality depth)
+ # Generally water bodies like lakes have incorrect depth
+ # Filtered out sequences:
+ # "5692a4c2adafac1f14201821" # Incorrect Depth
+ # "5864a935712e2761469111b4" # Noisy Depth and artifacts near horizon
+ # "59f87d0bfa6280566fb38c9a" # Object-centric, noise with background and sometimes in front of object
+ # "58a44463156b87103d3ed45e" # Very noisy depth in background
+ # "5c2b3ed5e611832e8aed46bf" # Depth occluded by artifacts
+ # "5bf03590d4392319481971dc" # Depth occluded by artifacts
+ # "00000000000000000000001a" # Largely incomplete depth
+ # "00000000000000000000000c" # Imprecise depth for buildings
+ # "000000000000000000000000" # Incorrect depth for planar terrain
+ self.scenes = [
+ "00000000000000000000000a",
+ "00000000000000000000000b",
+ "00000000000000000000000d",
+ "00000000000000000000000e",
+ "00000000000000000000000f",
+ "000000000000000000000001",
+ "00000000000000000000001b",
+ "00000000000000000000001d",
+ "000000000000000000000002",
+ "000000000000000000000003",
+ "000000000000000000000004",
+ "000000000000000000000005",
+ "5a2a95f032a1c655cfe3de62",
+ "5a2af22b32a1c655cfe46013",
+ "5a2ba6de32a1c655cfe51b79",
+ "5a3b9731e24cd76dad1a5f1b",
+ "5a3ca9cb270f0e3f14d0eddb",
+ "5a3cb4e4270f0e3f14d12f43",
+ "5a03e732454a8a7ec672776c",
+ "5a3f4aba5889373fbbc5d3b5",
+ "5a4a38dad38c8a075495b5d2",
+ "5a5a1e48d62c7a12d5d00e47",
+ "5a6b1c418d100c2f8fdc4411",
+ "5a6feeb54a7fbc3f874f9db7",
+ "5a7cb1d6fe5c0d6fb53e64fb",
+ "5a7d3db14989e929563eb153",
+ "5a8aa0fab18050187cbe060e",
+ "5a9e5df65baeef72b4a021cd",
+ "5a48ba95c7dab83a7d7b44ed",
+ "5a48c4e9c7dab83a7d7b5cc7",
+ "5a48d4b2c7dab83a7d7b9851",
+ "5a69c47d0d5d0a7f3b2e9752",
+ "5a77b46b318efe6c6736e68a",
+ "5a355c271b63f53d5970f362",
+ "5a489fb1c7dab83a7d7b1070",
+ "5a533e8034d7582116e34209",
+ "5a562fc7425d0f5186314725",
+ "5a572fd9fc597b0478a81d14",
+ "5a588a8193ac3d233f77fbca",
+ "5a618c72784780334bc1972d",
+ "5a752d42acc41e2423f17674",
+ "5a969eea91dfc339a9a3ad2c",
+ "5a8315f624b8e938486e0bd8",
+ "5a57542f333d180827dfc132",
+ "5a0271884e62597cdee0d0eb",
+ "5a6400933d809f1d8200af15",
+ "5a6464143d809f1d8208c43c",
+ "5a563183425d0f5186314855",
+ "5aa0f9d7a9efce63548c69a1",
+ "5aa0f478a9efce63548c1cb4",
+ "5aa7db90bfdd572271e95246",
+ "5aa235f64a17b335eeaf9609",
+ "5aa515e613d42d091d29d300",
+ "5aa1196ea9efce63548ed649",
+ "5aaadd4cbc13235570d178a7",
+ "5ab6af12ac4291329b1072ab",
+ "5ab7e00aac4291329b15864d",
+ "5ab8b8e029f5351f7f2ccf59",
+ "5ab74bf2ac4291329b11e879",
+ "5ab85f1dac4291329b17cb50",
+ "5ab8713ba3799a1d138bd69a",
+ "5abc2506b53b042ead637d86",
+ "5acc7459a7853c4b5ebbef59",
+ "5acf8ca0f3d8a750097e4b15",
+ "5adc6bd52430a05ecb2ffb85",
+ "5ae2e9c5fe405c5076abc6b2",
+ "5af02e904c8216544b4ab5a2",
+ "5af28cea59bc705737003253",
+ "5af545d0559359053d25dcf5",
+ "5afacb69ab00705d0cefdd5b",
+ "5b2c67b5e0878c381608b8d8",
+ "5b3b2b9e8d46a939f933fdc0",
+ "5b3b353d8d46a939f93524b9",
+ "5b6e716d67b396324c2d77cb",
+ "5b6eff8b67b396324c5b2672",
+ "5b7a3890fc8fcf6781e2593a",
+ "5b21e18c58e2823a67a10dd8",
+ "5b60fa0c764f146feef84df0",
+ "5b69cc0cb44b61786eb959bf",
+ "5b78e57afc8fcf6781d0c3ba",
+ "5b192eb2170cf166458ff886",
+ "5b558a928bbfb62204e77ba2",
+ "5b864d850d072a699b32f4ae",
+ "5b908d3dc6ab78485f3d24a9",
+ "5b950c71608de421b1e7318f",
+ "5b4933abf2b5f44e95de482a",
+ "5b08286b2775267d5b0634ba",
+ "5b37189a35304b6f75e7583e",
+ "5b271079e0878c3816dacca4",
+ "5b22269758e2823a67a3bd03",
+ "5b62647143840965efc0dbde",
+ "5ba19a8a360c7c30c1c169df",
+ "5ba75d79d76ffa2c86cf2f05",
+ "5bb7a08aea1cfa39f1a947ab",
+ "5bb8a49aea1cfa39f1aa7f75",
+ "5bbb6eb2ea1cfa39f1af7e0c",
+ "5bc5f0e896b66a2cd8f9bd36",
+ "5bccd6beca24970bce448134",
+ "5bce7ac9ca24970bce4934b6",
+ "5bcf979a6d5f586b95c258cd",
+ "5bd43b4ba6b28b1ee86b92dd",
+ "5be3a5fb8cfdd56947f6b67c",
+ "5be3ae47f44e235bdbbc9771",
+ "5be4ab93870d330ff2dce134",
+ "5be47bf9b18881428d8fbc1d",
+ "5be883a4f98cee15019d5b83",
+ "5bea87f4abd34c35e1860ab5",
+ "5beb6e66abd34c35e18e66b9",
+ "5bf3a82cd439231948877aed",
+ "5bf7d63575c26f32dbf7413b",
+ "5bf17c0fd439231948355385",
+ "5bf26cbbd43923194854b270",
+ "5bf18642c50e6f7f8bdbd492",
+ "5bf21799d43923194842c001",
+ "5bfc9d5aec61ca1dd69132a2",
+ "5bfd0f32ec61ca1dd69dc77b",
+ "5bfe5ae0fe0ea555e6a969ca",
+ "5bff3c5cfe0ea555e6bcbf3a",
+ "5c0d13b795da9479e12e2ee9",
+ "5c1af2e2bee9a723c963d019",
+ "5c1b1500bee9a723c96c3e78",
+ "5c1dbf200843bc542d8ef8c4",
+ "5c1f33f1d33e1f2e4aa6dda4",
+ "5c20ca3a0843bc542d94e3e2",
+ "5c062d84a96e33018ff6f0a6",
+ "5c189f2326173c3a09ed7ef3",
+ "5c1892f726173c3a09ea9aeb",
+ "5c34300a73a8df509add216d",
+ "5c34529873a8df509ae57b58",
+ "000000000000000000000006",
+ "000000000000000000000007",
+ "000000000000000000000008",
+ "000000000000000000000009",
+ "000000000000000000000010",
+ "000000000000000000000011",
+ "000000000000000000000012",
+ "000000000000000000000015",
+ "000000000000000000000016",
+ "000000000000000000000017",
+ "000000000000000000000018",
+ "000000000000000000000019",
+ "56d73ba74bd29b8c35abade2",
+ "56f34064e296120e10484dc4",
+ "57a4a7bb6b9272286e26dc18",
+ "57f8d9bbe73f6760f10e916a",
+ "58a0a2f33d0b4542479a11b1",
+ "58a0dd1a3d0b4542479a28f3",
+ "58a1a7914a4d262a170b1101",
+ "58a1bc804a4d262a170b2f01",
+ "58a1d9d14a4d262a170b58fe",
+ "58a01dea38486e3c98475871",
+ "58a1f5d74a4d262a170b65fc",
+ "58a2a09e156b87103d3d668c",
+ "58a2d9c3156b87103d3da90f",
+ "58a3ccb0156b87103d3e4332",
+ "58a3f2f8156b87103d3e5838",
+ "58a3f6c0156b87103d3e5971",
+ "58a3fc95156b87103d3e5d9b",
+ "58a07ce53d0b45424799fdde",
+ "58a07f233d0b45424799ffe7",
+ "58a44df2156b87103d3ee239",
+ "58a164f73d0b4542479a7a8e",
+ "58a0365e38486e3c984783eb",
+ "58a439cf156b87103d3ec885",
+ "58a464aa156b87103d3eec04",
+ "58a4452f156b87103d3ed55b",
+ "58a160983d0b4542479a7347",
+ "58a186444a4d262a170ae3ae",
+ "58a285424a4d262a170baf3e",
+ "58a41819156b87103d3e92a5",
+ "58a47552156b87103d3f00a4",
+ "58c4bb4f4a69c55606122be4",
+ "58c6451e4a69c556061894f1",
+ "58ca7014affdfd07c70a95ce",
+ "58cf4771d0f5fb221defe6da",
+ "58d36897f387231e6c929903",
+ "58eaf1513353456af3a1682a",
+ "58f7f7299f5b5647873cb110",
+ "58f73e7c9f5b56478738929f",
+ "59a8f851597729752c31e7e0",
+ "59a452bf9b460239aa5d1c72",
+ "59a9619a825418241fb88191",
+ "59acd2f4b891807f439c8992",
+ "59bf97fe7e7b31545da34439",
+ "59c1c3e2fd6e3d4ead9f1013",
+ "59d2657f82ca7774b1ec081d",
+ "59da1fb88a126011d0394ae9",
+ "59e75a2ca9e91f2c5526005d",
+ "59e864b2a9e91f2c5529325f",
+ "59ecfd02e225f6492d20fcc9",
+ "59f37f74b45be2233001ba18",
+ "59f70ab1e5c5d366af29bf3e",
+ "59f363a8b45be22330016cad",
+ "564a27b26d07883f460d8ab0",
+ "565fb1dead14d4154dae2b94",
+ "567a0fb0a825d2fb79ac9a20",
+ "569b92eb826bcba945ca002b",
+ "576fefa017ce5a16397e87fd",
+ "584a7333fe3cb463906c9fe6",
+ "584aa8e9fe3cb463906cc7d0",
+ "584ad76bfe3cb463906ce6dc",
+ "584af003fe3cb463906d0e9b",
+ "584b9a747072670e72bfc49d",
+ "584b671f7072670e72bfaaf8",
+ "584b81747072670e72bfbbfd",
+ "584ba35f7072670e72bfca4d",
+ "584ba5977072670e72bfcc2d",
+ "584bc53c7072670e72bfe85f",
+ "584bc3997072670e72bfe58d",
+ "584bc4407072670e72bfe665",
+ "584bd5587072670e72bffe39",
+ "584bdadf7072670e72c0005c",
+ "584be5ed7072670e72c007b3",
+ "584c9ad27072670e72c060c5",
+ "584c9cc67072670e72c063a1",
+ "584c58b77072670e72c03990",
+ "584cea557072670e72c07fb4",
+ "584d19d47072670e72c0c6c0",
+ "584dfe467072670e72c1665a",
+ "584e875c7072670e72c1ec94",
+ "584e05667072670e72c17167",
+ "584f94e87072670e72c2d3f7",
+ "584fdffd7072670e72c32dc7",
+ "584fe07f7072670e72c32e59",
+ "585a2a71b338a62ad50138dc",
+ "585a206ab338a62ad501298f",
+ "585a217cb338a62ad5012b38",
+ "585b34afb338a62ad501e836",
+ "585bb25fc49c8507c3ce7812",
+ "585bbe55c49c8507c3ce81cd",
+ "585d6c8a2a57cc11d4920a1e",
+ "585e54c72a57cc11d492f71a",
+ "585e34302a57cc11d492be30",
+ "585ee0632a57cc11d4933608",
+ "585f9661712e2761468dabca",
+ "585ffe9a712e2761468df643",
+ "586a37ec9d1b5e34c28184fc",
+ "586a515a9d1b5e34c281b431",
+ "586a94939d1b5e34c2823b5d",
+ "586abc689d1b5e34c2826360",
+ "586b0e219d1b5e34c2828862",
+ "586b3db89d1b5e34c282cd52",
+ "586b4c459d1b5e34c282e66d",
+ "586b7d7d9d1b5e34c283359e",
+ "586b8f149d1b5e34c283497c",
+ "586b8f629d1b5e34c28349d6",
+ "586c4c4d9d1b5e34c28391a1",
+ "586c5b5b9d1b5e34c2839a5b",
+ "586c9fdf9d1b5e34c283b657",
+ "586c48329d1b5e34c2838e80",
+ "586caab99d1b5e34c283c213",
+ "586cd0779d1b5e34c28403a7",
+ "586d6d249d1b5e34c284b80e",
+ "586d8a029d1b5e34c284c948",
+ "586d55af9d1b5e34c284a999",
+ "586d07869d1b5e34c2842e5b",
+ "586d27489d1b5e34c28453af",
+ "586df9849d1b5e34c28506de",
+ "586e279c9d1b5e34c2852180",
+ "587bc5ec2366dd5d06e262c1",
+ "587c1abf2366dd5d06e28901",
+ "587c03f12366dd5d06e27722",
+ "587c19da2366dd5d06e2877b",
+ "587c31b92366dd5d06e2a9dc",
+ "587c87d02366dd5d06e2f989",
+ "587c97a52366dd5d06e30a96",
+ "587c45192366dd5d06e2c0eb",
+ "587cec702366dd5d06e37862",
+ "587cef0a2366dd5d06e379e3",
+ "587db5872366dd5d06e3e0af",
+ "587e2b1d2366dd5d06e41af0",
+ "587e2ea62366dd5d06e41f2e",
+ "587e5cb52366dd5d06e4486e",
+ "587eb1822366dd5d06e45f29",
+ "587f365d2366dd5d06e4906e",
+ "588a9c5fec4d5a1c088ec350",
+ "588a34cfec4d5a1c088ea8d1",
+ "588ab5bdec4d5a1c088ed60f",
+ "588aff9d90414422fbe7885a",
+ "588b20d290414422fbe79f40",
+ "588c08d590414422fbe8200b",
+ "588c203d90414422fbe8319e",
+ "588c989a90414422fbe86d96",
+ "588ca09d90414422fbe871a1",
+ "588cce2190414422fbe88520",
+ "588cd5ef90414422fbe8875c",
+ "588cf0ad90414422fbe8a20f",
+ "588e0d8c90414422fbe8f8b2",
+ "588e01c490414422fbe8ee2a",
+ "588e35e690414422fbe90a53",
+ "588f017e90414422fbe9b74b",
+ "588f095190414422fbe9c1ee",
+ "589aca717dc3d323d55671c4",
+ "589af2c97dc3d323d55691e8",
+ "589b49ea7dc3d323d556d9b4",
+ "589b04287dc3d323d556a185",
+ "589bf6a57dc3d323d55743ab",
+ "589c3c497dc3d323d5578468",
+ "589c3c577dc3d323d5578480",
+ "589c300f7dc3d323d5577926",
+ "589c24527dc3d323d5577126",
+ "589c35457dc3d323d5577d8d",
+ "589ca6a6b896147a1b73aff7",
+ "589d1e1fb896147a1b73ee5b",
+ "589d5c58b896147a1b742256",
+ "589d95538fa2cf375df3317b",
+ "589df0ffb504a864ad63521a",
+ "589ea316b504a864ad639a2b",
+ "589ec97cb504a864ad63adc3",
+ "589f214338486e3c9846f123",
+ "589fdfe738486e3c984736cf",
+ "590c2d70336bb52a190be886",
+ "590f91851225725be9e25d4e",
+ "591a467a6109e14d4f09b776",
+ "591cf3033162411cf9047f37",
+ "591ea44850991c70dc99a207",
+ "599aa591d5b41f366fed0d58",
+ "5643df56138263b51db1b5f3",
+ "5644bdac138263b51db9f669",
+ "5850d4f97072670e72c425d6",
+ "5854c405804be105852330fe",
+ "5855a4fc804be1058523bd75",
+ "5856ac15804be105852419d8",
+ "5856ae8b804be10585241bae",
+ "5856b460804be10585242059",
+ "5857aa5ab338a62ad5ff4dbe",
+ "5857acf8b338a62ad5ff5107",
+ "5858db6cb338a62ad500103b",
+ "5858dbcab338a62ad5001081",
+ "5859d84fb338a62ad500e5cf",
+ "5861d8ea712e2761468f3cb3",
+ "5863edf8712e27614690cce0",
+ "5864b076712e27614691197e",
+ "5864da88712e276146913d8b",
+ "5865f4a8712e27614691e39b",
+ "5867a434833dfe3f7b88edaf",
+ "5868cd15833dfe3f7b89bfa3",
+ "5880b3692366dd5d06e5d534",
+ "5880e3422366dd5d06e5ff8e",
+ "5880f0ef2366dd5d06e6166e",
+ "5881d2bfb6844814c136a119",
+ "5881f11d8ce2c2754d0714c3",
+ "5881fee18ce2c2754d0723f8",
+ "5882cda2b116682b4adebd25",
+ "5882d58fb116682b4adec7db",
+ "5884c256932ba84fbed70bf5",
+ "5884cc13932ba84fbed71ec4",
+ "5885bc5296fa095e0671a7f0",
+ "5886d14cb791366d617a362c",
+ "5888becfc02346100f4b0b21",
+ "5888e408c02346100f4b1a29",
+ "5889da66ec4d5a1c088e5187",
+ "5889e344ec4d5a1c088e59be",
+ "5889e754ec4d5a1c088e60ba",
+ "5890c16b90414422fbeb0262",
+ "5891d8ae9a8c0314c5cd30ab",
+ "5891d0479a8c0314c5cd2abd",
+ "5891ecf19a8c0314c5cd490a",
+ "5892c0cd9a8c0314c5cdc977",
+ "5894ab309a8c0314c5cee57d",
+ "5895a6a89a8c0314c5cfca7c",
+ "5895b8c29a8c0314c5cfd051",
+ "5895d38f9a8c0314c5cfe50c",
+ "5895f2329a8c0314c5d00117",
+ "5896bb989a8c0314c5d086b6",
+ "5896ebf39a8c0314c5d0a8c4",
+ "5898b1bac9dccc22987b7f74",
+ "5898b6ffc9dccc22987b8a03",
+ "5898b31cc9dccc22987b82ec",
+ "5898bbaac9dccc22987b8eba",
+ "5899cfa6b76d7a3780a4cb64",
+ "5899e5dcb76d7a3780a4ecc1",
+ "5947b62af1b45630bd0c2a02",
+ "57102be2877e1421026358af",
+ "57153d4031bb9900425bde85",
+ "57177cd7fb8d93461afc4527",
+ "58497cdf97b73e0b090c4273",
+ "58500b007072670e72c35588",
+ "58510bf97072670e72c46ddf",
+ "58522bd56789802282f2ecb3",
+ "58524a2e0e7012308944bcf3",
+ "58524a080e7012308944bcbf",
+ "58524c1d0e7012308944bfda",
+ "58524f170e7012308944c200",
+ "58529a4e0e70123089454c6f",
+ "58551bdf804be1058523556d",
+ "58568c9a804be10585240b03",
+ "58574b35804be105852455fd",
+ "58577c60b338a62ad5ff1564",
+ "58592d69b338a62ad5007a74",
+ "58598db2b338a62ad500bc38",
+ "58625f42712e2761468fb44c",
+ "58651bcc712e2761469166dc",
+ "58660e79712e27614691fe3d",
+ "58669aad712e27614692834c",
+ "58669c02712e27614692851a",
+ "58676c36833dfe3f7b88b7f2",
+ "58678b2d833dfe3f7b88e244",
+ "58790c82ce911104a3467c88",
+ "58800b0b2366dd5d06e5312d",
+ "58805eac2366dd5d06e56460",
+ "58806e422366dd5d06e57bb6",
+ "58831d060db9bf59bf8ab98b",
+ "58851ebb932ba84fbed7abad",
+ "58871dc3b791366d617a55ff",
+ "58873cabb791366d617a65a7",
+ "58873d44b791366d617a65dd",
+ "58888b3dc02346100f4af665",
+ "58897f62c02346100f4b8ee6",
+ "58933bac9a8c0314c5ce3508",
+ "58938e6d9a8c0314c5ce726f",
+ "58951cb49a8c0314c5cf4d5e",
+ "58970fd09a8c0314c5d0e383",
+ "58977ef09a8c0314c5d17b26",
+ "59056e6760bb961de55f3501",
+ "59071f2e5a6dbd3af4130f98",
+ "59102c811225725be9e64149",
+ "59338e76772c3e6384afbb15",
+ "59350ca084b7f26bf5ce6eb8",
+ "59397e493a87372f2c9e882b",
+ "59521e0b9096412211c2aa9d",
+ "59817e4a1bd4b175e7038d19",
+ "567884f58d2828b95e3c8eba",
+ "585559d9804be10585238ddf",
+ "585834cdb338a62ad5ffab4d",
+ "586082d8712e2761468e2877",
+ "586133c2712e2761468ecfe3",
+ "586281d2712e2761468fcaa2",
+ "586316e5712e276146903c4d",
+ "586326ad712e276146904571",
+ "586375c9712e276146907429",
+ "586389c9712e276146908da6",
+ "586496fa712e2761469108e7",
+ "586669c6712e27614692597a",
+ "586913a49d1b5e34c2808b02",
+ "586922da9d1b5e34c2809ff3",
+ "588185d8dfb7a15588a114a3",
+ "588305ed0db9bf59bf8a8c80",
+ "588315c60db9bf59bf8aa928",
+ "588332ee0db9bf59bf8ae9c3",
+ "588457b8932ba84fbed69942",
+ "588519d5932ba84fbed7a04a",
+ "588824d1b791366d617adeef",
+ "588857f6c02346100f4ac09f",
+ "589145ef90414422fbeb2e08",
+ "589433fa9a8c0314c5ce9656",
+ "589765d39a8c0314c5d16b12",
+ "5851165f7072670e72c4860d",
+ "5859341ab338a62ad500848d",
+ "5862388b712e2761468f84aa",
+ "5863915b712e276146909135",
+ "5866445b712e27614692383e",
+ "5866500d712e2761469240fd",
+ "5867785a833dfe3f7b88c764",
+ "5867969c833dfe3f7b88e8bc",
+ "5868040c833dfe3f7b8934f7",
+ "5880675a2366dd5d06e570ca",
+ "5882372c8ce2c2754d076af0",
+ "5883535e932ba84fbed5ad07",
+ "5888358cb791366d617af69d",
+ "5890330d90414422fbeaa0cb",
+ "5897076e9a8c0314c5d0d31b",
+ "5940564ec2d9527ab869f7e2",
+ "5947719bf1b45630bd096665",
+ "5948194ff1b45630bd0f47e3",
+ "5950206a41b158666ac50506",
+ "5983012d1bd4b175e70c985a",
+ "58586810b338a62ad5ffc20c",
+ "58592046b338a62ad5006b33",
+ "58592854b338a62ad500750a",
+ "58596531b338a62ad500aace",
+ "58818685dfb7a15588a11626",
+ "58829563f42b1d3ee3ec835f",
+ "58894345c02346100f4b51ca",
+ "585289980e7012308945276a",
+ "585369770e7012308945c709",
+ "585373640e7012308945cab9",
+ "588230658ce2c2754d076728",
+ "589388059a8c0314c5ce718b",
+ "595979485ec6a95e86a58c8d",
+ "5841206219d291325678ca90",
+ "58563650804be1058523da55",
+ "58564084804be1058523e116",
+ "58636467712e27614690661f",
+ "58647495712e27614690f36d",
+ "58654563712e276146918643",
+ "58664251712e276146923738",
+ "588084032366dd5d06e59e82",
+ "588159582366dd5d06e66877",
+ "5890279190414422fbea9734",
+ "5890523090414422fbeab3f0",
+ "5890641690414422fbeabbe7",
+ "585203546789802282f2aaf5",
+ ]
+
+ # Train set sequences after filtering
+ self.train_split_scenes = [
+ "00000000000000000000000b",
+ "00000000000000000000000d",
+ "00000000000000000000000e",
+ "00000000000000000000000f",
+ "000000000000000000000001",
+ "00000000000000000000001b",
+ "00000000000000000000001d",
+ "000000000000000000000002",
+ "000000000000000000000003",
+ "000000000000000000000004",
+ "000000000000000000000005",
+ "5a2a95f032a1c655cfe3de62",
+ "5a2af22b32a1c655cfe46013",
+ "5a2ba6de32a1c655cfe51b79",
+ "5a3b9731e24cd76dad1a5f1b",
+ "5a3ca9cb270f0e3f14d0eddb",
+ "5a3cb4e4270f0e3f14d12f43",
+ "5a03e732454a8a7ec672776c",
+ "5a3f4aba5889373fbbc5d3b5",
+ "5a5a1e48d62c7a12d5d00e47",
+ "5a6b1c418d100c2f8fdc4411",
+ "5a6feeb54a7fbc3f874f9db7",
+ "5a7cb1d6fe5c0d6fb53e64fb",
+ "5a7d3db14989e929563eb153",
+ "5a8aa0fab18050187cbe060e",
+ "5a9e5df65baeef72b4a021cd",
+ "5a48ba95c7dab83a7d7b44ed",
+ "5a48c4e9c7dab83a7d7b5cc7",
+ "5a48d4b2c7dab83a7d7b9851",
+ "5a69c47d0d5d0a7f3b2e9752",
+ "5a77b46b318efe6c6736e68a",
+ "5a355c271b63f53d5970f362",
+ "5a533e8034d7582116e34209",
+ "5a562fc7425d0f5186314725",
+ "5a618c72784780334bc1972d",
+ "5a752d42acc41e2423f17674",
+ "5a969eea91dfc339a9a3ad2c",
+ "5a8315f624b8e938486e0bd8",
+ "5a57542f333d180827dfc132",
+ "5a0271884e62597cdee0d0eb",
+ "5a6400933d809f1d8200af15",
+ "5a6464143d809f1d8208c43c",
+ "5a563183425d0f5186314855",
+ "5aa0f9d7a9efce63548c69a1",
+ "5aa7db90bfdd572271e95246",
+ "5aa235f64a17b335eeaf9609",
+ "5aa515e613d42d091d29d300",
+ "5aa1196ea9efce63548ed649",
+ "5aaadd4cbc13235570d178a7",
+ "5ab6af12ac4291329b1072ab",
+ "5ab7e00aac4291329b15864d",
+ "5ab8b8e029f5351f7f2ccf59",
+ "5ab74bf2ac4291329b11e879",
+ "5ab85f1dac4291329b17cb50",
+ "5ab8713ba3799a1d138bd69a",
+ "5abc2506b53b042ead637d86",
+ "5acc7459a7853c4b5ebbef59",
+ "5acf8ca0f3d8a750097e4b15",
+ "5adc6bd52430a05ecb2ffb85",
+ "5af02e904c8216544b4ab5a2",
+ "5af28cea59bc705737003253",
+ "5af545d0559359053d25dcf5",
+ "5afacb69ab00705d0cefdd5b",
+ "5b3b2b9e8d46a939f933fdc0",
+ "5b3b353d8d46a939f93524b9",
+ "5b6e716d67b396324c2d77cb",
+ "5b6eff8b67b396324c5b2672",
+ "5b7a3890fc8fcf6781e2593a",
+ "5b60fa0c764f146feef84df0",
+ "5b69cc0cb44b61786eb959bf",
+ "5b78e57afc8fcf6781d0c3ba",
+ "5b192eb2170cf166458ff886",
+ "5b558a928bbfb62204e77ba2",
+ "5b908d3dc6ab78485f3d24a9",
+ "5b950c71608de421b1e7318f",
+ "5b08286b2775267d5b0634ba",
+ "5b271079e0878c3816dacca4",
+ "5b22269758e2823a67a3bd03",
+ "5b62647143840965efc0dbde",
+ "5ba19a8a360c7c30c1c169df",
+ "5ba75d79d76ffa2c86cf2f05",
+ "5bb7a08aea1cfa39f1a947ab",
+ "5bb8a49aea1cfa39f1aa7f75",
+ "5bbb6eb2ea1cfa39f1af7e0c",
+ "5bce7ac9ca24970bce4934b6",
+ "5bcf979a6d5f586b95c258cd",
+ "5bd43b4ba6b28b1ee86b92dd",
+ "5be3a5fb8cfdd56947f6b67c",
+ "5be3ae47f44e235bdbbc9771",
+ "5be4ab93870d330ff2dce134",
+ "5be47bf9b18881428d8fbc1d",
+ "5be883a4f98cee15019d5b83",
+ "5bea87f4abd34c35e1860ab5",
+ "5beb6e66abd34c35e18e66b9",
+ "5bf3a82cd439231948877aed",
+ "5bf7d63575c26f32dbf7413b",
+ "5bf17c0fd439231948355385",
+ "5bf21799d43923194842c001",
+ "5bfd0f32ec61ca1dd69dc77b",
+ "5bfe5ae0fe0ea555e6a969ca",
+ "5c0d13b795da9479e12e2ee9",
+ "5c1af2e2bee9a723c963d019",
+ "5c1b1500bee9a723c96c3e78",
+ "5c1dbf200843bc542d8ef8c4",
+ "5c20ca3a0843bc542d94e3e2",
+ "5c062d84a96e33018ff6f0a6",
+ "5c189f2326173c3a09ed7ef3",
+ "5c1892f726173c3a09ea9aeb",
+ "5c34300a73a8df509add216d",
+ "000000000000000000000006",
+ "000000000000000000000007",
+ "000000000000000000000008",
+ "000000000000000000000009",
+ "000000000000000000000010",
+ "000000000000000000000011",
+ "000000000000000000000012",
+ "000000000000000000000015",
+ "000000000000000000000016",
+ "000000000000000000000017",
+ "000000000000000000000018",
+ "000000000000000000000019",
+ "56d73ba74bd29b8c35abade2",
+ "56f34064e296120e10484dc4",
+ "57a4a7bb6b9272286e26dc18",
+ "57f8d9bbe73f6760f10e916a",
+ "58a0a2f33d0b4542479a11b1",
+ "58a0dd1a3d0b4542479a28f3",
+ "58a1a7914a4d262a170b1101",
+ "58a1bc804a4d262a170b2f01",
+ "58a1d9d14a4d262a170b58fe",
+ "58a01dea38486e3c98475871",
+ "58a1f5d74a4d262a170b65fc",
+ "58a2a09e156b87103d3d668c",
+ "58a2d9c3156b87103d3da90f",
+ "58a3ccb0156b87103d3e4332",
+ "58a3f2f8156b87103d3e5838",
+ "58a3f6c0156b87103d3e5971",
+ "58a3fc95156b87103d3e5d9b",
+ "58a07ce53d0b45424799fdde",
+ "58a07f233d0b45424799ffe7",
+ "58a44df2156b87103d3ee239",
+ "58a164f73d0b4542479a7a8e",
+ "58a0365e38486e3c984783eb",
+ "58a439cf156b87103d3ec885",
+ "58a464aa156b87103d3eec04",
+ "58a4452f156b87103d3ed55b",
+ "58a160983d0b4542479a7347",
+ "58a285424a4d262a170baf3e",
+ "58a41819156b87103d3e92a5",
+ "58a47552156b87103d3f00a4",
+ "58c4bb4f4a69c55606122be4",
+ "58c6451e4a69c556061894f1",
+ "58ca7014affdfd07c70a95ce",
+ "58cf4771d0f5fb221defe6da",
+ "58d36897f387231e6c929903",
+ "58eaf1513353456af3a1682a",
+ "58f73e7c9f5b56478738929f",
+ "59a8f851597729752c31e7e0",
+ "59a452bf9b460239aa5d1c72",
+ "59a9619a825418241fb88191",
+ "59bf97fe7e7b31545da34439",
+ "59c1c3e2fd6e3d4ead9f1013",
+ "59d2657f82ca7774b1ec081d",
+ "59da1fb88a126011d0394ae9",
+ "59e75a2ca9e91f2c5526005d",
+ "59e864b2a9e91f2c5529325f",
+ "59ecfd02e225f6492d20fcc9",
+ "59f37f74b45be2233001ba18",
+ "59f70ab1e5c5d366af29bf3e",
+ "59f363a8b45be22330016cad",
+ "564a27b26d07883f460d8ab0",
+ "565fb1dead14d4154dae2b94",
+ "569b92eb826bcba945ca002b",
+ "576fefa017ce5a16397e87fd",
+ "584a7333fe3cb463906c9fe6",
+ "584aa8e9fe3cb463906cc7d0",
+ "584af003fe3cb463906d0e9b",
+ "584b9a747072670e72bfc49d",
+ "584b671f7072670e72bfaaf8",
+ "584b81747072670e72bfbbfd",
+ "584ba35f7072670e72bfca4d",
+ "584ba5977072670e72bfcc2d",
+ "584bc53c7072670e72bfe85f",
+ "584bc3997072670e72bfe58d",
+ "584bc4407072670e72bfe665",
+ "584bd5587072670e72bffe39",
+ "584bdadf7072670e72c0005c",
+ "584be5ed7072670e72c007b3",
+ "584c9ad27072670e72c060c5",
+ "584c9cc67072670e72c063a1",
+ "584cea557072670e72c07fb4",
+ "584d19d47072670e72c0c6c0",
+ "584dfe467072670e72c1665a",
+ "584e875c7072670e72c1ec94",
+ "584e05667072670e72c17167",
+ "584f94e87072670e72c2d3f7",
+ "584fdffd7072670e72c32dc7",
+ "584fe07f7072670e72c32e59",
+ "585a2a71b338a62ad50138dc",
+ "585a206ab338a62ad501298f",
+ "585a217cb338a62ad5012b38",
+ "585b34afb338a62ad501e836",
+ "585bb25fc49c8507c3ce7812",
+ "585bbe55c49c8507c3ce81cd",
+ "585d6c8a2a57cc11d4920a1e",
+ "585e54c72a57cc11d492f71a",
+ "585e34302a57cc11d492be30",
+ "585ee0632a57cc11d4933608",
+ "585f9661712e2761468dabca",
+ "585ffe9a712e2761468df643",
+ "586a37ec9d1b5e34c28184fc",
+ "586a515a9d1b5e34c281b431",
+ "586a94939d1b5e34c2823b5d",
+ "586abc689d1b5e34c2826360",
+ "586b0e219d1b5e34c2828862",
+ "586b3db89d1b5e34c282cd52",
+ "586b4c459d1b5e34c282e66d",
+ "586b7d7d9d1b5e34c283359e",
+ "586b8f149d1b5e34c283497c",
+ "586b8f629d1b5e34c28349d6",
+ "586c4c4d9d1b5e34c28391a1",
+ "586c5b5b9d1b5e34c2839a5b",
+ "586c9fdf9d1b5e34c283b657",
+ "586caab99d1b5e34c283c213",
+ "586cd0779d1b5e34c28403a7",
+ "586d6d249d1b5e34c284b80e",
+ "586d8a029d1b5e34c284c948",
+ "586d55af9d1b5e34c284a999",
+ "586d07869d1b5e34c2842e5b",
+ "586d27489d1b5e34c28453af",
+ "586e279c9d1b5e34c2852180",
+ "587bc5ec2366dd5d06e262c1",
+ "587c1abf2366dd5d06e28901",
+ "587c03f12366dd5d06e27722",
+ "587c19da2366dd5d06e2877b",
+ "587c31b92366dd5d06e2a9dc",
+ "587c87d02366dd5d06e2f989",
+ "587c97a52366dd5d06e30a96",
+ "587c45192366dd5d06e2c0eb",
+ "587cec702366dd5d06e37862",
+ "587cef0a2366dd5d06e379e3",
+ "587db5872366dd5d06e3e0af",
+ "587e2b1d2366dd5d06e41af0",
+ "587e2ea62366dd5d06e41f2e",
+ "587e5cb52366dd5d06e4486e",
+ "587eb1822366dd5d06e45f29",
+ "587f365d2366dd5d06e4906e",
+ "588a9c5fec4d5a1c088ec350",
+ "588a34cfec4d5a1c088ea8d1",
+ "588ab5bdec4d5a1c088ed60f",
+ "588aff9d90414422fbe7885a",
+ "588b20d290414422fbe79f40",
+ "588c08d590414422fbe8200b",
+ "588c203d90414422fbe8319e",
+ "588c989a90414422fbe86d96",
+ "588ca09d90414422fbe871a1",
+ "588cce2190414422fbe88520",
+ "588cd5ef90414422fbe8875c",
+ "588cf0ad90414422fbe8a20f",
+ "588e01c490414422fbe8ee2a",
+ "588e35e690414422fbe90a53",
+ "588f017e90414422fbe9b74b",
+ "588f095190414422fbe9c1ee",
+ "589aca717dc3d323d55671c4",
+ "589af2c97dc3d323d55691e8",
+ "589b49ea7dc3d323d556d9b4",
+ "589b04287dc3d323d556a185",
+ "589bf6a57dc3d323d55743ab",
+ "589c3c497dc3d323d5578468",
+ "589c3c577dc3d323d5578480",
+ "589c24527dc3d323d5577126",
+ "589c35457dc3d323d5577d8d",
+ "589ca6a6b896147a1b73aff7",
+ "589d1e1fb896147a1b73ee5b",
+ "589d5c58b896147a1b742256",
+ "589d95538fa2cf375df3317b",
+ "589df0ffb504a864ad63521a",
+ "589ea316b504a864ad639a2b",
+ "589ec97cb504a864ad63adc3",
+ "589f214338486e3c9846f123",
+ "589fdfe738486e3c984736cf",
+ "590c2d70336bb52a190be886",
+ "591a467a6109e14d4f09b776",
+ "591cf3033162411cf9047f37",
+ "591ea44850991c70dc99a207",
+ "599aa591d5b41f366fed0d58",
+ "5643df56138263b51db1b5f3",
+ "5644bdac138263b51db9f669",
+ "5850d4f97072670e72c425d6",
+ "5854c405804be105852330fe",
+ "5855a4fc804be1058523bd75",
+ "5856ac15804be105852419d8",
+ "5856ae8b804be10585241bae",
+ "5856b460804be10585242059",
+ "5857aa5ab338a62ad5ff4dbe",
+ "5857acf8b338a62ad5ff5107",
+ "5858db6cb338a62ad500103b",
+ "5858dbcab338a62ad5001081",
+ "5859d84fb338a62ad500e5cf",
+ "5861d8ea712e2761468f3cb3",
+ "5863edf8712e27614690cce0",
+ "5864b076712e27614691197e",
+ "5864da88712e276146913d8b",
+ "5865f4a8712e27614691e39b",
+ "5867a434833dfe3f7b88edaf",
+ "5868cd15833dfe3f7b89bfa3",
+ "5880b3692366dd5d06e5d534",
+ "5880e3422366dd5d06e5ff8e",
+ "5880f0ef2366dd5d06e6166e",
+ "5881d2bfb6844814c136a119",
+ "5881f11d8ce2c2754d0714c3",
+ "5881fee18ce2c2754d0723f8",
+ "5882cda2b116682b4adebd25",
+ "5882d58fb116682b4adec7db",
+ "5884c256932ba84fbed70bf5",
+ "5884cc13932ba84fbed71ec4",
+ "5885bc5296fa095e0671a7f0",
+ "5886d14cb791366d617a362c",
+ "5888becfc02346100f4b0b21",
+ "5888e408c02346100f4b1a29",
+ "5889da66ec4d5a1c088e5187",
+ "5889e754ec4d5a1c088e60ba",
+ "5890c16b90414422fbeb0262",
+ "5891d8ae9a8c0314c5cd30ab",
+ "5891d0479a8c0314c5cd2abd",
+ "5891ecf19a8c0314c5cd490a",
+ "5892c0cd9a8c0314c5cdc977",
+ "5894ab309a8c0314c5cee57d",
+ "5895a6a89a8c0314c5cfca7c",
+ "5895b8c29a8c0314c5cfd051",
+ "5895d38f9a8c0314c5cfe50c",
+ "5895f2329a8c0314c5d00117",
+ "5896bb989a8c0314c5d086b6",
+ "5896ebf39a8c0314c5d0a8c4",
+ "5898b1bac9dccc22987b7f74",
+ "5898b6ffc9dccc22987b8a03",
+ "5898bbaac9dccc22987b8eba",
+ "5899cfa6b76d7a3780a4cb64",
+ "5899e5dcb76d7a3780a4ecc1",
+ "57102be2877e1421026358af",
+ "57153d4031bb9900425bde85",
+ "57177cd7fb8d93461afc4527",
+ "58497cdf97b73e0b090c4273",
+ "58500b007072670e72c35588",
+ "58510bf97072670e72c46ddf",
+ "58522bd56789802282f2ecb3",
+ "58524a2e0e7012308944bcf3",
+ "58524a080e7012308944bcbf",
+ "58524c1d0e7012308944bfda",
+ "58524f170e7012308944c200",
+ "58529a4e0e70123089454c6f",
+ "58551bdf804be1058523556d",
+ "58568c9a804be10585240b03",
+ "58574b35804be105852455fd",
+ "58577c60b338a62ad5ff1564",
+ "58592d69b338a62ad5007a74",
+ "58625f42712e2761468fb44c",
+ "58651bcc712e2761469166dc",
+ "58660e79712e27614691fe3d",
+ "58669aad712e27614692834c",
+ "58676c36833dfe3f7b88b7f2",
+ "58678b2d833dfe3f7b88e244",
+ "58800b0b2366dd5d06e5312d",
+ "58805eac2366dd5d06e56460",
+ "58806e422366dd5d06e57bb6",
+ "58831d060db9bf59bf8ab98b",
+ "58851ebb932ba84fbed7abad",
+ "58871dc3b791366d617a55ff",
+ "58873cabb791366d617a65a7",
+ "58873d44b791366d617a65dd",
+ "58888b3dc02346100f4af665",
+ "58933bac9a8c0314c5ce3508",
+ "58938e6d9a8c0314c5ce726f",
+ "58951cb49a8c0314c5cf4d5e",
+ "58970fd09a8c0314c5d0e383",
+ "58977ef09a8c0314c5d17b26",
+ "59056e6760bb961de55f3501",
+ "59071f2e5a6dbd3af4130f98",
+ "59102c811225725be9e64149",
+ "59338e76772c3e6384afbb15",
+ "59350ca084b7f26bf5ce6eb8",
+ "59397e493a87372f2c9e882b",
+ "59521e0b9096412211c2aa9d",
+ "59817e4a1bd4b175e7038d19",
+ "567884f58d2828b95e3c8eba",
+ "585559d9804be10585238ddf",
+ "585834cdb338a62ad5ffab4d",
+ "586082d8712e2761468e2877",
+ "586133c2712e2761468ecfe3",
+ "586281d2712e2761468fcaa2",
+ "586316e5712e276146903c4d",
+ "586326ad712e276146904571",
+ "586375c9712e276146907429",
+ "586389c9712e276146908da6",
+ "586496fa712e2761469108e7",
+ "586669c6712e27614692597a",
+ "586913a49d1b5e34c2808b02",
+ "586922da9d1b5e34c2809ff3",
+ "588185d8dfb7a15588a114a3",
+ "588315c60db9bf59bf8aa928",
+ "588332ee0db9bf59bf8ae9c3",
+ "588519d5932ba84fbed7a04a",
+ "588824d1b791366d617adeef",
+ "588857f6c02346100f4ac09f",
+ "589145ef90414422fbeb2e08",
+ "589433fa9a8c0314c5ce9656",
+ "589765d39a8c0314c5d16b12",
+ "5851165f7072670e72c4860d",
+ "5859341ab338a62ad500848d",
+ "5863915b712e276146909135",
+ "5866445b712e27614692383e",
+ "5866500d712e2761469240fd",
+ "5867785a833dfe3f7b88c764",
+ "5867969c833dfe3f7b88e8bc",
+ "5868040c833dfe3f7b8934f7",
+ "5882372c8ce2c2754d076af0",
+ "5883535e932ba84fbed5ad07",
+ "5888358cb791366d617af69d",
+ "5890330d90414422fbeaa0cb",
+ "5897076e9a8c0314c5d0d31b",
+ "5940564ec2d9527ab869f7e2",
+ "5947719bf1b45630bd096665",
+ "5948194ff1b45630bd0f47e3",
+ "5950206a41b158666ac50506",
+ "5983012d1bd4b175e70c985a",
+ "58586810b338a62ad5ffc20c",
+ "58592046b338a62ad5006b33",
+ "58592854b338a62ad500750a",
+ "58596531b338a62ad500aace",
+ "58818685dfb7a15588a11626",
+ "58829563f42b1d3ee3ec835f",
+ "58894345c02346100f4b51ca",
+ "585289980e7012308945276a",
+ "585369770e7012308945c709",
+ "585373640e7012308945cab9",
+ "588230658ce2c2754d076728",
+ "589388059a8c0314c5ce718b",
+ "595979485ec6a95e86a58c8d",
+ "5841206219d291325678ca90",
+ "58563650804be1058523da55",
+ "58564084804be1058523e116",
+ "58636467712e27614690661f",
+ "58647495712e27614690f36d",
+ "58654563712e276146918643",
+ "58664251712e276146923738",
+ "588084032366dd5d06e59e82",
+ "588159582366dd5d06e66877",
+ "5890279190414422fbea9734",
+ "5890641690414422fbeabbe7",
+ "585203546789802282f2aaf5",
+ ]
+
+ # Validation set sequences after filtering
+ self.val_split_scenes = [
+ "00000000000000000000000a",
+ "5a4a38dad38c8a075495b5d2",
+ "5a489fb1c7dab83a7d7b1070",
+ "5a572fd9fc597b0478a81d14",
+ "5a588a8193ac3d233f77fbca",
+ "5aa0f478a9efce63548c1cb4",
+ "5ae2e9c5fe405c5076abc6b2",
+ "5b2c67b5e0878c381608b8d8",
+ "5b21e18c58e2823a67a10dd8",
+ "5b864d850d072a699b32f4ae",
+ "5b4933abf2b5f44e95de482a",
+ "5b37189a35304b6f75e7583e",
+ "5bc5f0e896b66a2cd8f9bd36",
+ "5bccd6beca24970bce448134",
+ "5bf26cbbd43923194854b270",
+ "5bf18642c50e6f7f8bdbd492",
+ "5bfc9d5aec61ca1dd69132a2",
+ "5bff3c5cfe0ea555e6bcbf3a",
+ "5c1f33f1d33e1f2e4aa6dda4",
+ "5c34529873a8df509ae57b58",
+ "58a186444a4d262a170ae3ae",
+ "58f7f7299f5b5647873cb110",
+ "59acd2f4b891807f439c8992",
+ "567a0fb0a825d2fb79ac9a20",
+ "584ad76bfe3cb463906ce6dc",
+ "584c58b77072670e72c03990",
+ "586c48329d1b5e34c2838e80",
+ "586df9849d1b5e34c28506de",
+ "588e0d8c90414422fbe8f8b2",
+ "589c300f7dc3d323d5577926",
+ "590f91851225725be9e25d4e",
+ "5889e344ec4d5a1c088e59be",
+ "5898b31cc9dccc22987b82ec",
+ "5947b62af1b45630bd0c2a02",
+ "58598db2b338a62ad500bc38",
+ "58669c02712e27614692851a",
+ "58790c82ce911104a3467c88",
+ "58897f62c02346100f4b8ee6",
+ "588305ed0db9bf59bf8a8c80",
+ "588457b8932ba84fbed69942",
+ "5862388b712e2761468f84aa",
+ "5880675a2366dd5d06e570ca",
+ "5890523090414422fbeab3f0",
+ ]
+
+
+class TartanAirV2Splits:
+ """
+ This class contains the information about the splits of the TartanAir V2 dataset.
+ """
+
+ def __init__(self):
+ """
+ Splits of environments with unique geometry selected based on TartanVO & UFM splits.
+ """
+ # Apart from the below 2 splits, all other TAv2 scenes are in the train split
+ # Val split
+ self.val_split_scenes = ["EndofTheWorld", "HongKong", "WesternDesertTown"]
+
+ # Test split
+ self.test_split_scenes = [
+ "DesertGasStation",
+ "OldScandinavia",
+ "PolarSciFi",
+ "Sewerage",
+ "Supermarket",
+ ]
+
+
+class MegaDepthSplits:
+ """
+ This class contains the information about the splits of the MegaDepth dataset.
+ """
+
+ def __init__(self):
+ """
+ Validation split is based on scenes used in DUSt3R.
+ """
+ self.val_split_scenes = ["0015_0", "0015_1", "0022_0"]
+
+
+class SpringSplits:
+ """
+ This class contains the information about the splits of the Spring dataset.
+ """
+
+ def __init__(self):
+ self.val_split_scenes = ["0013", "0023", "0037"]
+
+
+class MPSDSplits:
+ """
+ This class contains the information about the splits of the MPSD dataset.
+ """
+
+ def __init__(self):
+ """
+ Train & Validation split numpy files containing folder names are generated during preprocessing of MPSD dataset.
+ Load the numpy files to get the list of scenes in the train & validation split.
+ A 95% (Train) & 5% (Validation) split is used.
+ """
+ self.train_split_scenes = "load_numpy_file_with_train_scenes"
+ self.val_split_scenes = "load_numpy_file_with_val_scenes"
+
+
+class ScanNetPPSplits:
+ """
+ This class contains the information about the splits of the ScanNetPP dataset.
+ """
+
+ def __init__(self):
+ """
+ Validation & Test split only contains scenes from ScanNet++V2 to prevent data leak with other methods such as DUSt3R during benchmarking.
+
+ Following logic was used to generate the splits:
+ # Select 80%, 10%, 10% of the scenes for train, val, test respectively from ScanNet++ V2 (~300 scene subset; excluding V1 scenes)
+ snpp_v2_test_scenes = np.random.choice(
+ snpp_v2_processed_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False
+ )
+ remaining_scenes = [scene for scene in snpp_v2_processed_scenes if scene not in snpp_v2_test_scenes]
+ snpp_v2_val_scenes = np.random.choice(
+ remaining_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False
+ )
+ snpp_v2_train_scenes = [
+ scene for scene in remaining_scenes if scene not in snpp_v2_val_scenes and scene not in snpp_v2_test_scenes
+ ]
+ """
+ # Validation Scenes
+ self.val_split_scenes = [
+ "1c7a683c92",
+ "2a1b555966",
+ "3a43c7b8d2",
+ "4aef651da7",
+ "06bc6d1b24",
+ "7f22d5ef1b",
+ "7f77abce34",
+ "8ea517a2fc",
+ "29c7afafed",
+ "41eb967018",
+ "77b40ce601",
+ "086f09d6e3",
+ "307e3262f1",
+ "639f2c4d5a",
+ "894dbd41f1",
+ "898a7dfd0c",
+ "2779f8f9e2",
+ "151178afd7",
+ "182932a4f3",
+ "635852d56e",
+ "9906136b57",
+ "af112b8903",
+ "b0f057c684",
+ "b37177e6c8",
+ "b119249da7",
+ "be8367fcbe",
+ "c8fc01c453",
+ "e1fb8626c8",
+ "e2caaaf5b5",
+ "fe3fc057a1",
+ ]
+
+ # Test Scenes
+ self.test_split_scenes = [
+ "0e900bcc5c",
+ "0eba3981c9",
+ "1cbb105c6a",
+ "3c8d535d49",
+ "5d902f1593",
+ "6bd39ac392",
+ "6c14d5fd01",
+ "7c31a42404",
+ "9bfbc75700",
+ "13b4efaf62",
+ "062e5a23a6",
+ "95b9971d01",
+ "246fe09e98",
+ "637a27d04b",
+ "725b8f0cba",
+ "413085a827",
+ "696317583f",
+ "a4c043ac48",
+ "a9e4791c7e",
+ "b0b004c40f",
+ "c3bc5e82c5",
+ "c31ebd4b22",
+ "cba701332a",
+ "cc5ea8026c",
+ "cec8312f4e",
+ "e3b3b0d0c7",
+ "e667e09fe6",
+ "eaa6c90310",
+ "f9397af4cb",
+ "fb893ffaf3",
+ ]
+
+
+class DL3DV10KSplits:
+ """
+ This class contains the information about the splits of the DL3DV-10K dataset.
+ We use the official benchmark split as the val split.
+ """
+
+ def __init__(self):
+ """
+ Validation split is based on DL3DV-Benchmark.
+ """
+ self.val_split_scenes = [
+ "load https://huggingface.co/datasets/DL3DV/DL3DV-Benchmark/raw/main/benchmark-meta.csv \
+ & https://raw.githubusercontent.com/DL3DV-10K/Dataset/main/cache/DL3DV-valid.csv"
+ ]
+
+
+class DTUSplits:
+ """
+ This class contains the information about the splits of the DTU dataset.
+ """
+
+ def __init__(self):
+ """
+ All scenes are in the test split.
+ """
+ self.test_split_scenes = "all"
+
+
+class ETH3DSplits:
+ """
+ This class contains the information about the splits of the ETH3D dataset.
+ """
+
+ def __init__(self):
+ """
+ All scenes are in the test split.
+ """
+ self.test_split_scenes = "all"
diff --git a/mapanything/datasets/wai/__init__.py b/mapanything/datasets/wai/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/datasets/wai/ase.py b/mapanything/datasets/wai/ase.py
new file mode 100644
index 0000000000000000000000000000000000000000..0da1cae1fc056c9be9475973b8431e8c335f1442
--- /dev/null
+++ b/mapanything/datasets/wai/ase.py
@@ -0,0 +1,289 @@
+"""
+ASE Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class ASEWAI(BaseDataset):
+ """
+ ASE dataset containing large diversity of synthetic indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"ase_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="ASE",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/ase", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = ASEWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 518),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = ASEWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 518),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "ASE_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/bedlam.py b/mapanything/datasets/wai/bedlam.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d9e6f8956b67f1f3058c9687f293fd06adcba12
--- /dev/null
+++ b/mapanything/datasets/wai/bedlam.py
@@ -0,0 +1,309 @@
+"""
+Bedlam Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class BedlamWAI(BaseDataset):
+ """
+ Bedlam dataset containing diverse synthetic scenes with humans.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"bedlam_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ # Bedlam scenes have very large number of images
+ # Thus, we use unidirectional covis for faster access
+ view_indices = self._sample_view_indices(
+ num_views_to_sample,
+ num_views_in_scene,
+ pairwise_covisibility,
+ use_bidirectional_covis=False,
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (see through window or horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="Bedlam",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/bedlam", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = BedlamWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = BedlamWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "Bedlam_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/blendedmvs.py b/mapanything/datasets/wai/blendedmvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..706ee6b073d364b8c785853d8f9109157e46476b
--- /dev/null
+++ b/mapanything/datasets/wai/blendedmvs.py
@@ -0,0 +1,315 @@
+"""
+BlendedMVS Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class BlendedMVSWAI(BaseDataset):
+ """
+ BlendedMVS dataset containing object-centric and birds-eye-view scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = False
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"blendedmvs_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ # modalities=["image", "depth", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+ ### HOTFIX: Load required additional masks manually
+ ### Remove once stability issue with scene_meta is fixed
+ mask_path = os.path.join(
+ scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
+ )
+ view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="BlendedMVS",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/blendedmvs", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = BlendedMVSWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = BlendedMVSWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 392),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "BlendedMVS_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/dl3dv.py b/mapanything/datasets/wai/dl3dv.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddfba3b53ea2f973ec867f5050b1f601d4aaa6a3
--- /dev/null
+++ b/mapanything/datasets/wai/dl3dv.py
@@ -0,0 +1,372 @@
+"""
+DL3DV Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.cropping import (
+ rescale_image_and_other_optional_info,
+ resize_with_nearest_interpolation_to_match_aspect_ratio,
+)
+from wai import load_data, load_frame
+
+
+class DL3DVWAI(BaseDataset):
+ """
+ DL3DV dataset containing over 10k in-the-wild and indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ mvs_confidence_filter_thres: float = 0.25,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ mvs_confidence_filter_thres: Confidence threshold to filter MVS depth. Defaults to 0.25.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self.mvs_confidence_filter_thres = mvs_confidence_filter_thres
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = False
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"dl3dv_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0_mvsa_based"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image"],
+ # modalities=[
+ # "image",
+ # "pred_depth/mvsanywhere",
+ # "pred_mask/moge2",
+ # "depth_confidence/mvsanywhere",
+ # ],
+ scene_meta=scene_meta,
+ )
+ ### HOTFIX: Load required additional modalities manually
+ ### Remove once stability issue with scene_meta is fixed
+ mvs_depth_path = os.path.join(
+ scene_root, "mvsanywhere", "v0", "depth", f"{view_file_name}.exr"
+ )
+ mvs_conf_path = os.path.join(
+ scene_root,
+ "mvsanywhere",
+ "v0",
+ "depth_confidence",
+ f"{view_file_name}.exr",
+ )
+ mask_path = os.path.join(
+ scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
+ )
+ view_data["pred_depth/mvsanywhere"] = load_data(mvs_depth_path, "depth")
+ view_data["depth_confidence/mvsanywhere"] = load_data(
+ mvs_conf_path, "scalar"
+ )
+ view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["pred_depth/mvsanywhere"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the dimensions of the original image
+ img_h, img_w = image.shape[:2]
+
+ # Resize depth to match image aspect ratio while ensuring that depth resolution doesn't increase
+ depthmap, target_depth_h, target_depth_w = (
+ resize_with_nearest_interpolation_to_match_aspect_ratio(
+ input_data=depthmap, img_h=img_h, img_w=img_w
+ )
+ )
+
+ # Now resize the image and update intrinsics to match the resized depth
+ image, _, intrinsics, _ = rescale_image_and_other_optional_info(
+ image=image,
+ output_resolution=(target_depth_w, target_depth_h),
+ depthmap=None,
+ camera_intrinsics=intrinsics,
+ )
+ image = np.array(image)
+
+ # Get the depth confidence map and mask out the MVS depth
+ confidence_map = view_data["depth_confidence/mvsanywhere"].numpy()
+ confidence_mask = (
+ confidence_map > self.mvs_confidence_filter_thres
+ ).astype(int)
+ confidence_mask = cv2.resize(
+ confidence_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ depthmap = np.where(confidence_mask, depthmap, 0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="DL3DV",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dl3dv", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = DL3DVWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ mvs_confidence_filter_thres=0.25,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = DL3DVWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # mvs_confidence_filter_thres=0.25,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "DL3DV_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/dtu.py b/mapanything/datasets/wai/dtu.py
new file mode 100644
index 0000000000000000000000000000000000000000..14c65b4b899d0594a5c8b06ab8948fa22a8d9625
--- /dev/null
+++ b/mapanything/datasets/wai/dtu.py
@@ -0,0 +1,272 @@
+"""
+DTU Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class DTUWAI(BaseDataset):
+ """
+ DTU dataset containing high-quality multi-view stereo object scans.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = "test"
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = False
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"dtu_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="DTU",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dtu", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = DTUWAI(
+ num_views=args.num_of_views,
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ seed=777,
+ transform="imgnorm",
+ data_norm_type="dinov2",
+ )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "DTU_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/dynamicreplica.py b/mapanything/datasets/wai/dynamicreplica.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0e2c2f3001a448f4627522549ea2d1433c0d542
--- /dev/null
+++ b/mapanything/datasets/wai/dynamicreplica.py
@@ -0,0 +1,292 @@
+"""
+Dynamic Replica Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class DynamicReplicaWAI(BaseDataset):
+ """
+ Dynamic Replica dataset containing synthetic scenes with humans and animals.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"dynamicreplica_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="DynamicReplica",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/dynamicreplica", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = DynamicReplicaWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = DynamicReplicaWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "DynamicReplica_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/eth3d.py b/mapanything/datasets/wai/eth3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..efc1a184fcd0d5b8ae3621bd6fe3373a97e83710
--- /dev/null
+++ b/mapanything/datasets/wai/eth3d.py
@@ -0,0 +1,272 @@
+"""
+ETH3D Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class ETH3DWAI(BaseDataset):
+ """
+ ETH3D dataset containing high-quality outdoor and indoor scans of the ETH Zurich campus.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = "test"
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"eth3d_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="ETH3D",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/eth3d", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = ETH3DWAI(
+ num_views=args.num_of_views,
+ covisibility_thres=0.025,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 336),
+ seed=777,
+ transform="imgnorm",
+ data_norm_type="dinov2",
+ )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "ETH3D_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/gta_sfm.py b/mapanything/datasets/wai/gta_sfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07de5b76a864c4cb4e42d6e2b14a3e0bbb67044
--- /dev/null
+++ b/mapanything/datasets/wai/gta_sfm.py
@@ -0,0 +1,303 @@
+"""
+GTA SfM Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class GTASfMWAI(BaseDataset):
+ """
+ GTA SfM dataset containing large diversity of synthetic in-the-wild scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"gta_sfm_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="GTASfM",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/gta_sfm", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = GTASfMWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = GTASfMWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 392),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "GTASfM_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/matrixcity.py b/mapanything/datasets/wai/matrixcity.py
new file mode 100644
index 0000000000000000000000000000000000000000..6849a932423467233c2b60d607eb97ecb65aeab2
--- /dev/null
+++ b/mapanything/datasets/wai/matrixcity.py
@@ -0,0 +1,307 @@
+"""
+Matrix City Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class MatrixCityWAI(BaseDataset):
+ """
+ Matrix City dataset containing large scale aerial & street-view urban synthetic scenes.
+ Depth maps are antialiased and there are floaters at all object boundaries due to interpolation.
+ https://github.com/city-super/MatrixCity/issues/4#issuecomment-3027961575
+ Normal based edge masking doesn't fix this issue completely.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"matrixcity_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="MatrixCity",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/matrixcity", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = MatrixCityWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = MatrixCityWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "MatrixCity_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/megadepth.py b/mapanything/datasets/wai/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..86e1cf41f3018a53b5faf45d4349491bbadc022a
--- /dev/null
+++ b/mapanything/datasets/wai/megadepth.py
@@ -0,0 +1,316 @@
+"""
+MegaDepth Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class MegaDepthWAI(BaseDataset):
+ """
+ MegaDepth dataset containing outdoor phototourism and in-the-wild scenes.
+ Also includes Tanks & Temples scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = False
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"megadepth_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ # modalities=["image", "depth", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+ ### HOTFIX: Load required additional masks manually
+ ### Remove once stability issue with scene_meta is fixed
+ mask_path = os.path.join(
+ scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
+ )
+ view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="MegaDepth",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/megadepth", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = MegaDepthWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 336),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = MegaDepthWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 336),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "MegaDepth_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/mpsd.py b/mapanything/datasets/wai/mpsd.py
new file mode 100644
index 0000000000000000000000000000000000000000..632702634273089ea169770b19553f89a6c92fb0
--- /dev/null
+++ b/mapanything/datasets/wai/mpsd.py
@@ -0,0 +1,313 @@
+"""
+MPSD Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class MPSDWAI(BaseDataset):
+ """
+ MPSD dataset containing outdoor planet scale metric reconstructions.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"mpsd_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ # modalities=["image", "depth", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+ ### HOTFIX: Load required additional masks manually
+ ### Remove once stability issue with scene_meta is fixed
+ mask_path = os.path.join(
+ scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
+ )
+ view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="MPSD",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/mpsd", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = MPSDWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.15,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = MPSDWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.15,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 392),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "MPSD_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/mvs_synth.py b/mapanything/datasets/wai/mvs_synth.py
new file mode 100644
index 0000000000000000000000000000000000000000..1db1d1f2eba87e5e1764e54402313b17261ac511
--- /dev/null
+++ b/mapanything/datasets/wai/mvs_synth.py
@@ -0,0 +1,303 @@
+"""
+MVS Synth Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class MVSSynthWAI(BaseDataset):
+ """
+ MVS Synth dataset containing large diversity of synthetic in-the-wild scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"mvs_synth_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="MVSSynth",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/mvs_synth", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = MVSSynthWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = MVSSynthWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "MVSSynth_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/paralleldomain4d.py b/mapanything/datasets/wai/paralleldomain4d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e247eda6710582ec0ef4375f7f915760cc97c32
--- /dev/null
+++ b/mapanything/datasets/wai/paralleldomain4d.py
@@ -0,0 +1,304 @@
+"""
+Parallel Domain 4D Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class ParallelDomain4DWAI(BaseDataset):
+ """
+ Parallel Domain 4D dataset containing large diversity of synthetic AV scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"paralleldomain4d_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="ParallelDomain4D",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/paralleldomain4d", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = ParallelDomain4DWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = ParallelDomain4DWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 392),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "ParallelDomain4D_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/sailvos3d.py b/mapanything/datasets/wai/sailvos3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5cdb7c9923ff1b8c1065ecc2e1e5fd557f5874b
--- /dev/null
+++ b/mapanything/datasets/wai/sailvos3d.py
@@ -0,0 +1,303 @@
+"""
+SAIL-VOS 3D Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class SAILVOS3DWAI(BaseDataset):
+ """
+ SAIL-VOS 3D dataset containing large diversity of synthetic in-the-wild cut scenes from GTA.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"sailvos3d_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="SAILVOS3D",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/sailvos3d", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = SAILVOS3DWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 336),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = SAILVOS3DWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 336),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "SAILVOS3D_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/scannetpp.py b/mapanything/datasets/wai/scannetpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e680dc0f7c5d4d300b3282c837ac4e48287527e
--- /dev/null
+++ b/mapanything/datasets/wai/scannetpp.py
@@ -0,0 +1,302 @@
+"""
+ScanNet++V2 Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class ScanNetPPWAI(BaseDataset):
+ """
+ ScanNet++V2 dataset containing large diversity of indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"scannetppv2_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "rendered_depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["rendered_depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="ScanNetPP",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/scannetppv2", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = ScanNetPPWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 336),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = ScanNetPPWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 336),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ # dataset = ScanNetPPWAI(
+ # num_views=args.num_of_views,
+ # split="test",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 336),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "ScanNetPP_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/spring.py b/mapanything/datasets/wai/spring.py
new file mode 100644
index 0000000000000000000000000000000000000000..442584a294d818af3a8ac156d358623d63d6d491
--- /dev/null
+++ b/mapanything/datasets/wai/spring.py
@@ -0,0 +1,318 @@
+"""
+Spring Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class SpringWAI(BaseDataset):
+ """
+ Spring dataset containing high-quality large-scale in-the-wild scenes with unique animated objects.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"spring_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ ) # Assumes only npy file in directory is covisbility map
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth", "skymask"],
+ # modalities=["image", "depth", "skymask", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+ ### HOTFIX: Load required additional masks manually
+ ### Remove once stability issue with scene_meta is fixed
+ mask_path = os.path.join(
+ scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
+ )
+ view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Get the sky mask and mask out GT depth
+ sky_mask = view_data["skymask"].numpy().astype(int)
+ depthmap = np.where(sky_mask, 0, depthmap)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="Spring",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/spring", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = SpringWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = SpringWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "Spring_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/structured3d.py b/mapanything/datasets/wai/structured3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a1fbfe24bef3b59f281ab9c46212295149cca61
--- /dev/null
+++ b/mapanything/datasets/wai/structured3d.py
@@ -0,0 +1,292 @@
+"""
+Structured3D Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class Structured3DWAI(BaseDataset):
+ """
+ Structured3D dataset containing large diversity of synthetic multi-room indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"structured3d_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="Structured3D",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/structured3d", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = Structured3DWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = Structured3DWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "Structured3D_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/tav2_wb.py b/mapanything/datasets/wai/tav2_wb.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad4e2e8cca4dbd88fafe347034028a4f13ced784
--- /dev/null
+++ b/mapanything/datasets/wai/tav2_wb.py
@@ -0,0 +1,330 @@
+"""
+TartanAirV2-WB Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class TartanAirV2WBWAI(BaseDataset):
+ """
+ TartanAirV2-WB dataset containing vastly-sized in-the-wild synthetic scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"tav2_wb_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ # modalities=["image", "depth", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+ ### HOTFIX: Load required additional masks manually
+ ### Remove once stability issue with scene_meta is fixed
+ mask_path = os.path.join(
+ scene_root, "moge", "v0", "mask", "moge2", f"{view_file_name}.png"
+ )
+ view_data["pred_mask/moge2"] = load_data(mask_path, "binary")
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Mask out the outlier depth caused due to transparent windows in TartanAirV2
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="TartanAirV2WB",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/tav2_wb", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = TartanAirV2WBWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 518),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = TartanAirV2WBWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 518),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ # dataset = TartanAirV2WBWAI(
+ # num_views=args.num_of_views,
+ # split="test",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 518),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "TartanAirV2WB_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/unrealstereo4k.py b/mapanything/datasets/wai/unrealstereo4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..05c5219d74c2b28aa0bda71b419bb02edf5394e8
--- /dev/null
+++ b/mapanything/datasets/wai/unrealstereo4k.py
@@ -0,0 +1,304 @@
+"""
+UnrealStereo4K Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class UnrealStereo4KWAI(BaseDataset):
+ """
+ UnrealStereo4K dataset containing synthetic in-the-wild scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"unrealstereo4k_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="UnrealStereo4K",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/unrealstereo4k", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = UnrealStereo4KWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = UnrealStereo4KWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "UnrealStereo4K_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/xrooms.py b/mapanything/datasets/wai/xrooms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5128ffd7180924e30f10134b54f27896003999f
--- /dev/null
+++ b/mapanything/datasets/wai/xrooms.py
@@ -0,0 +1,300 @@
+"""
+XRooms Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from wai import load_data, load_frame
+
+
+class XRoomsWAI(BaseDataset):
+ """
+ XRooms dataset containing large diversity of synthetic re-lightable indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"xrooms_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ ### HOTFIX HACK for incompatible covisibility in a few scenes
+ ### TODO: Re-mine covisibility on errorenous scenes
+ if len(pairwise_covisibility) == num_views_in_scene:
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+ else:
+ # Get a random view index
+ view_indices = self._rng.choice(num_views_in_scene, size=1, replace=False)
+ # Repeat the view index to get the desired number of views
+ view_indices = np.repeat(view_indices, num_views_to_sample)
+ ### END HOTFIX HACK
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="XRooms",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/xrooms", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = XRoomsWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 518),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = XRoomsWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 518),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "XRooms_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/models/__init__.py b/mapanything/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b9de65a1cb8ff5cb15fab527bc5a7194896f8a
--- /dev/null
+++ b/mapanything/models/__init__.py
@@ -0,0 +1,185 @@
+"""
+Model Factory for MapAnything
+"""
+
+import importlib.util
+import logging
+import warnings
+
+import numpy as np
+from omegaconf import DictConfig, OmegaConf
+
+# Core models that are always available
+from mapanything.models.mapanything import (
+ MapAnything,
+ MapAnythingAblations,
+ ModularDUSt3R,
+)
+
+# Suppress DINOv2 warnings
+logging.getLogger("dinov2").setLevel(logging.WARNING)
+warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning)
+warnings.filterwarnings(
+ "ignore", message="xFormers is not available", category=UserWarning
+)
+
+
+def resolve_special_float(value):
+ if value == "inf":
+ return np.inf
+ elif value == "-inf":
+ return -np.inf
+ else:
+ raise ValueError(f"Unknown special float value: {value}")
+
+
+def init_model(
+ model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False
+):
+ """
+ Initialize a model using OmegaConf configuration.
+
+ Args:
+ model_str (str): Name of the model class to create.
+ model_config (DictConfig): OmegaConf model configuration.
+ torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub.
+ """
+ if not OmegaConf.has_resolver("special_float"):
+ OmegaConf.register_new_resolver("special_float", resolve_special_float)
+ model_dict = OmegaConf.to_container(model_config, resolve=True)
+ model = model_factory(
+ model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict
+ )
+
+ return model
+
+
+# Define model configurations with import paths
+MODEL_CONFIGS = {
+ # Core models
+ "mapanything": {
+ "class": MapAnything,
+ },
+ "mapanything_ablations": {
+ "class": MapAnythingAblations,
+ },
+ "modular_dust3r": {
+ "class": ModularDUSt3R,
+ },
+ # External models
+ "anycalib": {
+ "module": "mapanything.models.external.anycalib",
+ "class_name": "AnyCalibWrapper",
+ },
+ "dust3r": {
+ "module": "mapanything.models.external.dust3r",
+ "class_name": "DUSt3RBAWrapper",
+ },
+ "mast3r": {
+ "module": "mapanything.models.external.mast3r",
+ "class_name": "MASt3RSGAWrapper",
+ },
+ "moge": {
+ "module": "mapanything.models.external.moge",
+ "class_name": "MoGeWrapper",
+ },
+ "must3r": {
+ "module": "mapanything.models.external.must3r",
+ "class_name": "MUSt3RWrapper",
+ },
+ "pi3": {
+ "module": "mapanything.models.external.pi3",
+ "class_name": "Pi3Wrapper",
+ },
+ "pow3r": {
+ "module": "mapanything.models.external.pow3r",
+ "class_name": "Pow3RWrapper",
+ },
+ "pow3r_ba": {
+ "module": "mapanything.models.external.pow3r",
+ "class_name": "Pow3RBAWrapper",
+ },
+ "vggt": {
+ "module": "mapanything.models.external.vggt",
+ "class_name": "VGGTWrapper",
+ },
+ # Add other model classes here
+}
+
+
+def check_module_exists(module_path):
+ """
+ Check if a module can be imported without actually importing it.
+
+ Args:
+ module_path (str): The path to the module to check.
+
+ Returns:
+ bool: True if the module can be imported, False otherwise.
+ """
+ return importlib.util.find_spec(module_path) is not None
+
+
+def model_factory(model_str: str, **kwargs):
+ """
+ Model factory for MapAnything.
+
+ Args:
+ model_str (str): Name of the model to create.
+ **kwargs: Additional keyword arguments to pass to the model constructor.
+
+ Returns:
+ nn.Module: An instance of the specified model.
+ """
+ if model_str not in MODEL_CONFIGS:
+ raise ValueError(
+ f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}"
+ )
+
+ model_config = MODEL_CONFIGS[model_str]
+
+ # Handle core models directly
+ if "class" in model_config:
+ model_class = model_config["class"]
+ # Handle external models with dynamic imports
+ elif "module" in model_config:
+ module_path = model_config["module"]
+ class_name = model_config["class_name"]
+
+ # Check if the module can be imported
+ if not check_module_exists(module_path):
+ raise ImportError(
+ f"Model '{model_str}' requires module '{module_path}' which is not installed. "
+ f"Please install the corresponding submodule or package."
+ )
+
+ # Dynamically import the module and get the class
+ try:
+ module = importlib.import_module(module_path)
+ model_class = getattr(module, class_name)
+ except (ImportError, AttributeError) as e:
+ raise ImportError(
+ f"Failed to import {class_name} from {module_path}: {str(e)}"
+ )
+ else:
+ raise ValueError(f"Invalid model configuration for {model_str}")
+
+ print(f"Initializing {model_class} with kwargs: {kwargs}")
+ if model_str != "org_dust3r":
+ return model_class(**kwargs)
+ else:
+ eval_str = kwargs.get("model_eval_str", None)
+ return eval(eval_str)
+
+
+def get_available_models() -> list:
+ """
+ Get a list of available models in MapAnything.
+
+ Returns:
+ list: A list of available model names.
+ """
+ return list(MODEL_CONFIGS.keys())
+
+
+__all__ = ["model_factory", "get_available_models"]
diff --git a/mapanything/models/external/__init__.py b/mapanything/models/external/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/models/external/anycalib/__init__.py b/mapanything/models/external/anycalib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb661bead9d4c7eb4e9ed57eca4736e7c42ac7b5
--- /dev/null
+++ b/mapanything/models/external/anycalib/__init__.py
@@ -0,0 +1,95 @@
+"""
+Inference wrapper for AnyCalib
+"""
+
+import torch
+from anycalib import AnyCalib
+
+from mapanything.utils.geometry import get_rays_in_camera_frame
+
+
+class AnyCalibWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ model_id="anycalib_pinhole",
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = name
+ self.model_id = model_id
+
+ # Initialize the model
+ self.model = AnyCalib(model_id=self.model_id)
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for AnyCalib.
+
+ Assumption:
+ - The number of input views is 1.
+ - The output camera model is pinhole (fx, fy, cx, cy).
+ This can be relaxed by not hardcoding the cam_id.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Length of the list should be 1.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["identity"]
+
+ Returns:
+ List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
+ """
+ # Check that the number of input views is 1
+ assert len(views) == 1, "AnyCalib only supports 1 input view."
+
+ # Get input shape of the images and batch size per view
+ _, _, height, width = views[0]["img"].shape
+
+ # Check the data norm type
+ # AnyCalib expects a normalized image but without the DINOv2 mean and std applied ("identity")
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "identity", (
+ "AnyCalib expects a normalized image but without the DINOv2 mean and std applied"
+ )
+
+ # Run AnyCalib inference
+ # Corresponding batched output dictionary:
+ # {
+ # "intrinsics": List[(D_i,) tensors] for each camera model "i" at the original input resolution,
+ # "fov_field": (B, N, 2) tensor with the regressed FoV field by the network. N≈320^2 (resolution close to the one seen during training),
+ # "tangent_coords": alias for "fov_field",
+ # "rays": (B, N, 3) tensor with the corresponding (via the exponential map) ray directions in the camera frame (x right, y down, z forward),
+ # "pred_size": (H, W) tuple with the image size used by the network. It can be used e.g. for resizing the FoV/ray fields to the original image size.
+ # }
+ # For "pinhole" camera model, the intrinsics are (fx, fy, cx, cy).
+ model_outputs = self.model.predict(views[0]["img"], cam_id="pinhole")
+
+ # Convert the list of intrinsics to a tensor
+ intrinsics = []
+ for intrinsics_per_sample in model_outputs["intrinsics"]:
+ pred_fx, pred_fy, pred_cx, pred_cy = intrinsics_per_sample
+ intrinsics_per_sample = torch.tensor(
+ [
+ [pred_fx, 0, pred_cx],
+ [0, pred_fy, pred_cy],
+ [0, 0, 1],
+ ],
+ device=views[0]["img"].device,
+ )
+ intrinsics.append(intrinsics_per_sample)
+
+ # Convert the list of intrinsics to a tensor of size (batch_size_per_view, 3, 3)
+ intrinsics = torch.stack(intrinsics)
+
+ # Get the ray directions
+ with torch.autocast("cuda", enabled=False):
+ _, ray_directions = get_rays_in_camera_frame(
+ intrinsics, height, width, normalize_to_unit_sphere=True
+ )
+
+ # Return the output in MapAnything format
+ res = [{"ray_directions": ray_directions, "intrinsics": intrinsics}]
+
+ return res
diff --git a/mapanything/models/external/dinov2/__init__.py b/mapanything/models/external/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9
--- /dev/null
+++ b/mapanything/models/external/dinov2/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+__version__ = "0.0.1"
diff --git a/mapanything/models/external/dinov2/hub/__init__.py b/mapanything/models/external/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/mapanything/models/external/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/mapanything/models/external/dinov2/hub/backbones.py b/mapanything/models/external/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..a56ab3710fc93c9cc3bb95f919dc3a1eb92c7000
--- /dev/null
+++ b/mapanything/models/external/dinov2/hub/backbones.py
@@ -0,0 +1,183 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from mapanything.models.external.dinov2.hub.utils import (
+ _DINOV2_BASE_URL,
+ _make_dinov2_model_name,
+)
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(
+ arch_name, patch_size, num_register_tokens
+ )
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+
+ return model
+
+
+def dinov2_vits14(
+ *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
+):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitb14(
+ *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
+):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitl14(
+ *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
+):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitg14(
+ *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
+):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(
+ *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
+):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(
+ *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
+):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(
+ *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
+):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(
+ *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
+):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/mapanything/models/external/dinov2/hub/utils.py b/mapanything/models/external/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3943bdf947a4e25ddf30ae33afb1534a95ee21da
--- /dev/null
+++ b/mapanything/models/external/dinov2/hub/utils.py
@@ -0,0 +1,42 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(
+ arch_name: str, patch_size: int, num_register_tokens: int = 0
+) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(
+ itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])
+ )
+ output = F.pad(x, pads)
+ return output
diff --git a/mapanything/models/external/dinov2/layers/__init__.py b/mapanything/models/external/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc32f90aa750961645f0b4c82cb21f35ac8cc30d
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from mapanything.models.external.dinov2.layers.dino_head import DINOHead # noqa
+from mapanything.models.external.dinov2.layers.mlp import Mlp # noqa
+from mapanything.models.external.dinov2.layers.patch_embed import PatchEmbed # noqa
+from mapanything.models.external.dinov2.layers.swiglu_ffn import (
+ SwiGLUFFN, # noqa
+ SwiGLUFFNFused, # noqa
+)
+from mapanything.models.external.dinov2.layers.block import NestedTensorBlock # noqa
+from mapanything.models.external.dinov2.layers.attention import MemEffAttention # noqa
diff --git a/mapanything/models/external/dinov2/layers/attention.py b/mapanything/models/external/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..68ef9b613577e49b1c77aca29307cd2e2e64ff62
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/attention.py
@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+
+from torch import nn, Tensor
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Attention)")
+ else:
+ # warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/mapanything/models/external/dinov2/layers/block.py b/mapanything/models/external/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a955e058ffa77e9e8388f0cb0aa370cd3eaa972
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/block.py
@@ -0,0 +1,290 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Any, Callable, Dict, List, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from mapanything.models.external.dinov2.layers.attention import (
+ Attention,
+ MemEffAttention,
+)
+from mapanything.models.external.dinov2.layers.drop_path import DropPath
+from mapanything.models.external.dinov2.layers.layer_scale import LayerScale
+from mapanything.models.external.dinov2.layers.mlp import Mlp
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, index_select_cat, scaled_index_add
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Block)")
+ else:
+ # warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ else:
+ x_plus_residual = scaled_index_add(
+ x,
+ brange,
+ residual.to(dtype=x.dtype),
+ scaling=scaling_vector,
+ alpha=residual_scale_factor,
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = (
+ [b.shape[0] for b in branges]
+ if branges is not None
+ else [x.shape[0] for x in x_list]
+ )
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
+ 1, -1, x_list[0].shape[-1]
+ )
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
+ ]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(
+ x_list, branges, residual_list, residual_scale_factors
+ ):
+ outputs.append(
+ add_residual(
+ x, brange, residual, residual_scale_factor, scaling_vector
+ ).view_as(x)
+ )
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma
+ if isinstance(self.ls1, LayerScale)
+ else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma
+ if isinstance(self.ls1, LayerScale)
+ else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/mapanything/models/external/dinov2/layers/dino_head.py b/mapanything/models/external/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3621df591d8a2dee72ea6ad365b7b39678ca051e
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/dino_head.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(
+ nlayers,
+ in_dim,
+ bottleneck_dim,
+ hidden_dim=hidden_dim,
+ use_bn=use_bn,
+ bias=mlp_bias,
+ )
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(
+ nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
+):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/mapanything/models/external/dinov2/layers/drop_path.py b/mapanything/models/external/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..7de14f0f11cb348088cefbb0b976f50635b66bba
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/drop_path.py
@@ -0,0 +1,36 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/mapanything/models/external/dinov2/layers/layer_scale.py b/mapanything/models/external/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2592ca74e3972ec2ab62ecf6fabd4e2631353a6
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/layer_scale.py
@@ -0,0 +1,26 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import nn, Tensor
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/mapanything/models/external/dinov2/layers/mlp.py b/mapanything/models/external/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..66ef78b612cb3eb4ea2481b82b3ec02a5fbcaf49
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import nn, Tensor
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/mapanything/models/external/dinov2/layers/patch_embed.py b/mapanything/models/external/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..90d8b1b87df8c719a60a654bccdf022d31e6fd80
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/patch_embed.py
@@ -0,0 +1,100 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+import torch.nn as nn
+from torch import Tensor
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, (
+ f"Input image height {H} is not a multiple of patch height {patch_H}"
+ )
+ assert W % patch_W == 0, (
+ f"Input image width {W} is not a multiple of patch width: {patch_W}"
+ )
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = (
+ Ho
+ * Wo
+ * self.embed_dim
+ * self.in_chans
+ * (self.patch_size[0] * self.patch_size[1])
+ )
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/mapanything/models/external/dinov2/layers/swiglu_ffn.py b/mapanything/models/external/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ecff675a8539f382a511d797a9c3ca6da4bd539
--- /dev/null
+++ b/mapanything/models/external/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,71 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ # warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ # warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/mapanything/models/external/dinov2/models/__init__.py b/mapanything/models/external/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..870e3b24f470774e47d0f43333b138e7c4e8547a
--- /dev/null
+++ b/mapanything/models/external/dinov2/models/__init__.py
@@ -0,0 +1,44 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+import mapanything.models.external.dinov2.models.vision_transformer as vits
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(
+ cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size
+ )
diff --git a/mapanything/models/external/dinov2/models/vision_transformer.py b/mapanything/models/external/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3059c767561bca21b492ce1abfb74a977fc65edd
--- /dev/null
+++ b/mapanything/models/external/dinov2/models/vision_transformer.py
@@ -0,0 +1,448 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import math
+from functools import partial
+from typing import Callable, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.utils.checkpoint import checkpoint
+
+from mapanything.models.external.dinov2.layers import (
+ MemEffAttention,
+ Mlp,
+ NestedTensorBlock as Block,
+ PatchEmbed,
+ SwiGLUFFNFused,
+)
+from mapanything.models.external.pi3.layers.attention import FlashAttention
+
+# logger = logging.getLogger("dinov2")
+
+
+def named_apply(
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
+) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(
+ fn=fn,
+ module=child_module,
+ name=child_name,
+ depth_first=depth_first,
+ include_root=True,
+ )
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = (
+ embed_dim # num_features for consistency with other models
+ )
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
+ )
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
+ if num_register_tokens
+ else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ # logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ # logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ # logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ attn_class=FlashAttention,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append(
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
+ )
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
+ previous_dtype
+ )
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
+ )
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [
+ self.prepare_tokens_with_masks(x, masks)
+ for x, masks in zip(x_list, masks_list)
+ ]
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=False)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=False)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = (
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ )
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), (
+ f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ )
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = (
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ )
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), (
+ f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ )
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/mapanything/models/external/dinov2/utils/__init__.py b/mapanything/models/external/dinov2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/mapanything/models/external/dinov2/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/mapanything/models/external/dinov2/utils/cluster.py b/mapanything/models/external/dinov2/utils/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..660f1307c14ecc345afe1de8108d084896495596
--- /dev/null
+++ b/mapanything/models/external/dinov2/utils/cluster.py
@@ -0,0 +1,102 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+
+class ClusterType(Enum):
+ AWS = "aws"
+ FAIR = "fair"
+ RSC = "rsc"
+
+
+def _guess_cluster_type() -> ClusterType:
+ uname = os.uname()
+ if uname.sysname == "Linux":
+ if uname.release.endswith("-aws"):
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
+ return ClusterType.AWS
+ elif uname.nodename.startswith("rsc"):
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
+ return ClusterType.RSC
+
+ return ClusterType.FAIR
+
+
+def get_cluster_type(
+ cluster_type: Optional[ClusterType] = None,
+) -> Optional[ClusterType]:
+ if cluster_type is None:
+ return _guess_cluster_type()
+
+ return cluster_type
+
+
+def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ CHECKPOINT_DIRNAMES = {
+ ClusterType.AWS: "checkpoints",
+ ClusterType.FAIR: "checkpoint",
+ ClusterType.RSC: "checkpoint/dino",
+ }
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
+
+
+def get_user_checkpoint_path(
+ cluster_type: Optional[ClusterType] = None,
+) -> Optional[Path]:
+ checkpoint_path = get_checkpoint_path(cluster_type)
+ if checkpoint_path is None:
+ return None
+
+ username = os.environ.get("USER")
+ assert username is not None
+ return checkpoint_path / username
+
+
+def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ SLURM_PARTITIONS = {
+ ClusterType.AWS: "learnlab",
+ ClusterType.FAIR: "learnlab",
+ ClusterType.RSC: "learn",
+ }
+ return SLURM_PARTITIONS[cluster_type]
+
+
+def get_slurm_executor_parameters(
+ nodes: int,
+ num_gpus_per_node: int,
+ cluster_type: Optional[ClusterType] = None,
+ **kwargs,
+) -> Dict[str, Any]:
+ # create default parameters
+ params = {
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
+ "gpus_per_node": num_gpus_per_node,
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
+ "cpus_per_task": 10,
+ "nodes": nodes,
+ "slurm_partition": get_slurm_partition(cluster_type),
+ }
+ # apply cluster-specific adjustments
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type == ClusterType.AWS:
+ params["cpus_per_task"] = 12
+ del params["mem_gb"]
+ elif cluster_type == ClusterType.RSC:
+ params["cpus_per_task"] = 12
+ # set additional parameters / apply overrides
+ params.update(kwargs)
+ return params
diff --git a/mapanything/models/external/dinov2/utils/config.py b/mapanything/models/external/dinov2/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..337c8945f6486058a2549b8d22eb2f40fbf3feb7
--- /dev/null
+++ b/mapanything/models/external/dinov2/utils/config.py
@@ -0,0 +1,74 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+import os
+
+import dinov2.distributed as distributed
+from dinov2.configs import dinov2_default_config
+from dinov2.logging import setup_logging
+from dinov2.utils import utils
+from omegaconf import OmegaConf
+
+logger = logging.getLogger("dinov2")
+
+
+def apply_scaling_rules_to_cfg(cfg): # to fix
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
+ base_lr = cfg.optim.base_lr
+ cfg.optim.lr = base_lr
+ cfg.optim.lr *= math.sqrt(
+ cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0
+ )
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
+ else:
+ raise NotImplementedError
+ return cfg
+
+
+def write_config(cfg, output_dir, name="config.yaml"):
+ logger.info(OmegaConf.to_yaml(cfg))
+ saved_cfg_path = os.path.join(output_dir, name)
+ with open(saved_cfg_path, "w") as f:
+ OmegaConf.save(config=cfg, f=f)
+ return saved_cfg_path
+
+
+def get_cfg_from_args(args):
+ args.output_dir = os.path.abspath(args.output_dir)
+ args.opts += [f"train.output_dir={args.output_dir}"]
+ default_cfg = OmegaConf.create(dinov2_default_config)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
+ return cfg
+
+
+def default_setup(args):
+ distributed.enable(overwrite=True)
+ seed = getattr(args, "seed", 0)
+ rank = distributed.get_global_rank()
+
+ global logger
+ setup_logging(output=args.output_dir, level=logging.INFO)
+ logger = logging.getLogger("dinov2")
+
+ utils.fix_random_seeds(seed + rank)
+ logger.info("git:\n {}\n".format(utils.get_sha()))
+ logger.info(
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
+ )
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg_from_args(args)
+ os.makedirs(args.output_dir, exist_ok=True)
+ default_setup(args)
+ apply_scaling_rules_to_cfg(cfg)
+ write_config(cfg, args.output_dir)
+ return cfg
diff --git a/mapanything/models/external/dinov2/utils/dtype.py b/mapanything/models/external/dinov2/utils/dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..490c52912ce0c2fa0ad3df812c05b3aec6883a65
--- /dev/null
+++ b/mapanything/models/external/dinov2/utils/dtype.py
@@ -0,0 +1,38 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+from typing import Dict, Union
+
+import numpy as np
+import torch
+
+TypeSpec = Union[str, np.dtype, torch.dtype]
+
+
+_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
+ np.dtype("bool"): torch.bool,
+ np.dtype("uint8"): torch.uint8,
+ np.dtype("int8"): torch.int8,
+ np.dtype("int16"): torch.int16,
+ np.dtype("int32"): torch.int32,
+ np.dtype("int64"): torch.int64,
+ np.dtype("float16"): torch.float16,
+ np.dtype("float32"): torch.float32,
+ np.dtype("float64"): torch.float64,
+ np.dtype("complex64"): torch.complex64,
+ np.dtype("complex128"): torch.complex128,
+}
+
+
+def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
+ if isinstance(dtype, torch.dtype):
+ return dtype
+ if isinstance(dtype, str):
+ dtype = np.dtype(dtype)
+ assert isinstance(dtype, np.dtype), (
+ f"Expected an instance of nunpy dtype, got {type(dtype)}"
+ )
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
diff --git a/mapanything/models/external/dinov2/utils/param_groups.py b/mapanything/models/external/dinov2/utils/param_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fccb5ff667afb5f9ca9117fc253a2fc2eb7d206
--- /dev/null
+++ b/mapanything/models/external/dinov2/utils/param_groups.py
@@ -0,0 +1,122 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+from collections import defaultdict
+
+logger = logging.getLogger("dinov2")
+
+
+def get_vit_lr_decay_rate(
+ name,
+ lr_decay_rate=1.0,
+ num_layers=12,
+ force_is_backbone=False,
+ chunked_blocks=False,
+):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if name.startswith("backbone") or force_is_backbone:
+ if (
+ ".pos_embed" in name
+ or ".patch_embed" in name
+ or ".mask_token" in name
+ or ".cls_token" in name
+ or ".register_tokens" in name
+ ):
+ layer_id = 0
+ elif force_is_backbone and (
+ "pos_embed" in name
+ or "patch_embed" in name
+ or "mask_token" in name
+ or "cls_token" in name
+ or "register_tokens" in name
+ ):
+ layer_id = 0
+ elif ".blocks." in name and ".residual." not in name:
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
+ elif "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
+
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+
+def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
+ chunked_blocks = False
+ if hasattr(model, "n_blocks"):
+ logger.info("chunked fsdp")
+ n_blocks = model.n_blocks
+ chunked_blocks = model.chunked_blocks
+ elif hasattr(model, "blocks"):
+ logger.info("first code branch")
+ n_blocks = len(model.blocks)
+ elif hasattr(model, "backbone"):
+ logger.info("second code branch")
+ n_blocks = len(model.backbone.blocks)
+ else:
+ logger.info("else code branch")
+ n_blocks = 0
+ all_param_groups = []
+
+ for name, param in model.named_parameters():
+ name = name.replace("_fsdp_wrapped_module.", "")
+ if not param.requires_grad:
+ continue
+ decay_rate = get_vit_lr_decay_rate(
+ name,
+ lr_decay_rate,
+ num_layers=n_blocks,
+ force_is_backbone=n_blocks > 0,
+ chunked_blocks=chunked_blocks,
+ )
+ d = {
+ "params": param,
+ "is_last_layer": False,
+ "lr_multiplier": decay_rate,
+ "wd_multiplier": 1.0,
+ "name": name,
+ }
+
+ if "last_layer" in name:
+ d.update({"is_last_layer": True})
+
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
+ d.update({"wd_multiplier": 0.0})
+
+ if "patch_embed" in name:
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
+
+ all_param_groups.append(d)
+ logger.info(
+ f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}"""
+ )
+
+ return all_param_groups
+
+
+def fuse_params_groups(
+ all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")
+):
+ fused_params_groups = defaultdict(lambda: {"params": []})
+ for d in all_params_groups:
+ identifier = ""
+ for k in keys:
+ identifier += k + str(d[k]) + "_"
+
+ for k in keys:
+ fused_params_groups[identifier][k] = d[k]
+ fused_params_groups[identifier]["params"].append(d["params"])
+
+ return fused_params_groups.values()
diff --git a/mapanything/models/external/dinov2/utils/utils.py b/mapanything/models/external/dinov2/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ed972fdd47ee587fa50ec7050addf4575ff7ed8
--- /dev/null
+++ b/mapanything/models/external/dinov2/utils/utils.py
@@ -0,0 +1,105 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+import random
+import subprocess
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch import nn
+
+# logger = logging.getLogger("dinov2")
+
+
+def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
+ state_dict = torch.hub.load_state_dict_from_url(
+ pretrained_weights, map_location="cpu"
+ )
+ else:
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
+ if checkpoint_key is not None and checkpoint_key in state_dict:
+ # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
+ state_dict = state_dict[checkpoint_key]
+ # remove `module.` prefix
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ # remove `backbone.` prefix induced by multicrop wrapper
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
+ _ = model.load_state_dict(state_dict, strict=False)
+ # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
+
+
+def fix_random_seeds(seed=31):
+ """
+ Fix random seeds.
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommitted changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+class CosineScheduler(object):
+ def __init__(
+ self,
+ base_value,
+ final_value,
+ total_iters,
+ warmup_iters=0,
+ start_warmup_value=0,
+ freeze_iters=0,
+ ):
+ super().__init__()
+ self.final_value = final_value
+ self.total_iters = total_iters
+
+ freeze_schedule = np.zeros((freeze_iters))
+
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (
+ 1 + np.cos(np.pi * iters / len(iters))
+ )
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
+
+ assert len(self.schedule) == self.total_iters
+
+ def __getitem__(self, it):
+ if it >= self.total_iters:
+ return self.final_value
+ else:
+ return self.schedule[it]
+
+
+def has_batchnorms(model):
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
+ for name, module in model.named_modules():
+ if isinstance(module, bn_types):
+ return True
+ return False
diff --git a/mapanything/models/external/dust3r/__init__.py b/mapanything/models/external/dust3r/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a0e622cb2268264fa9716f7035c7ee12cbf011b
--- /dev/null
+++ b/mapanything/models/external/dust3r/__init__.py
@@ -0,0 +1,217 @@
+"""
+Inference wrapper for DUSt3R
+"""
+
+import warnings
+
+import torch
+from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
+from dust3r.image_pairs import make_pairs
+from dust3r.inference import inference
+from dust3r.model import AsymmetricCroCo3DStereo # noqa
+
+from mapanything.models.external.vggt.utils.rotation import mat_to_quat
+from mapanything.utils.geometry import (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ convert_z_depth_to_depth_along_ray,
+ depthmap_to_camera_frame,
+ get_rays_in_camera_frame,
+)
+
+inf = float("inf")
+
+
+def load_model(model_path, device, verbose=True):
+ if verbose:
+ print("Loading model from", model_path)
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
+ args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
+ if "landscape_only" not in args:
+ args = args[:-1] + ", landscape_only=False)"
+ else:
+ args = args.replace(" ", "").replace(
+ "landscape_only=True", "landscape_only=False"
+ )
+ assert "landscape_only=False" in args
+ if verbose:
+ print(f"Instantiating: {args}")
+ try:
+ net = eval(args)
+ except NameError:
+ net = AsymmetricCroCo3DStereo(
+ enc_depth=24,
+ dec_depth=12,
+ enc_embed_dim=1024,
+ dec_embed_dim=768,
+ enc_num_heads=16,
+ dec_num_heads=12,
+ pos_embed="RoPE100",
+ patch_embed_cls="PatchEmbedDust3R",
+ img_size=(512, 512),
+ head_type="dpt",
+ output_mode="pts3d",
+ depth_mode=("exp", -inf, inf),
+ conf_mode=("exp", 1, inf),
+ landscape_only=False,
+ )
+ s = net.load_state_dict(ckpt["model"], strict=False)
+ if verbose:
+ print(s)
+ return net.to(device)
+
+
+class DUSt3RBAWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ ckpt_path,
+ scene_graph="complete",
+ inference_batch_size=32,
+ global_optim_schedule="cosine",
+ global_optim_lr=0.01,
+ global_optim_niter=300,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = name
+ self.ckpt_path = ckpt_path
+ self.scene_graph = scene_graph
+ self.inference_batch_size = inference_batch_size
+ self.global_optim_schedule = global_optim_schedule
+ self.global_optim_lr = global_optim_lr
+ self.global_optim_niter = global_optim_niter
+
+ # Init the model and load the checkpoint
+ self.model = load_model(self.ckpt_path, device="cpu")
+
+ # Init the global aligner mode
+ self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for DUSt3R using the global aligner.
+
+ Assumption:
+ - The batch size of input views is 1.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys, where B is the batch size and is 1:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["dust3r"]
+
+ Returns:
+ List[dict]: A list containing the final outputs for the input views.
+ """
+ # Check the batch size of input views
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ device = views[0]["img"].device
+ num_views = len(views)
+ assert batch_size_per_view == 1, (
+ f"Batch size of input views should be 1, but got {batch_size_per_view}."
+ )
+
+ # Check the data norm type
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "dust3r", (
+ "DUSt3R expects a normalized image with the DUSt3R normalization scheme applied"
+ )
+
+ # Convert the input views to the expected input format
+ images = []
+ for view in views:
+ images.append(
+ dict(
+ img=view["img"],
+ idx=len(images),
+ instance=str(len(images)),
+ )
+ )
+
+ # Make image pairs and run inference pair-wise
+ pairs = make_pairs(
+ images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
+ )
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=FutureWarning)
+ output = inference(
+ pairs,
+ self.model,
+ device,
+ batch_size=self.inference_batch_size,
+ verbose=False,
+ )
+
+ # Global optimization
+ with torch.enable_grad():
+ scene = global_aligner(
+ output, device=device, mode=self.global_aligner_mode, verbose=False
+ )
+ _ = scene.compute_global_alignment(
+ init="mst",
+ niter=self.global_optim_niter,
+ schedule=self.global_optim_schedule,
+ lr=self.global_optim_lr,
+ )
+
+ # Make sure scene is not None
+ if scene is None:
+ raise RuntimeError("Global optimization failed.")
+
+ # Get the predictions
+ intrinsics = scene.get_intrinsics()
+ c2w_poses = scene.get_im_poses()
+ depths = scene.get_depthmaps()
+
+ # Convert the output to the MapAnything format
+ with torch.autocast("cuda", enabled=False):
+ res = []
+ for view_idx in range(num_views):
+ # Get the current view predictions
+ curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
+ curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
+ curr_view_depth_z = depths[view_idx].unsqueeze(0)
+
+ # Convert the pose to quaternions and translation
+ curr_view_cam_translations = curr_view_pose[..., :3, 3]
+ curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
+
+ # Get the camera frame pointmaps
+ curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
+ curr_view_depth_z, curr_view_intrinsic
+ )
+
+ # Convert the z depth to depth along ray
+ curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
+ curr_view_depth_z, curr_view_intrinsic
+ )
+ curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
+
+ # Get the ray directions on the unit sphere in the camera frame
+ _, curr_view_ray_dirs = get_rays_in_camera_frame(
+ curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
+ )
+
+ # Get the pointmaps
+ curr_view_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ curr_view_ray_dirs,
+ curr_view_depth_along_ray,
+ curr_view_cam_translations,
+ curr_view_cam_quats,
+ )
+ )
+
+ # Append the outputs to the result list
+ res.append(
+ {
+ "pts3d": curr_view_pts3d,
+ "pts3d_cam": curr_view_pts3d_cam,
+ "ray_directions": curr_view_ray_dirs,
+ "depth_along_ray": curr_view_depth_along_ray,
+ "cam_trans": curr_view_cam_translations,
+ "cam_quats": curr_view_cam_quats,
+ }
+ )
+
+ return res
diff --git a/mapanything/models/external/mast3r/__init__.py b/mapanything/models/external/mast3r/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c03e2bf9750b6a5e0c1261a314d7a82cd5e2f6
--- /dev/null
+++ b/mapanything/models/external/mast3r/__init__.py
@@ -0,0 +1,191 @@
+"""
+Inference wrapper for MASt3R + Sparse GA
+"""
+
+import os
+import tempfile
+import warnings
+
+import torch
+from dust3r.image_pairs import make_pairs
+from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
+from mast3r.model import load_model
+
+from mapanything.models.external.vggt.utils.rotation import mat_to_quat
+from mapanything.utils.geometry import (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ convert_z_depth_to_depth_along_ray,
+ depthmap_to_camera_frame,
+ get_rays_in_camera_frame,
+)
+
+
+class MASt3RSGAWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ ckpt_path,
+ cache_dir,
+ scene_graph="complete",
+ sparse_ga_lr1=0.07,
+ sparse_ga_niter1=300,
+ sparse_ga_lr2=0.01,
+ sparse_ga_niter2=300,
+ sparse_ga_optim_level="refine+depth",
+ sparse_ga_shared_intrinsics=False,
+ sparse_ga_matching_conf_thr=5.0,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = name
+ self.ckpt_path = ckpt_path
+ self.cache_dir = cache_dir
+ self.scene_graph = scene_graph
+ self.sparse_ga_lr1 = sparse_ga_lr1
+ self.sparse_ga_niter1 = sparse_ga_niter1
+ self.sparse_ga_lr2 = sparse_ga_lr2
+ self.sparse_ga_niter2 = sparse_ga_niter2
+ self.sparse_ga_optim_level = sparse_ga_optim_level
+ self.sparse_ga_shared_intrinsics = sparse_ga_shared_intrinsics
+ self.sparse_ga_matching_conf_thr = sparse_ga_matching_conf_thr
+
+ # Init the model and load the checkpoint
+ self.model = load_model(self.ckpt_path, device="cpu")
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for MASt3R using the sparse global aligner.
+
+ Assumption:
+ - The batch size of input views is 1.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys, where B is the batch size and is 1:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["dust3r"]
+ "label" (list): ["scene_name"]
+ "instance" (list): ["image_name"]
+
+ Returns:
+ List[dict]: A list containing the final outputs for the input views.
+ """
+ # Check the batch size of input views
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ device = views[0]["img"].device
+ num_views = len(views)
+ assert batch_size_per_view == 1, (
+ f"Batch size of input views should be 1, but got {batch_size_per_view}."
+ )
+
+ # Check the data norm type
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "dust3r", (
+ "MASt3R expects a normalized image with the DUSt3R normalization scheme applied"
+ )
+
+ # Convert the input views to the expected input format
+ images = []
+ image_paths = []
+ for view in views:
+ images.append(
+ dict(
+ img=view["img"].cpu(),
+ idx=len(images),
+ instance=str(len(images)),
+ true_shape=torch.tensor(view["img"].shape[-2:])[None]
+ .repeat(batch_size_per_view, 1)
+ .numpy(),
+ )
+ )
+ view_name = os.path.join(view["label"][0], view["instance"][0])
+ image_paths.append(view_name)
+
+ # Make image pairs and run inference
+ # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
+ pairs = make_pairs(
+ images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
+ )
+ with torch.enable_grad():
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=FutureWarning)
+ tempfile.mkdtemp(dir=self.cache_dir)
+ scene = sparse_global_alignment(
+ image_paths,
+ pairs,
+ self.cache_dir,
+ self.model,
+ lr1=self.sparse_ga_lr1,
+ niter1=self.sparse_ga_niter1,
+ lr2=self.sparse_ga_lr2,
+ niter2=self.sparse_ga_niter2,
+ device=device,
+ opt_depth="depth" in self.sparse_ga_optim_level,
+ shared_intrinsics=self.sparse_ga_shared_intrinsics,
+ matching_conf_thr=self.sparse_ga_matching_conf_thr,
+ verbose=False,
+ )
+
+ # Make sure scene is not None
+ if scene is None:
+ raise RuntimeError("Global optimization failed.")
+
+ # Get the predictions
+ intrinsics = scene.intrinsics
+ c2w_poses = scene.get_im_poses()
+ _, depths, _ = scene.get_dense_pts3d()
+
+ # Convert the output to the MapAnything format
+ with torch.autocast("cuda", enabled=False):
+ res = []
+ for view_idx in range(num_views):
+ # Get the current view predictions
+ curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
+ curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
+ curr_view_depth_z = (
+ depths[view_idx].reshape((height, width)).unsqueeze(0)
+ )
+
+ # Convert the pose to quaternions and translation
+ curr_view_cam_translations = curr_view_pose[..., :3, 3]
+ curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
+
+ # Get the camera frame pointmaps
+ curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
+ curr_view_depth_z, curr_view_intrinsic
+ )
+
+ # Convert the z depth to depth along ray
+ curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
+ curr_view_depth_z, curr_view_intrinsic
+ )
+ curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
+
+ # Get the ray directions on the unit sphere in the camera frame
+ _, curr_view_ray_dirs = get_rays_in_camera_frame(
+ curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
+ )
+
+ # Get the pointmaps
+ curr_view_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ curr_view_ray_dirs,
+ curr_view_depth_along_ray,
+ curr_view_cam_translations,
+ curr_view_cam_quats,
+ )
+ )
+
+ # Append the outputs to the result list
+ res.append(
+ {
+ "pts3d": curr_view_pts3d,
+ "pts3d_cam": curr_view_pts3d_cam,
+ "ray_directions": curr_view_ray_dirs,
+ "depth_along_ray": curr_view_depth_along_ray,
+ "cam_trans": curr_view_cam_translations,
+ "cam_quats": curr_view_cam_quats,
+ }
+ )
+
+ return res
diff --git a/mapanything/models/external/moge/__init__.py b/mapanything/models/external/moge/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e188fb0f0bb451489dd546b189a7ca05b2b0ee6c
--- /dev/null
+++ b/mapanything/models/external/moge/__init__.py
@@ -0,0 +1,114 @@
+"""
+Inference wrapper for MoGe
+"""
+
+import torch
+
+from mapanything.models.external.moge.models.v1 import MoGeModel as MoGeModelV1
+from mapanything.models.external.moge.models.v2 import MoGeModel as MoGeModelV2
+
+
+class MoGeWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ model_string="Ruicheng/moge-2-vitl",
+ torch_hub_force_reload=False,
+ load_custom_ckpt=False,
+ custom_ckpt_path=None,
+ ):
+ super().__init__()
+ self.name = name
+ self.model_string = model_string
+ self.torch_hub_force_reload = torch_hub_force_reload
+ self.load_custom_ckpt = load_custom_ckpt
+ self.custom_ckpt_path = custom_ckpt_path
+
+ # Mapping of MoGe model version to checkpoint strings
+ self.moge_model_map = {
+ "v1": ["Ruicheng/moge-vitl"],
+ "v2": [
+ "Ruicheng/moge-2-vits-normal",
+ "Ruicheng/moge-2-vitb-normal",
+ "Ruicheng/moge-2-vitl-normal",
+ "Ruicheng/moge-2-vitl",
+ ],
+ }
+
+ # Initialize the model
+ if self.model_string in self.moge_model_map["v1"]:
+ self.model = MoGeModelV1.from_pretrained(self.model_string)
+ elif self.model_string in self.moge_model_map["v2"]:
+ self.model = MoGeModelV2.from_pretrained(self.model_string)
+ else:
+ raise ValueError(
+ f"Invalid model string: {self.model_string}. Valid strings are: {self.moge_model_map}"
+ )
+
+ # Load custom checkpoint if requested
+ if self.load_custom_ckpt:
+ print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
+ assert self.custom_ckpt_path is not None, (
+ "custom_ckpt_path must be provided if load_custom_ckpt is set to True"
+ )
+ custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
+ print(self.model.load_state_dict(custom_ckpt, strict=True))
+ del custom_ckpt # in case it occupies memory
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for MoGe-2.
+ The predicted MoGe-2 mask is not applied to the outputs.
+ The number of tokens for inference is determined by the image shape.
+
+ Assumption:
+ - The number of input views is 1.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Length of the list should be 1.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["identity"]
+
+ Returns:
+ List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
+ """
+ # Check that the number of input views is 1
+ assert len(views) == 1, "MoGe only supports 1 input view."
+
+ # Get input shape of the images, number of tokens for inference, and batch size per view
+ _, _, height, width = views[0]["img"].shape
+ num_tokens = int(height // 14) * int(width // 14)
+
+ # Check the data norm type
+ # MoGe expects a normalized image but without the DINOv2 mean and std applied ("identity")
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "identity", (
+ "MoGe expects a normalized image but without the DINOv2 mean and std applied"
+ )
+
+ # Run MoGe inference
+ # Output dict contains: "points", "depth", "mask", "intrinsics", "normal" (based on model config)
+ model_outputs = self.model.infer(
+ image=views[0]["img"], num_tokens=num_tokens, apply_mask=False
+ )
+
+ # Get the ray directions and depth along ray
+ with torch.autocast("cuda", enabled=False):
+ depth_along_ray = torch.norm(model_outputs["points"], dim=-1, keepdim=True)
+ ray_directions = model_outputs["points"] / depth_along_ray
+
+ # Convert the output to MapAnything format
+ result_dict = {
+ "pts3d": model_outputs["points"],
+ "pts3d_cam": model_outputs["points"],
+ "depth_z": model_outputs["depth"].unsqueeze(-1),
+ "intrinsics": model_outputs["intrinsics"],
+ "non_ambiguous_mask": model_outputs["mask"],
+ "ray_directions": ray_directions,
+ "depth_along_ray": depth_along_ray,
+ }
+ res = [result_dict]
+
+ return res
diff --git a/mapanything/models/external/moge/models/modules.py b/mapanything/models/external/moge/models/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e043408573a6bce823e9af87cb74c942561e7e5
--- /dev/null
+++ b/mapanything/models/external/moge/models/modules.py
@@ -0,0 +1,467 @@
+import functools
+import importlib
+import itertools
+from typing import List, Literal, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mapanything.models.external.dinov2.models.vision_transformer import (
+ DinoVisionTransformer,
+)
+from mapanything.models.external.moge.models.utils import (
+ wrap_dinov2_attention_with_sdpa,
+ wrap_module_with_gradient_checkpointing,
+)
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = None,
+ hidden_channels: int = None,
+ kernel_size: int = 3,
+ padding_mode: str = "replicate",
+ activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
+ in_norm: Literal[
+ "group_norm", "layer_norm", "instance_norm", "none"
+ ] = "layer_norm",
+ hidden_norm: Literal[
+ "group_norm", "layer_norm", "instance_norm"
+ ] = "group_norm",
+ ):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation == "relu":
+ activation_cls = nn.ReLU
+ elif activation == "leaky_relu":
+ activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
+ elif activation == "silu":
+ activation_cls = nn.SiLU
+ elif activation == "elu":
+ activation_cls = nn.ELU
+ else:
+ raise ValueError(f"Unsupported activation function: {activation}")
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(in_channels // 32, in_channels)
+ if in_norm == "group_norm"
+ else nn.GroupNorm(1, in_channels)
+ if in_norm == "layer_norm"
+ else nn.InstanceNorm2d(in_channels)
+ if in_norm == "instance_norm"
+ else nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(
+ in_channels,
+ hidden_channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ padding_mode=padding_mode,
+ ),
+ nn.GroupNorm(hidden_channels // 32, hidden_channels)
+ if hidden_norm == "group_norm"
+ else nn.GroupNorm(1, hidden_channels)
+ if hidden_norm == "layer_norm"
+ else nn.InstanceNorm2d(hidden_channels)
+ if hidden_norm == "instance_norm"
+ else nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(
+ hidden_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ padding_mode=padding_mode,
+ ),
+ )
+
+ self.skip_connection = (
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
+ if in_channels != out_channels
+ else nn.Identity()
+ )
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+class DINOv2Encoder(nn.Module):
+ "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
+
+ backbone: DinoVisionTransformer
+ image_mean: torch.Tensor
+ image_std: torch.Tensor
+ dim_features: int
+
+ def __init__(
+ self,
+ backbone: str,
+ intermediate_layers: Union[int, List[int]],
+ dim_out: int,
+ **deprecated_kwargs,
+ ):
+ super(DINOv2Encoder, self).__init__()
+
+ self.intermediate_layers = intermediate_layers
+
+ # Load the backbone
+ self.hub_loader = getattr(
+ importlib.import_module(
+ "mapanything.models.external.dinov2.hub.backbones", __package__
+ ),
+ backbone,
+ )
+ self.backbone_name = backbone
+ self.backbone = self.hub_loader(pretrained=False)
+
+ self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
+ self.num_features = (
+ intermediate_layers
+ if isinstance(intermediate_layers, int)
+ else len(intermediate_layers)
+ )
+
+ self.output_projections = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=self.dim_features,
+ out_channels=dim_out,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for _ in range(self.num_features)
+ ]
+ )
+
+ self.register_buffer(
+ "image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
+ )
+ self.register_buffer(
+ "image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
+ )
+
+ @property
+ def onnx_compatible_mode(self):
+ return getattr(self, "_onnx_compatible_mode", False)
+
+ @onnx_compatible_mode.setter
+ def onnx_compatible_mode(self, value: bool):
+ self._onnx_compatible_mode = value
+ self.backbone.onnx_compatible_mode = value
+
+ def init_weights(self):
+ pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
+ self.backbone.load_state_dict(pretrained_backbone_state_dict)
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
+
+ def enable_pytorch_native_sdpa(self):
+ for i in range(len(self.backbone.blocks)):
+ wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
+
+ def forward(
+ self,
+ image: torch.Tensor,
+ token_rows: Union[int, torch.LongTensor],
+ token_cols: Union[int, torch.LongTensor],
+ return_class_token: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ image_14 = F.interpolate(
+ image,
+ (token_rows * 14, token_cols * 14),
+ mode="bilinear",
+ align_corners=False,
+ antialias=not self.onnx_compatible_mode,
+ )
+ image_14 = (image_14 - self.image_mean) / self.image_std
+
+ # Get intermediate layers from the backbone
+ features = self.backbone.get_intermediate_layers(
+ image_14, n=self.intermediate_layers, return_class_token=True
+ )
+
+ # Project features to the desired dimensionality
+ x = torch.stack(
+ [
+ proj(
+ feat.permute(0, 2, 1)
+ .unflatten(2, (token_rows, token_cols))
+ .contiguous()
+ )
+ for proj, (feat, clstoken) in zip(self.output_projections, features)
+ ],
+ dim=1,
+ ).sum(dim=1)
+
+ if return_class_token:
+ return x, features[-1][1]
+ else:
+ return x
+
+
+class Resampler(nn.Sequential):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ type_: Literal[
+ "pixel_shuffle",
+ "nearest",
+ "bilinear",
+ "conv_transpose",
+ "pixel_unshuffle",
+ "avg_pool",
+ "max_pool",
+ ],
+ scale_factor: int = 2,
+ ):
+ if type_ == "pixel_shuffle":
+ nn.Sequential.__init__(
+ self,
+ nn.Conv2d(
+ in_channels,
+ out_channels * (scale_factor**2),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ nn.PixelShuffle(scale_factor),
+ nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ )
+ for i in range(1, scale_factor**2):
+ self[0].weight.data[i :: scale_factor**2] = self[0].weight.data[
+ 0 :: scale_factor**2
+ ]
+ self[0].bias.data[i :: scale_factor**2] = self[0].bias.data[
+ 0 :: scale_factor**2
+ ]
+ elif type_ in ["nearest", "bilinear"]:
+ nn.Sequential.__init__(
+ self,
+ nn.Upsample(
+ scale_factor=scale_factor,
+ mode=type_,
+ align_corners=False if type_ == "bilinear" else None,
+ ),
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ )
+ elif type_ == "conv_transpose":
+ nn.Sequential.__init__(
+ self,
+ nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=scale_factor,
+ stride=scale_factor,
+ ),
+ nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ )
+ self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
+ elif type_ == "pixel_unshuffle":
+ nn.Sequential.__init__(
+ self,
+ nn.PixelUnshuffle(scale_factor),
+ nn.Conv2d(
+ in_channels * (scale_factor**2),
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ )
+ elif type_ == "avg_pool":
+ nn.Sequential.__init__(
+ self,
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ elif type_ == "max_pool":
+ nn.Sequential.__init__(
+ self,
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ else:
+ raise ValueError(f"Unsupported resampler type: {type_}")
+
+
+class MLP(nn.Sequential):
+ def __init__(self, dims: Sequence[int]):
+ nn.Sequential.__init__(
+ self,
+ *itertools.chain(
+ *[
+ (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
+ for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
+ ]
+ ),
+ nn.Linear(dims[-2], dims[-1]),
+ )
+
+
+class ConvStack(nn.Module):
+ def __init__(
+ self,
+ dim_in: List[Optional[int]],
+ dim_res_blocks: List[int],
+ dim_out: List[Optional[int]],
+ resamplers: Union[
+ Literal[
+ "pixel_shuffle",
+ "nearest",
+ "bilinear",
+ "conv_transpose",
+ "pixel_unshuffle",
+ "avg_pool",
+ "max_pool",
+ ],
+ List,
+ ],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_in_norm: Literal[
+ "layer_norm", "group_norm", "instance_norm", "none"
+ ] = "layer_norm",
+ res_block_hidden_norm: Literal[
+ "layer_norm", "group_norm", "instance_norm", "none"
+ ] = "group_norm",
+ activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
+ ):
+ super().__init__()
+ self.input_blocks = nn.ModuleList(
+ [
+ nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0)
+ if dim_in_ is not None
+ else nn.Identity()
+ for dim_in_, dim_res_block_ in zip(
+ dim_in
+ if isinstance(dim_in, Sequence)
+ else itertools.repeat(dim_in),
+ dim_res_blocks,
+ )
+ ]
+ )
+ self.resamplers = nn.ModuleList(
+ [
+ Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
+ for i, (dim_prev, dim_succ, resampler) in enumerate(
+ zip(
+ dim_res_blocks[:-1],
+ dim_res_blocks[1:],
+ resamplers
+ if isinstance(resamplers, Sequence)
+ else itertools.repeat(resamplers),
+ )
+ )
+ ]
+ )
+ self.res_blocks = nn.ModuleList(
+ [
+ nn.Sequential(
+ *(
+ ResidualConvBlock(
+ dim_res_block_,
+ dim_res_block_,
+ dim_times_res_block_hidden * dim_res_block_,
+ activation=activation,
+ in_norm=res_block_in_norm,
+ hidden_norm=res_block_hidden_norm,
+ )
+ for _ in range(
+ num_res_blocks[i]
+ if isinstance(num_res_blocks, list)
+ else num_res_blocks
+ )
+ )
+ )
+ for i, dim_res_block_ in enumerate(dim_res_blocks)
+ ]
+ )
+ self.output_blocks = nn.ModuleList(
+ [
+ nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0)
+ if dim_out_ is not None
+ else nn.Identity()
+ for dim_out_, dim_res_block_ in zip(
+ dim_out
+ if isinstance(dim_out, Sequence)
+ else itertools.repeat(dim_out),
+ dim_res_blocks,
+ )
+ ]
+ )
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.resamplers)):
+ self.resamplers[i] = wrap_module_with_gradient_checkpointing(
+ self.resamplers[i]
+ )
+ for i in range(len(self.res_blocks)):
+ for j in range(len(self.res_blocks[i])):
+ self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(
+ self.res_blocks[i][j]
+ )
+
+ def forward(self, in_features: List[torch.Tensor]):
+ out_features = []
+ for i in range(len(self.res_blocks)):
+ feature = self.input_blocks[i](in_features[i])
+ if i == 0:
+ x = feature
+ elif feature is not None:
+ x = x + feature
+ x = self.res_blocks[i](x)
+ out_features.append(self.output_blocks[i](x))
+ if i < len(self.res_blocks) - 1:
+ x = self.resamplers[i](x)
+ return out_features
diff --git a/mapanything/models/external/moge/models/utils.py b/mapanything/models/external/moge/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0052a758dd6ffc1fa7579e423c8a925aa146ac54
--- /dev/null
+++ b/mapanything/models/external/moge/models/utils.py
@@ -0,0 +1,477 @@
+import inspect
+from functools import partial, wraps
+from numbers import Number
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def wrap_module_with_gradient_checkpointing(module: nn.Module):
+ from torch.utils.checkpoint import checkpoint
+
+ class _CheckpointingWrapper(module.__class__):
+ _restore_cls = module.__class__
+
+ def forward(self, *args, **kwargs):
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
+
+ module.__class__ = _CheckpointingWrapper
+ return module
+
+
+def unwrap_module_with_gradient_checkpointing(module: nn.Module):
+ module.__class__ = module.__class__._restore_cls
+
+
+def wrap_dinov2_attention_with_sdpa(module: nn.Module):
+ assert torch.__version__ >= "2.0", "SDPA requires PyTorch 2.0 or later"
+
+ class _AttentionWrapper(module.__class__):
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ ) # (3, B, H, N, C // H)
+
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
+
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ module.__class__ = _AttentionWrapper
+ return module
+
+
+def sync_ddp_hook(
+ state, bucket: torch.distributed.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+ group_to_use = torch.distributed.group.WORLD
+ world_size = group_to_use.size()
+ grad = bucket.buffer()
+ grad.div_(world_size)
+ torch.distributed.all_reduce(grad, group=group_to_use)
+ fut = torch.futures.Future()
+ fut.set_result(grad)
+ return fut
+
+
+def normalized_view_plane_uv(
+ width: int,
+ height: int,
+ aspect_ratio: float = None,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+) -> torch.Tensor:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio**2) ** 0.5
+
+ u = torch.linspace(
+ -span_x * (width - 1) / width,
+ span_x * (width - 1) / width,
+ width,
+ dtype=dtype,
+ device=device,
+ )
+ v = torch.linspace(
+ -span_y * (height - 1) / height,
+ span_y * (height - 1) / height,
+ height,
+ dtype=dtype,
+ device=device,
+ )
+ u, v = torch.meshgrid(u, v, indexing="xy")
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+
+def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
+ from scipy.optimize import least_squares
+
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[:, None]
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+ err = (f * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm")
+ optim_shift = solution["x"].squeeze().astype(np.float32)
+
+ xy_proj = xy / (z + optim_shift)[:, None]
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+
+ return optim_shift, optim_focal
+
+
+def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
+ from scipy.optimize import least_squares
+
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[:, None]
+ err = (focal * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm")
+ optim_shift = solution["x"].squeeze().astype(np.float32)
+
+ return optim_shift
+
+
+def recover_focal_shift(
+ points: torch.Tensor,
+ mask: torch.Tensor = None,
+ focal: torch.Tensor = None,
+ downsample_size: Tuple[int, int] = (64, 64),
+):
+ """
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
+
+ Note that it assumes:
+ - the optical center is at the center of the map
+ - the map is undistorted
+ - the map is isometric in the x and y directions
+
+ ### Parameters:
+ - `points: torch.Tensor` of shape (..., H, W, 3)
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
+
+ ### Returns:
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
+ """
+ shape = points.shape
+ height, width = points.shape[-3], points.shape[-2]
+
+ points = points.reshape(-1, *shape[-3:])
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
+ focal = focal.reshape(-1) if focal is not None else None
+ uv = normalized_view_plane_uv(
+ width, height, dtype=points.dtype, device=points.device
+ ) # (H, W, 2)
+
+ points_lr = F.interpolate(
+ points.permute(0, 3, 1, 2), downsample_size, mode="nearest"
+ ).permute(0, 2, 3, 1)
+ uv_lr = (
+ F.interpolate(
+ uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest"
+ )
+ .squeeze(0)
+ .permute(1, 2, 0)
+ )
+ mask_lr = (
+ None
+ if mask is None
+ else F.interpolate(
+ mask.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest"
+ ).squeeze(1)
+ > 0
+ )
+
+ uv_lr_np = uv_lr.cpu().numpy()
+ points_lr_np = points_lr.detach().cpu().numpy()
+ focal_np = focal.cpu().numpy() if focal is not None else None
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
+ optim_shift, optim_focal = [], []
+ for i in range(points.shape[0]):
+ points_lr_i_np = (
+ points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
+ )
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
+ if uv_lr_i_np.shape[0] < 2:
+ optim_focal.append(1)
+ optim_shift.append(0)
+ continue
+ if focal is None:
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(
+ uv_lr_i_np, points_lr_i_np
+ )
+ optim_focal.append(float(optim_focal_i))
+ else:
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
+ optim_shift.append(float(optim_shift_i))
+ optim_shift = torch.tensor(
+ optim_shift, device=points.device, dtype=points.dtype
+ ).reshape(shape[:-3])
+
+ if focal is None:
+ optim_focal = torch.tensor(
+ optim_focal, device=points.device, dtype=points.dtype
+ ).reshape(shape[:-3])
+ else:
+ optim_focal = focal.reshape(shape[:-3])
+
+ return optim_focal, optim_shift
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+
+ return wrapper
+
+
+def get_device(args, kwargs):
+ device = None
+ for arg in list(args) + list(kwargs.values()):
+ if isinstance(arg, torch.Tensor):
+ if device is None:
+ device = arg.device
+ elif device != arg.device:
+ raise ValueError("All tensors must be on the same device.")
+ return device
+
+
+def get_args_order(func, args, kwargs):
+ """
+ Get the order of the arguments of a function.
+ """
+ names = inspect.getfullargspec(func).args
+ names_idx = {name: i for i, name in enumerate(names)}
+ args_order = []
+ kwargs_order = {}
+ for name, arg in kwargs.items():
+ if name in names:
+ kwargs_order[name] = names_idx[name]
+ names.remove(name)
+ for i, arg in enumerate(args):
+ if i < len(names):
+ args_order.append(names_idx[names[i]])
+ return args_order, kwargs_order
+
+
+def broadcast_args(args, kwargs, args_dim, kwargs_dim):
+ spatial = []
+ for arg, arg_dim in zip(
+ args + list(kwargs.values()), args_dim + list(kwargs_dim.values())
+ ):
+ if isinstance(arg, torch.Tensor) and arg_dim is not None:
+ arg_spatial = arg.shape[: arg.ndim - arg_dim]
+ if len(arg_spatial) > len(spatial):
+ spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
+ for j in range(len(arg_spatial)):
+ if spatial[-j] < arg_spatial[-j]:
+ if spatial[-j] == 1:
+ spatial[-j] = arg_spatial[-j]
+ else:
+ raise ValueError("Cannot broadcast arguments.")
+ for i, arg in enumerate(args):
+ if isinstance(arg, torch.Tensor) and args_dim[i] is not None:
+ args[i] = torch.broadcast_to(
+ arg, [*spatial, *arg.shape[arg.ndim - args_dim[i] :]]
+ )
+ for key, arg in kwargs.items():
+ if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
+ kwargs[key] = torch.broadcast_to(
+ arg, [*spatial, *arg.shape[arg.ndim - kwargs_dim[key] :]]
+ )
+ return args, kwargs, spatial
+
+
+@suppress_traceback
+def batched(*dims):
+ """
+ Decorator that allows a function to be called with batched arguments.
+ """
+
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, device=torch.device("cpu"), **kwargs):
+ args = list(args)
+ # get arguments dimensions
+ args_order, kwargs_order = get_args_order(func, args, kwargs)
+ args_dim = [dims[i] for i in args_order]
+ kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
+ # convert to torch tensor
+ device = get_device(args, kwargs) or device
+ for i, arg in enumerate(args):
+ if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
+ args[i] = torch.tensor(arg, device=device)
+ for key, arg in kwargs.items():
+ if (
+ isinstance(arg, (Number, list, tuple))
+ and kwargs_dim[key] is not None
+ ):
+ kwargs[key] = torch.tensor(arg, device=device)
+ # broadcast arguments
+ args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
+ for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
+ if isinstance(arg, torch.Tensor) and arg_dim is not None:
+ args[i] = arg.reshape([-1, *arg.shape[arg.ndim - arg_dim :]])
+ for key, arg in kwargs.items():
+ if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
+ kwargs[key] = arg.reshape(
+ [-1, *arg.shape[arg.ndim - kwargs_dim[key] :]]
+ )
+ # call function
+ results = func(*args, **kwargs)
+ type_results = type(results)
+ results = list(results) if isinstance(results, (tuple, list)) else [results]
+ # restore spatial dimensions
+ for i, result in enumerate(results):
+ results[i] = result.reshape([*spatial, *result.shape[1:]])
+ if type_results is tuple:
+ results = tuple(results)
+ elif type_results is list:
+ results = list(results)
+ else:
+ results = results[0]
+ return results
+
+ return wrapper
+
+ return decorator
+
+
+def image_uv(
+ height: int,
+ width: int,
+ left: int = None,
+ top: int = None,
+ right: int = None,
+ bottom: int = None,
+ device: torch.device = None,
+ dtype: torch.dtype = None,
+) -> torch.Tensor:
+ """
+ Get image space UV grid, ranging in [0, 1].
+
+ >>> image_uv(10, 10):
+ [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
+ [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
+ ... ... ...
+ [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
+
+ Args:
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ torch.Tensor: shape (height, width, 2)
+ """
+ if left is None:
+ left = 0
+ if top is None:
+ top = 0
+ if right is None:
+ right = width
+ if bottom is None:
+ bottom = height
+ u = torch.linspace(
+ (left + 0.5) / width,
+ (right - 0.5) / width,
+ right - left,
+ device=device,
+ dtype=dtype,
+ )
+ v = torch.linspace(
+ (top + 0.5) / height,
+ (bottom - 0.5) / height,
+ bottom - top,
+ device=device,
+ dtype=dtype,
+ )
+ u, v = torch.meshgrid(u, v, indexing="xy")
+ uv = torch.stack([u, v], dim=-1)
+
+ return uv
+
+
+@batched(2, 1, 2, 2)
+def unproject_cv(
+ uv_coord: torch.Tensor,
+ depth: torch.Tensor = None,
+ extrinsics: torch.Tensor = None,
+ intrinsics: torch.Tensor = None,
+) -> torch.Tensor:
+ """
+ Unproject uv coordinates to 3D view space following the OpenCV convention
+
+ Args:
+ uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ depth (torch.Tensor): [..., N] depth value
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
+ intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
+
+ Returns:
+ points (torch.Tensor): [..., N, 3] 3d points
+ """
+ assert intrinsics is not None, "intrinsics matrix is required"
+ points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
+ points = points @ torch.inverse(intrinsics).transpose(-2, -1)
+ if depth is not None:
+ points = points * depth[..., None]
+ if extrinsics is not None:
+ points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+ points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
+ return points
+
+
+def depth_to_points(
+ depth: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor = None
+):
+ height, width = depth.shape[-2:]
+ uv = image_uv(width=width, height=height, dtype=depth.dtype, device=depth.device)
+ pts = unproject_cv(
+ uv,
+ depth,
+ intrinsics=intrinsics[..., None, :, :],
+ extrinsics=extrinsics[..., None, :, :] if extrinsics is not None else None,
+ )
+
+ return pts
+
+
+@batched(0, 0, 0, 0, 0, 0)
+def intrinsics_from_focal_center(
+ fx: Union[float, torch.Tensor],
+ fy: Union[float, torch.Tensor],
+ cx: Union[float, torch.Tensor],
+ cy: Union[float, torch.Tensor],
+) -> torch.Tensor:
+ """
+ Get OpenCV intrinsics matrix
+
+ Args:
+ focal_x (float | torch.Tensor): focal length in x axis
+ focal_y (float | torch.Tensor): focal length in y axis
+ cx (float | torch.Tensor): principal point in x axis
+ cy (float | torch.Tensor): principal point in y axis
+
+ Returns:
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
+ """
+ N = fx.shape[0]
+ ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device)
+ zeros, ones = (
+ torch.zeros(N, dtype=fx.dtype, device=fx.device),
+ torch.ones(N, dtype=fx.dtype, device=fx.device),
+ )
+ ret = torch.stack(
+ [fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1
+ ).unflatten(-1, (3, 3))
+ return ret
diff --git a/mapanything/models/external/moge/models/v1.py b/mapanything/models/external/moge/models/v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8a15bd9280a078e57ec00762a89922ea3864ca
--- /dev/null
+++ b/mapanything/models/external/moge/models/v1.py
@@ -0,0 +1,595 @@
+import importlib
+from numbers import Number
+from pathlib import Path
+from typing import Any, Dict, IO, List, Literal, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.version
+from huggingface_hub import hf_hub_download
+
+from mapanything.models.external.moge.models.utils import (
+ depth_to_points,
+ intrinsics_from_focal_center,
+ normalized_view_plane_uv,
+ recover_focal_shift,
+ wrap_module_with_gradient_checkpointing,
+)
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = None,
+ hidden_channels: int = None,
+ padding_mode: str = "replicate",
+ activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
+ norm: Literal["group_norm", "layer_norm"] = "group_norm",
+ ):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation == "relu":
+ activation_cls = lambda: nn.ReLU(inplace=True) # noqa
+ elif activation == "leaky_relu":
+ activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) # noqa
+ elif activation == "silu":
+ activation_cls = lambda: nn.SiLU(inplace=True) # noqa
+ elif activation == "elu":
+ activation_cls = lambda: nn.ELU(inplace=True) # noqa
+ else:
+ raise ValueError(f"Unsupported activation function: {activation}")
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(1, in_channels),
+ activation_cls(),
+ nn.Conv2d(
+ in_channels,
+ hidden_channels,
+ kernel_size=3,
+ padding=1,
+ padding_mode=padding_mode,
+ ),
+ nn.GroupNorm(
+ hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels
+ ),
+ activation_cls(),
+ nn.Conv2d(
+ hidden_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ padding_mode=padding_mode,
+ ),
+ )
+
+ self.skip_connection = (
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
+ if in_channels != out_channels
+ else nn.Identity()
+ )
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+class Head(nn.Module):
+ def __init__(
+ self,
+ num_features: int,
+ dim_in: int,
+ dim_out: List[int],
+ dim_proj: int = 512,
+ dim_upsample: List[int] = [256, 128, 128],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm",
+ last_res_blocks: int = 0,
+ last_conv_channels: int = 32,
+ last_conv_size: int = 1,
+ ):
+ super().__init__()
+
+ self.projects = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=dim_in,
+ out_channels=dim_proj,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for _ in range(num_features)
+ ]
+ )
+
+ self.upsample_blocks = nn.ModuleList(
+ [
+ nn.Sequential(
+ self._make_upsampler(in_ch + 2, out_ch),
+ *(
+ ResidualConvBlock(
+ out_ch,
+ out_ch,
+ dim_times_res_block_hidden * out_ch,
+ activation="relu",
+ norm=res_block_norm,
+ )
+ for _ in range(num_res_blocks)
+ ),
+ )
+ for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
+ ]
+ )
+
+ self.output_block = nn.ModuleList(
+ [
+ self._make_output_block(
+ dim_upsample[-1] + 2,
+ dim_out_,
+ dim_times_res_block_hidden,
+ last_res_blocks,
+ last_conv_channels,
+ last_conv_size,
+ res_block_norm,
+ )
+ for dim_out_ in dim_out
+ ]
+ )
+
+ def _make_upsampler(self, in_channels: int, out_channels: int):
+ upsampler = nn.Sequential(
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
+ nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ )
+ upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
+ return upsampler
+
+ def _make_output_block(
+ self,
+ dim_in: int,
+ dim_out: int,
+ dim_times_res_block_hidden: int,
+ last_res_blocks: int,
+ last_conv_channels: int,
+ last_conv_size: int,
+ res_block_norm: Literal["group_norm", "layer_norm"],
+ ):
+ return nn.Sequential(
+ nn.Conv2d(
+ dim_in,
+ last_conv_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ padding_mode="replicate",
+ ),
+ *(
+ ResidualConvBlock(
+ last_conv_channels,
+ last_conv_channels,
+ dim_times_res_block_hidden * last_conv_channels,
+ activation="relu",
+ norm=res_block_norm,
+ )
+ for _ in range(last_res_blocks)
+ ),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(
+ last_conv_channels,
+ dim_out,
+ kernel_size=last_conv_size,
+ stride=1,
+ padding=last_conv_size // 2,
+ padding_mode="replicate",
+ ),
+ )
+
+ def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
+ img_h, img_w = image.shape[-2:]
+ patch_h, patch_w = img_h // 14, img_w // 14
+
+ # Process the hidden states
+ x = torch.stack(
+ [
+ proj(
+ feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()
+ )
+ for proj, (feat, clstoken) in zip(self.projects, hidden_states)
+ ],
+ dim=1,
+ ).sum(dim=1)
+
+ # Upsample stage
+ # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
+ for i, block in enumerate(self.upsample_blocks):
+ # UV coordinates is for awareness of image aspect ratio
+ uv = normalized_view_plane_uv(
+ width=x.shape[-1],
+ height=x.shape[-2],
+ aspect_ratio=img_w / img_h,
+ dtype=x.dtype,
+ device=x.device,
+ )
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
+ x = torch.cat([x, uv], dim=1)
+ for layer in block:
+ x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
+
+ # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
+ x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
+ uv = normalized_view_plane_uv(
+ width=x.shape[-1],
+ height=x.shape[-2],
+ aspect_ratio=img_w / img_h,
+ dtype=x.dtype,
+ device=x.device,
+ )
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
+ x = torch.cat([x, uv], dim=1)
+
+ if isinstance(self.output_block, nn.ModuleList):
+ output = [
+ torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
+ for block in self.output_block
+ ]
+ else:
+ output = torch.utils.checkpoint.checkpoint(
+ self.output_block, x, use_reentrant=False
+ )
+
+ return output
+
+
+class MoGeModel(nn.Module):
+ image_mean: torch.Tensor
+ image_std: torch.Tensor
+
+ def __init__(
+ self,
+ encoder: str = "dinov2_vitb14",
+ intermediate_layers: Union[int, List[int]] = 4,
+ dim_proj: int = 512,
+ dim_upsample: List[int] = [256, 128, 128],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ remap_output: Literal[
+ False, True, "linear", "sinh", "exp", "sinh_exp"
+ ] = "linear",
+ res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm",
+ num_tokens_range: Tuple[Number, Number] = [1200, 2500],
+ last_res_blocks: int = 0,
+ last_conv_channels: int = 32,
+ last_conv_size: int = 1,
+ mask_threshold: float = 0.5,
+ **deprecated_kwargs,
+ ):
+ super(MoGeModel, self).__init__()
+
+ if deprecated_kwargs:
+ # Process legacy arguments
+ if "trained_area_range" in deprecated_kwargs:
+ num_tokens_range = [
+ deprecated_kwargs["trained_area_range"][0] // 14**2,
+ deprecated_kwargs["trained_area_range"][1] // 14**2,
+ ]
+ del deprecated_kwargs["trained_area_range"]
+ # warnings.warn(
+ # f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}"
+ # )
+
+ self.encoder = encoder
+ self.remap_output = remap_output
+ self.intermediate_layers = intermediate_layers
+ self.num_tokens_range = num_tokens_range
+ self.mask_threshold = mask_threshold
+
+ # NOTE: We have copied the DINOv2 code in torchhub to this repository.
+ # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
+ hub_loader = getattr(
+ importlib.import_module(
+ "mapanything.models.external.dinov2.hub.backbones", __package__
+ ),
+ encoder,
+ )
+ self.backbone = hub_loader(pretrained=False)
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
+
+ self.head = Head(
+ num_features=intermediate_layers
+ if isinstance(intermediate_layers, int)
+ else len(intermediate_layers),
+ dim_in=dim_feature,
+ dim_out=[3, 1],
+ dim_proj=dim_proj,
+ dim_upsample=dim_upsample,
+ dim_times_res_block_hidden=dim_times_res_block_hidden,
+ num_res_blocks=num_res_blocks,
+ res_block_norm=res_block_norm,
+ last_res_blocks=last_res_blocks,
+ last_conv_channels=last_conv_channels,
+ last_conv_size=last_conv_size,
+ )
+
+ image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
+ image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
+
+ self.register_buffer("image_mean", image_mean)
+ self.register_buffer("image_std", image_std)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ **hf_kwargs,
+ ) -> "MoGeModel":
+ """
+ Load a model from a checkpoint file.
+
+ ### Parameters:
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
+
+ ### Returns:
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
+ """
+ if Path(pretrained_model_name_or_path).exists():
+ checkpoint = torch.load(
+ pretrained_model_name_or_path, map_location="cpu", weights_only=True
+ )
+ else:
+ cached_checkpoint_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path,
+ repo_type="model",
+ filename="model.pt",
+ **hf_kwargs,
+ )
+ checkpoint = torch.load(
+ cached_checkpoint_path, map_location="cpu", weights_only=True
+ )
+ model_config = checkpoint["model_config"]
+ if model_kwargs is not None:
+ model_config.update(model_kwargs)
+ model = cls(**model_config)
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+ def init_weights(self):
+ "Load the backbone with pretrained dinov2 weights from torch hub"
+ state_dict = torch.hub.load(
+ "facebookresearch/dinov2", self.encoder, pretrained=True
+ ).state_dict()
+ self.backbone.load_state_dict(state_dict)
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(
+ self.backbone.blocks[i]
+ )
+
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
+ if self.remap_output == "linear":
+ pass
+ elif self.remap_output == "sinh":
+ points = torch.sinh(points)
+ elif self.remap_output == "exp":
+ xy, z = points.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ points = torch.cat([xy * z, z], dim=-1)
+ elif self.remap_output == "sinh_exp":
+ xy, z = points.split([2, 1], dim=-1)
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
+ else:
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
+ return points
+
+ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
+ original_height, original_width = image.shape[-2:]
+
+ # Resize to expected resolution defined by num_tokens
+ resize_factor = (
+ (num_tokens * 14**2) / (original_height * original_width)
+ ) ** 0.5
+ resized_width, resized_height = (
+ int(original_width * resize_factor),
+ int(original_height * resize_factor),
+ )
+ image = F.interpolate(
+ image,
+ (resized_height, resized_width),
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ )
+
+ # Apply image transformation for DINOv2
+ image = (image - self.image_mean) / self.image_std
+ image_14 = F.interpolate(
+ image,
+ (resized_height // 14 * 14, resized_width // 14 * 14),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+
+ # Get intermediate layers from the backbone
+ features = self.backbone.get_intermediate_layers(
+ image_14, self.intermediate_layers, return_class_token=True
+ )
+
+ # Predict points (and mask)
+ output = self.head(features, image)
+ points, mask = output
+
+ # Make sure fp32 precision for output
+ with torch.autocast(device_type=image.device.type, dtype=torch.float32):
+ # Resize to original resolution
+ points = F.interpolate(
+ points,
+ (original_height, original_width),
+ mode="bilinear",
+ align_corners=False,
+ antialias=False,
+ )
+ mask = F.interpolate(
+ mask,
+ (original_height, original_width),
+ mode="bilinear",
+ align_corners=False,
+ antialias=False,
+ )
+
+ # Post-process points and mask
+ points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
+ points = self._remap_points(
+ points
+ ) # slightly improves the performance in case of very large output values
+
+ return_dict = {"points": points, "mask": mask}
+ return return_dict
+
+ @torch.inference_mode()
+ def infer(
+ self,
+ image: torch.Tensor,
+ fov_x: Union[Number, torch.Tensor] = None,
+ resolution_level: int = 9,
+ num_tokens: int = None,
+ apply_mask: bool = True,
+ force_projection: bool = True,
+ use_fp16: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ User-friendly inference function
+
+ ### Parameters
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
+ - `resolution_level`: An integer [0-9] for the resolution level for inference.
+ The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size.
+ `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.
+ - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`.
+ `resolution_level` will be ignored if `num_tokens` is provided. Default: None
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
+ - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
+
+ ### Returns
+
+ A dictionary containing the following keys:
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
+ """
+ if image.dim() == 3:
+ omit_batch_dim = True
+ image = image.unsqueeze(0)
+ else:
+ omit_batch_dim = False
+ image = image.to(dtype=self.dtype, device=self.device)
+
+ original_height, original_width = image.shape[-2:]
+ aspect_ratio = original_width / original_height
+
+ if num_tokens is None:
+ min_tokens, max_tokens = self.num_tokens_range
+ num_tokens = int(
+ min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)
+ )
+
+ with torch.autocast(
+ device_type=self.device.type,
+ dtype=torch.float16,
+ enabled=use_fp16 and self.dtype != torch.float16,
+ ):
+ output = self.forward(image, num_tokens)
+ points, mask = output["points"], output["mask"]
+
+ # Always process the output in fp32 precision
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ points, mask, fov_x = map(
+ lambda x: x.float() if isinstance(x, torch.Tensor) else x,
+ [points, mask, fov_x],
+ )
+
+ mask_binary = mask > self.mask_threshold
+
+ # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
+ if fov_x is None:
+ focal, shift = recover_focal_shift(points, mask_binary)
+ else:
+ focal = (
+ aspect_ratio
+ / (1 + aspect_ratio**2) ** 0.5
+ / torch.tan(
+ torch.deg2rad(
+ torch.as_tensor(
+ fov_x, device=points.device, dtype=points.dtype
+ )
+ / 2
+ )
+ )
+ )
+ if focal.ndim == 0:
+ focal = focal[None].expand(points.shape[0])
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
+ fx = focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio
+ fy = focal / 2 * (1 + aspect_ratio**2) ** 0.5
+ intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
+ depth = points[..., 2] + shift[..., None, None]
+
+ # If projection constraint is forced, recompute the point map using the actual depth map
+ if force_projection:
+ points = depth_to_points(depth, intrinsics=intrinsics)
+ else:
+ points = (
+ points
+ + torch.stack(
+ [torch.zeros_like(shift), torch.zeros_like(shift), shift],
+ dim=-1,
+ )[..., None, None, :]
+ )
+
+ # Apply mask if needed
+ if apply_mask:
+ points = torch.where(mask_binary[..., None], points, torch.inf)
+ depth = torch.where(mask_binary, depth, torch.inf)
+
+ return_dict = {
+ "points": points,
+ "intrinsics": intrinsics,
+ "depth": depth,
+ "mask": mask_binary,
+ }
+
+ if omit_batch_dim:
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
+
+ return return_dict
diff --git a/mapanything/models/external/moge/models/v2.py b/mapanything/models/external/moge/models/v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d623400208b034bf55fb6156dceff10158d7e87d
--- /dev/null
+++ b/mapanything/models/external/moge/models/v2.py
@@ -0,0 +1,379 @@
+import warnings
+from numbers import Number
+from pathlib import Path
+from typing import Any, Dict, IO, List, Literal, Optional, Union
+
+import torch
+import torch.amp
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.version
+from huggingface_hub import hf_hub_download
+
+from mapanything.models.external.moge.models.modules import (
+ ConvStack,
+ DINOv2Encoder,
+ MLP,
+)
+from mapanything.models.external.moge.models.utils import (
+ depth_to_points,
+ intrinsics_from_focal_center,
+ normalized_view_plane_uv,
+ recover_focal_shift,
+)
+
+
+class MoGeModel(nn.Module):
+ encoder: DINOv2Encoder
+ neck: ConvStack
+ points_head: ConvStack
+ mask_head: ConvStack
+ scale_head: MLP
+ onnx_compatible_mode: bool
+
+ def __init__(
+ self,
+ encoder: Dict[str, Any],
+ neck: Dict[str, Any],
+ points_head: Dict[str, Any] = None,
+ mask_head: Dict[str, Any] = None,
+ normal_head: Dict[str, Any] = None,
+ scale_head: Dict[str, Any] = None,
+ remap_output: Literal["linear", "sinh", "exp", "sinh_exp"] = "linear",
+ num_tokens_range: List[int] = [1200, 3600],
+ **deprecated_kwargs,
+ ):
+ super(MoGeModel, self).__init__()
+ if deprecated_kwargs:
+ warnings.warn(
+ f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}"
+ )
+
+ self.remap_output = remap_output
+ self.num_tokens_range = num_tokens_range
+
+ self.encoder = DINOv2Encoder(**encoder)
+ self.neck = ConvStack(**neck)
+ if points_head is not None:
+ self.points_head = ConvStack(**points_head)
+ if mask_head is not None:
+ self.mask_head = ConvStack(**mask_head)
+ if normal_head is not None:
+ self.normal_head = ConvStack(**normal_head)
+ if scale_head is not None:
+ self.scale_head = MLP(**scale_head)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ @property
+ def onnx_compatible_mode(self) -> bool:
+ return getattr(self, "_onnx_compatible_mode", False)
+
+ @onnx_compatible_mode.setter
+ def onnx_compatible_mode(self, value: bool):
+ self._onnx_compatible_mode = value
+ self.encoder.onnx_compatible_mode = value
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ **hf_kwargs,
+ ) -> "MoGeModel":
+ """
+ Load a model from a checkpoint file.
+
+ ### Parameters:
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
+ - `compiled`
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
+
+ ### Returns:
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
+ """
+ if Path(pretrained_model_name_or_path).exists():
+ checkpoint_path = pretrained_model_name_or_path
+ else:
+ checkpoint_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path,
+ repo_type="model",
+ filename="model.pt",
+ **hf_kwargs,
+ )
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
+
+ model_config = checkpoint["model_config"]
+ if model_kwargs is not None:
+ model_config.update(model_kwargs)
+ model = cls(**model_config)
+ model.load_state_dict(checkpoint["model"], strict=False)
+
+ return model
+
+ def init_weights(self):
+ self.encoder.init_weights()
+
+ def enable_gradient_checkpointing(self):
+ self.encoder.enable_gradient_checkpointing()
+ self.neck.enable_gradient_checkpointing()
+ for head in ["points_head", "normal_head", "mask_head"]:
+ if hasattr(self, head):
+ getattr(self, head).enable_gradient_checkpointing()
+
+ def enable_pytorch_native_sdpa(self):
+ self.encoder.enable_pytorch_native_sdpa()
+
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
+ if self.remap_output == "linear":
+ pass
+ elif self.remap_output == "sinh":
+ points = torch.sinh(points)
+ elif self.remap_output == "exp":
+ xy, z = points.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ points = torch.cat([xy * z, z], dim=-1)
+ elif self.remap_output == "sinh_exp":
+ xy, z = points.split([2, 1], dim=-1)
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
+ else:
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
+ return points
+
+ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
+ batch_size, _, img_h, img_w = image.shape
+ device, dtype = image.device, image.dtype
+
+ aspect_ratio = img_w / img_h
+ base_h, base_w = (
+ int((num_tokens / aspect_ratio) ** 0.5),
+ int((num_tokens * aspect_ratio) ** 0.5),
+ )
+ num_tokens = base_h * base_w
+
+ # Backbones encoding
+ features, cls_token = self.encoder(
+ image, base_h, base_w, return_class_token=True
+ )
+ features = [features, None, None, None, None]
+
+ # Concat UVs for aspect ratio input
+ for level in range(5):
+ uv = normalized_view_plane_uv(
+ width=base_w * 2**level,
+ height=base_h * 2**level,
+ aspect_ratio=aspect_ratio,
+ dtype=dtype,
+ device=device,
+ )
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
+ if features[level] is None:
+ features[level] = uv
+ else:
+ features[level] = torch.concat([features[level], uv], dim=1)
+
+ # Shared neck
+ features = self.neck(features)
+
+ # Heads decoding
+ points, normal, mask = (
+ getattr(self, head)(features)[-1] if hasattr(self, head) else None
+ for head in ["points_head", "normal_head", "mask_head"]
+ )
+ metric_scale = (
+ self.scale_head(cls_token) if hasattr(self, "scale_head") else None
+ )
+
+ # Resize
+ points, normal, mask = (
+ F.interpolate(
+ v, (img_h, img_w), mode="bilinear", align_corners=False, antialias=False
+ )
+ if v is not None
+ else None
+ for v in [points, normal, mask]
+ )
+
+ # Remap output
+ if points is not None:
+ points = points.permute(0, 2, 3, 1)
+ points = self._remap_points(
+ points
+ ) # slightly improves the performance in case of very large output values
+ if normal is not None:
+ normal = normal.permute(0, 2, 3, 1)
+ normal = F.normalize(normal, dim=-1)
+ if mask is not None:
+ mask = mask.squeeze(1).sigmoid()
+ if metric_scale is not None:
+ metric_scale = metric_scale.squeeze(1).exp()
+
+ return_dict = {
+ "points": points,
+ "normal": normal,
+ "mask": mask,
+ "metric_scale": metric_scale,
+ }
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
+
+ return return_dict
+
+ @torch.inference_mode()
+ def infer(
+ self,
+ image: torch.Tensor,
+ num_tokens: int = None,
+ resolution_level: int = 9,
+ force_projection: bool = True,
+ apply_mask: Literal[False, True, "blend"] = True,
+ fov_x: Optional[Union[Number, torch.Tensor]] = None,
+ use_fp16: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ User-friendly inference function
+
+ ### Parameters
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
+ - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
+ More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
+ - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
+
+ ### Returns
+
+ A dictionary containing the following keys:
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
+ """
+ if image.dim() == 3:
+ omit_batch_dim = True
+ image = image.unsqueeze(0)
+ else:
+ omit_batch_dim = False
+ image = image.to(dtype=self.dtype, device=self.device)
+
+ original_height, original_width = image.shape[-2:]
+ aspect_ratio = original_width / original_height
+
+ # Determine the number of base tokens to use
+ if num_tokens is None:
+ min_tokens, max_tokens = self.num_tokens_range
+ num_tokens = int(
+ min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)
+ )
+
+ # Forward pass
+ with torch.autocast(
+ device_type=self.device.type,
+ dtype=torch.float16,
+ enabled=use_fp16 and self.dtype != torch.float16,
+ ):
+ output = self.forward(image, num_tokens=num_tokens)
+ points, normal, mask, metric_scale = (
+ output.get(k, None) for k in ["points", "normal", "mask", "metric_scale"]
+ )
+
+ # Always process the output in fp32 precision
+ points, normal, mask, metric_scale, fov_x = map(
+ lambda x: x.float() if isinstance(x, torch.Tensor) else x,
+ [points, normal, mask, metric_scale, fov_x],
+ )
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ if mask is not None:
+ mask_binary = mask > 0.5
+ else:
+ mask_binary = None
+
+ if points is not None:
+ # Convert affine point map to camera-space. Recover depth and intrinsics from point map.
+ # NOTE: Focal here is the focal length relative to half the image diagonal
+ if fov_x is None:
+ # Recover focal and shift from predicted point map
+ focal, shift = recover_focal_shift(points, mask_binary)
+ else:
+ # Focal is known, recover shift only
+ focal = (
+ aspect_ratio
+ / (1 + aspect_ratio**2) ** 0.5
+ / torch.tan(
+ torch.deg2rad(
+ torch.as_tensor(
+ fov_x, device=points.device, dtype=points.dtype
+ )
+ / 2
+ )
+ )
+ )
+ if focal.ndim == 0:
+ focal = focal[None].expand(points.shape[0])
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
+ fx, fy = (
+ focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio,
+ focal / 2 * (1 + aspect_ratio**2) ** 0.5,
+ )
+ intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
+ points[..., 2] += shift[..., None, None]
+ if mask_binary is not None:
+ mask_binary &= (
+ points[..., 2] > 0
+ ) # in case depth is contains negative values (which should never happen in practice)
+ depth = points[..., 2].clone()
+ else:
+ depth, intrinsics = None, None
+
+ # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
+ if force_projection and depth is not None:
+ points = depth_to_points(depth, intrinsics=intrinsics)
+
+ # Apply metric scale
+ if metric_scale is not None:
+ if points is not None:
+ points *= metric_scale[:, None, None, None]
+ if depth is not None:
+ depth *= metric_scale[:, None, None]
+
+ # Apply mask
+ if apply_mask and mask_binary is not None:
+ points = (
+ torch.where(mask_binary[..., None], points, torch.inf)
+ if points is not None
+ else None
+ )
+ depth = (
+ torch.where(mask_binary, depth, torch.inf)
+ if depth is not None
+ else None
+ )
+ normal = (
+ torch.where(
+ mask_binary[..., None], normal, torch.zeros_like(normal)
+ )
+ if normal is not None
+ else None
+ )
+
+ return_dict = {
+ "points": points,
+ "intrinsics": intrinsics,
+ "depth": depth,
+ "mask": mask_binary,
+ "normal": normal,
+ }
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
+
+ if omit_batch_dim:
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
+
+ return return_dict
diff --git a/mapanything/models/external/must3r/__init__.py b/mapanything/models/external/must3r/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d5ebd000f3a529c3020739a4b3977927a4d0c5e
--- /dev/null
+++ b/mapanything/models/external/must3r/__init__.py
@@ -0,0 +1,283 @@
+"""
+Inference wrapper for MUSt3R
+"""
+
+import datetime
+import os
+
+import numpy as np
+import torch
+from dust3r.viz import rgb
+from must3r.demo.inference import SceneState
+from must3r.engine.inference import inference_multi_ar, postprocess
+from must3r.model import get_pointmaps_activation, load_model
+
+from mapanything.models.external.vggt.utils.rotation import mat_to_quat
+
+
+def must3r_inference(
+ views,
+ filelist,
+ model,
+ retrieval,
+ device,
+ amp,
+ num_mem_images,
+ max_bs,
+ init_num_images=2,
+ batch_num_views=1,
+ render_once=False,
+ is_sequence=False,
+ viser_server=None,
+ num_refinements_iterations=2,
+ verbose=True,
+):
+ if amp == "fp16":
+ dtype = torch.float16
+ elif amp == "bf16":
+ assert torch.cuda.is_bf16_supported()
+ dtype = torch.bfloat16
+ else:
+ assert not amp
+ dtype = torch.float32
+
+ max_bs = None if max_bs == 0 else max_bs
+ encoder, decoder = model
+ pointmaps_activation = get_pointmaps_activation(decoder, verbose=verbose)
+
+ def post_process_function(x):
+ return postprocess(
+ x, pointmaps_activation=pointmaps_activation, compute_cam=True
+ )
+
+ if verbose:
+ print("loading images")
+ time_start = datetime.datetime.now()
+ nimgs = len(views)
+
+ ellapsed = datetime.datetime.now() - time_start
+ if verbose:
+ print(f"loaded in {ellapsed}")
+ print("running inference")
+ time_start = datetime.datetime.now()
+ if viser_server is not None:
+ viser_server.reset(nimgs)
+
+ imgs = [b["img"].to("cpu") for b in views]
+ true_shape = [torch.from_numpy(b["true_shape"]).to("cpu") for b in views]
+ true_shape = torch.stack(true_shape, dim=0)
+ nimgs = true_shape.shape[0]
+
+ # Use all images as keyframes
+ keyframes = np.linspace(0, len(imgs) - 1, num_mem_images, dtype=int).tolist()
+ encoder_precomputed_features = None
+
+ not_keyframes = sorted(set(range(nimgs)).difference(set(keyframes)))
+ assert (len(keyframes) + len(not_keyframes)) == nimgs
+ # reorder images
+ views = [views[i] for i in keyframes] + [views[i] for i in not_keyframes]
+ imgs = [b["img"].to(device) for b in views]
+ true_shape = [torch.from_numpy(b["true_shape"]).to(device) for b in views]
+ filenames = [filelist[i] for i in keyframes + not_keyframes]
+ img_ids = [torch.tensor(v) for v in keyframes + not_keyframes]
+
+ if encoder_precomputed_features is not None:
+ x_start, pos_start = encoder_precomputed_features
+ x = [x_start[i] for i in keyframes] + [x_start[i] for i in not_keyframes]
+ pos = [pos_start[i] for i in keyframes] + [pos_start[i] for i in not_keyframes]
+ encoder_precomputed_features = (x, pos)
+
+ mem_batches = [init_num_images]
+ while (sum_b := sum(mem_batches)) != max(num_mem_images, init_num_images):
+ size_b = min(batch_num_views, num_mem_images - sum_b)
+ mem_batches.append(size_b)
+
+ if render_once:
+ to_render = list(range(num_mem_images, nimgs))
+ else:
+ to_render = None
+
+ with torch.autocast("cuda", dtype=dtype):
+ x_out_0, x_out = inference_multi_ar(
+ encoder,
+ decoder,
+ imgs,
+ img_ids,
+ true_shape,
+ mem_batches,
+ max_bs=max_bs,
+ verbose=verbose,
+ to_render=to_render,
+ encoder_precomputed_features=encoder_precomputed_features,
+ device=device,
+ preserve_gpu_mem=True,
+ post_process_function=post_process_function,
+ viser_server=viser_server,
+ num_refinements_iterations=num_refinements_iterations,
+ )
+ if to_render is not None:
+ x_out = x_out_0 + x_out
+
+ ellapsed = datetime.datetime.now() - time_start
+ if verbose:
+ print(f"inference in {ellapsed}")
+ try:
+ print(str(int(torch.cuda.max_memory_reserved(device) / (1024**2))) + " MB")
+ except Exception:
+ pass
+
+ if viser_server is not None:
+ viser_server.reset_cam_visility()
+ viser_server.send_message("Finished")
+
+ if verbose:
+ print("preparing pointcloud")
+ time_start = datetime.datetime.now()
+ focals = []
+ cams2world = []
+ for i in range(nimgs):
+ focals.append(float(x_out[i]["focal"].cpu()))
+ cams2world.append(x_out[i]["c2w"].cpu())
+
+ # x_out to cpu
+ for i in range(len(x_out)):
+ for k in x_out[i].keys():
+ x_out[i][k] = x_out[i][k].cpu()
+
+ rgbimg = [rgb(imgs[i], true_shape[i]) for i in range(nimgs)]
+ scene = SceneState(x_out, rgbimg, true_shape, focals, cams2world, filenames)
+
+ ellapsed = datetime.datetime.now() - time_start
+ if verbose:
+ print(f"pointcloud prepared in {ellapsed}")
+
+ return scene
+
+
+class MUSt3RWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ ckpt_path,
+ retrieval_ckpt_path,
+ img_size=512,
+ amp="bf16",
+ max_bs=1,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = name
+ self.ckpt_path = ckpt_path
+ self.retrieval_ckpt_path = retrieval_ckpt_path
+ self.amp = amp
+ self.max_bs = max_bs
+
+ # Init the model and load the checkpoint
+ self.model = load_model(self.ckpt_path, img_size=512)
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for MUSt3R.
+
+ Assumption:
+ - The batch size of input views is 1.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys, where B is the batch size and is 1:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["dust3r"]
+ "label" (list): ["scene_name"]
+ "instance" (list): ["image_name"]
+
+ Returns:
+ List[dict]: A list containing the final outputs for the input views.
+ """
+ # Check the batch size of input views
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ device = views[0]["img"].device
+ num_views = len(views)
+ assert batch_size_per_view == 1, (
+ f"Batch size of input views should be 1, but got {batch_size_per_view}."
+ )
+
+ # Check the data norm type
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "dust3r", (
+ "MUSt3R expects a normalized image with the DUSt3R normalization scheme applied"
+ )
+
+ # Convert the input views to the expected input format
+ images = []
+ image_paths = []
+ for view in views:
+ images.append(
+ dict(
+ img=view["img"][0].cpu(),
+ idx=len(images),
+ instance=str(len(images)),
+ true_shape=np.int32([view["img"].shape[-2], view["img"].shape[-1]]),
+ )
+ )
+ view_name = os.path.join(view["label"][0], view["instance"][0])
+ image_paths.append(view_name)
+
+ # Run MUSt3R inference
+ scene = must3r_inference(
+ images,
+ image_paths,
+ self.model,
+ self.retrieval_ckpt_path,
+ device,
+ self.amp,
+ num_views,
+ self.max_bs,
+ verbose=False,
+ )
+
+ # Make sure scene is not None
+ if scene is None:
+ raise RuntimeError("MUSt3R failed.")
+
+ # Get the predictions
+ predictions = scene.x_out
+
+ # Convert the output to the MapAnything format
+ with torch.autocast("cuda", enabled=False):
+ res = []
+ for view_idx in range(num_views):
+ # Get the current view predictions
+ curr_view_prediction = predictions[view_idx]
+ curr_view_conf = curr_view_prediction["conf"]
+ curr_view_pose = curr_view_prediction["c2w"].unsqueeze(0)
+
+ # Convert the pose to quaternions and translation
+ curr_view_cam_translations = curr_view_pose[..., :3, 3]
+ curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
+
+ # Get the camera frame pointmaps
+ curr_view_pts3d_cam = curr_view_prediction["pts3d_local"].unsqueeze(0)
+
+ # Get the depth along ray and ray directions
+ curr_view_depth_along_ray = torch.norm(
+ curr_view_pts3d_cam, dim=-1, keepdim=True
+ )
+ curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray
+
+ # Get the pointmaps
+ curr_view_pts3d = curr_view_prediction["pts3d"].unsqueeze(0)
+
+ # Append the outputs to the result list
+ res.append(
+ {
+ "pts3d": curr_view_pts3d.to(device),
+ "pts3d_cam": curr_view_pts3d_cam.to(device),
+ "ray_directions": curr_view_ray_dirs.to(device),
+ "depth_along_ray": curr_view_depth_along_ray.to(device),
+ "cam_trans": curr_view_cam_translations.to(device),
+ "cam_quats": curr_view_cam_quats.to(device),
+ "conf": curr_view_conf.to(device),
+ }
+ )
+
+ return res
diff --git a/mapanything/models/external/pi3/__init__.py b/mapanything/models/external/pi3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba79abbf4916b24eb4f3f0bf76cdea02da03b9c9
--- /dev/null
+++ b/mapanything/models/external/pi3/__init__.py
@@ -0,0 +1,119 @@
+"""
+Inference wrapper for Pi3
+"""
+
+import torch
+
+from mapanything.models.external.pi3.models.pi3 import Pi3
+from mapanything.models.external.vggt.utils.rotation import mat_to_quat
+
+
+class Pi3Wrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ torch_hub_force_reload,
+ load_pretrained_weights=True,
+ pos_type="rope100",
+ decoder_size="large",
+ ):
+ super().__init__()
+ self.name = name
+ self.torch_hub_force_reload = torch_hub_force_reload
+
+ if load_pretrained_weights:
+ # Load pre-trained weights
+ if not torch_hub_force_reload:
+ # Initialize the Pi3 model from huggingface hub cache
+ print("Loading Pi3 from huggingface cache ...")
+ self.model = Pi3.from_pretrained(
+ "yyfz233/Pi3",
+ )
+ else:
+ # Initialize the Pi3 model
+ self.model = Pi3.from_pretrained("yyfz233/Pi3", force_download=True)
+ else:
+ # Load the Pi3 class
+ self.model = Pi3(
+ pos_type=pos_type,
+ decoder_size=decoder_size,
+ )
+
+ # Get the dtype for Pi3 inference
+ # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
+ self.dtype = (
+ torch.bfloat16
+ if torch.cuda.get_device_capability()[0] >= 8
+ else torch.float16
+ )
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for Pi3
+
+ Assumption:
+ - All the input views have the same image shape.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["identity"]
+
+ Returns:
+ List[dict]: A list containing the final outputs for all N views.
+ """
+ # Get input shape of the images, number of views, and batch size per view
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ num_views = len(views)
+
+ # Check the data norm type
+ # Pi3 expects a normalized image but without the DINOv2 mean and std applied ("identity")
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "identity", (
+ "Pi3 expects a normalized image but without the DINOv2 mean and std applied"
+ )
+
+ # Concatenate the images to create a single (B, V, C, H, W) tensor
+ img_list = [view["img"] for view in views]
+ images = torch.stack(img_list, dim=1)
+
+ # Run the Pi3 aggregator
+ with torch.autocast("cuda", dtype=self.dtype):
+ results = self.model(images)
+
+ # Need high precision for transformations
+ with torch.autocast("cuda", enabled=False):
+ # Convert the output to MapAnything format
+ res = []
+ for view_idx in range(num_views):
+ # Get the extrinsics
+ curr_view_extrinsic = results["camera_poses"][:, view_idx, ...]
+ curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
+ curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
+
+ # Get the depth along ray, ray directions, local point cloud & global point cloud
+ curr_view_pts3d_cam = results["local_points"][:, view_idx, ...]
+ curr_view_depth_along_ray = torch.norm(
+ curr_view_pts3d_cam, dim=-1, keepdim=True
+ )
+ curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray
+ curr_view_pts3d = results["points"][:, view_idx, ...]
+
+ # Get the confidence
+ curr_view_confidence = results["conf"][:, view_idx, ...]
+
+ # Append the outputs to the result list
+ res.append(
+ {
+ "pts3d": curr_view_pts3d,
+ "pts3d_cam": curr_view_pts3d_cam,
+ "ray_directions": curr_view_ray_dirs,
+ "depth_along_ray": curr_view_depth_along_ray,
+ "cam_trans": curr_view_cam_translations,
+ "cam_quats": curr_view_cam_quats,
+ "conf": curr_view_confidence,
+ }
+ )
+
+ return res
diff --git a/mapanything/models/external/pi3/layers/__init__.py b/mapanything/models/external/pi3/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/models/external/pi3/layers/attention.py b/mapanything/models/external/pi3/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..71ee31d3dd177aaf33c98d0f7a2456160bb74967
--- /dev/null
+++ b/mapanything/models/external/pi3/layers/attention.py
@@ -0,0 +1,429 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+
+import os
+
+import torch
+from torch import nn, Tensor
+from torch.nn.attention import SDPBackend
+from torch.nn.functional import scaled_dot_product_attention
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Attention)")
+ else:
+ # warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:, :, i] for i in range(3)]
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class FlashAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .transpose(1, 3)
+ )
+
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:, :, i] for i in range(3)]
+
+ if q.dtype == torch.bfloat16:
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ x = scaled_dot_product_attention(q, k, v)
+ else:
+ with nn.attention.sdpa_kernel(
+ [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]
+ ):
+ x = scaled_dot_product_attention(q, k, v)
+
+ x = x.transpose(1, 2).reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+"""
+Following is written by GPT-4o
+"""
+
+
+class CrossAttentionRope(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ qk_norm: bool = False,
+ norm_layer: nn.Module = nn.LayerNorm,
+ rope=None,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ # Separate projection layers for query, key, and value
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
+ self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
+ self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
+
+ self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.rope = rope
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ attn_bias=None,
+ qpos=None,
+ kpos=None,
+ ) -> Tensor:
+ """
+ Args:
+ query: Tensor of shape (B, N, C), input query
+ key: Tensor of shape (B, M, C), input key
+ value: Tensor of shape (B, M, C), input value
+ attn_bias: Optional tensor for attention bias
+ Returns:
+ Tensor of shape (B, N, C), output of cross-attention
+ """
+ B, N, C = query.shape
+ _, M, _ = key.shape
+
+ # Project query, key, and value
+ q = (
+ self.q_proj(query)
+ .reshape(B, N, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ k = (
+ self.k_proj(key)
+ .reshape(B, M, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ v = (
+ self.v_proj(value)
+ .reshape(B, M, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, qpos)
+ k = self.rope(k, kpos)
+
+ # Scale query
+ q = q * self.scale
+
+ # Compute attention scores
+ attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M)
+ if attn_bias is not None:
+ attn = attn + attn_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ # Compute attention output
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C)
+
+ # Final projection
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffCrossAttentionRope(CrossAttentionRope):
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ attn_bias=None,
+ qpos=None,
+ kpos=None,
+ ) -> Tensor:
+ """
+ Args:
+ query: Tensor of shape (B, N, C), input query
+ key: Tensor of shape (B, M, C), input key
+ value: Tensor of shape (B, M, C), input value
+ attn_bias: Optional tensor for attention bias
+ Returns:
+ Tensor of shape (B, N, C), output of cross-attention
+ """
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(query, key, value, attn_bias)
+
+ B, N, C = query.shape
+ _, M, _ = key.shape
+
+ # Project query, key, and value
+ q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads)
+ k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads)
+ v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, qpos)
+ k = self.rope(k, kpos)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+
+ # Compute memory-efficient attention
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape(B, N, C)
+
+ # Final projection
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class AttentionRope(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ qk_norm: bool = False,
+ norm_layer: nn.Module = nn.LayerNorm,
+ rope=None,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
+
+ self.rope = rope
+
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, xpos)
+ k = self.rope(k, xpos)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttentionRope(AttentionRope):
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ qkv = qkv.transpose(1, 3)
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:, :, i] for i in range(3)]
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, xpos)
+ k = self.rope(k, xpos)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix
+ # global_valid_id = torch.where(score_matrix > 0)
+ # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class FlashAttentionRope(AttentionRope):
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .transpose(1, 3)
+ )
+
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:, :, i] for i in range(3)]
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, xpos)
+ k = self.rope(k, xpos)
+
+ if q.dtype == torch.bfloat16:
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ x = scaled_dot_product_attention(q, k, v)
+ else:
+ with nn.attention.sdpa_kernel(
+ [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]
+ ):
+ x = scaled_dot_product_attention(q, k, v)
+
+ x = x.transpose(1, 2).reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def get_attn_score(blk_class, x, frame_num, token_length, xpos=None):
+ x = blk_class.norm1(x)
+
+ B, N, C = x.shape
+ qkv = blk_class.attn.qkv(x).reshape(
+ B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads
+ )
+
+ qkv = qkv.transpose(1, 3)
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:, :, i] for i in range(3)]
+ q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype)
+
+ if blk_class.attn.rope is not None:
+ q = blk_class.attn.rope(q, xpos)
+ k = blk_class.attn.rope(k, xpos)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+
+ score = (
+ (
+ q.permute(0, 2, 1, 3)
+ * blk_class.attn.scale
+ @ k.permute(0, 2, 1, 3).transpose(-2, -1)
+ )
+ .sum(dim=1)
+ .reshape(B, frame_num, token_length, frame_num, token_length)
+ .mean(dim=[2, 4])
+ .sum(-1)
+ )
+
+ return score
diff --git a/mapanything/models/external/pi3/layers/block.py b/mapanything/models/external/pi3/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..59702755abf42310952398c42a15537df92a0382
--- /dev/null
+++ b/mapanything/models/external/pi3/layers/block.py
@@ -0,0 +1,448 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import os
+from typing import Any, Callable, Dict, List, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from mapanything.models.external.dinov2.layers.drop_path import DropPath
+from mapanything.models.external.dinov2.layers.layer_scale import LayerScale
+from mapanything.models.external.dinov2.layers.mlp import Mlp
+from mapanything.models.external.pi3.layers.attention import (
+ Attention,
+ CrossAttentionRope,
+ MemEffAttention,
+)
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, index_select_cat, scaled_index_add
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Block)")
+ else:
+ # warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ else:
+ x_plus_residual = scaled_index_add(
+ x,
+ brange,
+ residual.to(dtype=x.dtype),
+ scaling=scaling_vector,
+ alpha=residual_scale_factor,
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = (
+ [b.shape[0] for b in branges]
+ if branges is not None
+ else [x.shape[0] for x in x_list]
+ )
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
+ 1, -1, x_list[0].shape[-1]
+ )
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
+ ]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(
+ x_list, branges, residual_list, residual_scale_factors
+ ):
+ outputs.append(
+ add_residual(
+ x, brange, residual, residual_scale_factor, scaling_vector
+ ).view_as(x)
+ )
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma
+ if isinstance(self.ls1, LayerScale)
+ else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma
+ if isinstance(self.ls1, LayerScale)
+ else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
+
+
+class BlockRope(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool = False,
+ rope=None,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ rope=rope,
+ )
+
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, xpos=None) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), xpos=xpos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+class CrossBlockRope(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ init_values=None,
+ qk_norm: bool = False,
+ rope=None,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ rope=rope,
+ qk_norm=qk_norm,
+ )
+
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.ls_y = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.norm2 = norm_layer(dim)
+ self.norm_y = norm_layer(dim)
+ self.cross_attn = cross_attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ rope=rope,
+ qk_norm=qk_norm,
+ )
+
+ self.norm3 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ bias=ffn_bias,
+ )
+
+ def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), xpos=xpos))
+
+ def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor:
+ return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm3(x)))
+
+ x = x + attn_residual_func(x)
+ y_ = self.norm_y(y)
+ x = x + cross_attn_residual_func(x, y_)
+ x = x + ffn_residual_func(x)
+
+ return x
diff --git a/mapanything/models/external/pi3/layers/camera_head.py b/mapanything/models/external/pi3/layers/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3989249b3900454278227c62df98af9bab691a56
--- /dev/null
+++ b/mapanything/models/external/pi3/layers/camera_head.py
@@ -0,0 +1,106 @@
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172'
+class ResConvBlock(nn.Module):
+ """
+ 1x1 convolution residual block
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.head_skip = (
+ nn.Identity()
+ if self.in_channels == self.out_channels
+ else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
+ )
+ # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
+ # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
+ # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
+
+ # change 1x1 convolution to linear
+ self.res_conv1 = nn.Linear(self.in_channels, self.out_channels)
+ self.res_conv2 = nn.Linear(self.out_channels, self.out_channels)
+ self.res_conv3 = nn.Linear(self.out_channels, self.out_channels)
+
+ def forward(self, res):
+ x = F.relu(self.res_conv1(res))
+ x = F.relu(self.res_conv2(x))
+ x = F.relu(self.res_conv3(x))
+ res = self.head_skip(res) + x
+ return res
+
+
+class CameraHead(nn.Module):
+ def __init__(self, dim=512):
+ super().__init__()
+ output_dim = dim
+ self.res_conv = nn.ModuleList(
+ [deepcopy(ResConvBlock(output_dim, output_dim)) for _ in range(2)]
+ )
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ self.more_mlps = nn.Sequential(
+ nn.Linear(output_dim, output_dim),
+ nn.ReLU(),
+ nn.Linear(output_dim, output_dim),
+ nn.ReLU(),
+ )
+ self.fc_t = nn.Linear(output_dim, 3)
+ self.fc_rot = nn.Linear(output_dim, 9)
+
+ def forward(self, feat, patch_h, patch_w):
+ BN, hw, c = feat.shape
+
+ for i in range(2):
+ feat = self.res_conv[i](feat)
+
+ # feat = self.avgpool(feat)
+ feat = self.avgpool(
+ feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()
+ ) ##########
+ feat = feat.view(feat.size(0), -1)
+
+ feat = self.more_mlps(feat) # [B, D_]
+ with torch.amp.autocast(device_type="cuda", enabled=False):
+ out_t = self.fc_t(feat.float()) # [B,3]
+ out_r = self.fc_rot(feat.float()) # [B,9]
+ pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
+
+ return pose
+
+ def convert_pose_to_4x4(self, B, out_r, out_t, device):
+ out_r = self.svd_orthogonalize(out_r) # [N,3,3]
+ pose = torch.zeros((B, 4, 4), device=device)
+ pose[:, :3, :3] = out_r
+ pose[:, :3, 3] = out_t
+ pose[:, 3, 3] = 1.0
+ return pose
+
+ def svd_orthogonalize(self, m):
+ """Convert 9D representation to SO(3) using SVD orthogonalization.
+
+ Args:
+ m: [BATCH, 3, 3] 3x3 matrices.
+
+ Returns:
+ [BATCH, 3, 3] SO(3) rotation matrices.
+ """
+ if m.dim() < 3:
+ m = m.reshape((-1, 3, 3))
+ m_transpose = torch.transpose(
+ torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2
+ )
+ u, s, v = torch.svd(m_transpose)
+ det = torch.det(torch.matmul(v, u.transpose(-2, -1)))
+ # Check orientation reflection.
+ r = torch.matmul(
+ torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2),
+ u.transpose(-2, -1),
+ )
+ return r
diff --git a/mapanything/models/external/pi3/layers/pos_embed.py b/mapanything/models/external/pi3/layers/pos_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb0d7271f70cbacb77ae4e97bba7300f5f6b99f
--- /dev/null
+++ b/mapanything/models/external/pi3/layers/pos_embed.py
@@ -0,0 +1,190 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+
+import numpy as np
+import torch
+
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if n_cls_token > 0:
+ pos_embed = np.concatenate(
+ [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=float)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if "pos_embed" in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches**0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print(
+ "Position interpolate from %dx%d to %dx%d"
+ % (orig_size, orig_size, new_size, new_size)
+ )
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size, orig_size, embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens,
+ size=(new_size, new_size),
+ mode="bicubic",
+ align_corners=False,
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model["pos_embed"] = new_pos_embed
+
+
+# ----------------------------------------------------------
+# RoPE2D: RoPE implementation in 2D
+# ----------------------------------------------------------
+
+try:
+ from models.curope import cuRoPE2D
+
+ RoPE2D = cuRoPE2D
+except ImportError:
+
+ class RoPE2D(torch.nn.Module):
+ def __init__(self, freq=100.0, F0=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D, seq_len, device, dtype) not in self.cache:
+ inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, D, 2).float().to(device) / D)
+ )
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
+ return self.cache[D, seq_len, device, dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim == 2
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
+ output:
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
+ """
+ assert tokens.size(3) % 2 == 0, (
+ "number of dimensions should be a multiple of two"
+ )
+ D = tokens.size(3) // 2
+ assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2
+ cos, sin = self.get_cos_sin(
+ D, int(positions.max()) + 1, tokens.device, tokens.dtype
+ )
+ # split features into two along the feature dimension, and apply rope1d on each half
+ y, x = tokens.chunk(2, dim=-1)
+ y = self.apply_rope1d(y, positions[:, :, 0], cos, sin)
+ x = self.apply_rope1d(x, positions[:, :, 1], cos, sin)
+ tokens = torch.cat((y, x), dim=-1)
+ return tokens
+
+
+# patch embedding
+class PositionGetter(object):
+ """return positions of patches"""
+
+ def __init__(self):
+ self.cache_positions = {}
+
+ def __call__(self, b, h, w, device):
+ if (h, w) not in self.cache_positions:
+ x = torch.arange(w, device=device)
+ y = torch.arange(h, device=device)
+ self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
+ pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
+ return pos
diff --git a/mapanything/models/external/pi3/layers/transformer_head.py b/mapanything/models/external/pi3/layers/transformer_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0094a72fb954df688166345cd0672ac74eafac58
--- /dev/null
+++ b/mapanything/models/external/pi3/layers/transformer_head.py
@@ -0,0 +1,98 @@
+from functools import partial
+
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+
+from mapanything.models.external.dinov2.layers import Mlp
+from mapanything.models.external.pi3.layers.attention import FlashAttentionRope
+from mapanything.models.external.pi3.layers.block import BlockRope
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ dec_embed_dim=512,
+ depth=5,
+ dec_num_heads=8,
+ mlp_ratio=4,
+ rope=None,
+ need_project=True,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.projects = (
+ nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
+ )
+ self.use_checkpoint = use_checkpoint
+
+ self.blocks = nn.ModuleList(
+ [
+ BlockRope(
+ dim=dec_embed_dim,
+ num_heads=dec_num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ drop_path=0.0,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ ffn_layer=Mlp,
+ init_values=None,
+ qk_norm=False,
+ # attn_class=MemEffAttentionRope,
+ attn_class=FlashAttentionRope,
+ rope=rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.linear_out = nn.Linear(dec_embed_dim, out_dim)
+
+ def forward(self, hidden, xpos=None):
+ hidden = self.projects(hidden)
+ for i, blk in enumerate(self.blocks):
+ if self.use_checkpoint and self.training:
+ hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
+ else:
+ hidden = blk(hidden, xpos=xpos)
+ out = self.linear_out(hidden)
+ return out
+
+
+class LinearPts3d(nn.Module):
+ """
+ Linear head for dust3r
+ Each token outputs: - 16x16 3D points (+ confidence)
+ """
+
+ def __init__(
+ self,
+ patch_size,
+ dec_embed_dim,
+ output_dim=3,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Linear(dec_embed_dim, (output_dim) * self.patch_size**2)
+
+ def forward(self, decout, img_shape):
+ H, W = img_shape
+ tokens = decout[-1]
+ B, S, D = tokens.shape
+
+ # extract 3D points
+ feat = self.proj(tokens) # B,S,D
+ feat = feat.transpose(-1, -2).view(
+ B, -1, H // self.patch_size, W // self.patch_size
+ )
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
+
+ # permute + norm depth
+ return feat.permute(0, 2, 3, 1)
diff --git a/mapanything/models/external/pi3/models/__init__.py b/mapanything/models/external/pi3/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/models/external/pi3/models/pi3.py b/mapanything/models/external/pi3/models/pi3.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc07c319c760b9be04428dfaf9fbf44d4b6443d0
--- /dev/null
+++ b/mapanything/models/external/pi3/models/pi3.py
@@ -0,0 +1,251 @@
+from copy import deepcopy
+from functools import partial
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin
+
+from mapanything.models.external.dinov2.hub.backbones import dinov2_vitl14_reg
+from mapanything.models.external.dinov2.layers import Mlp
+from mapanything.models.external.pi3.layers.attention import FlashAttentionRope
+from mapanything.models.external.pi3.layers.block import BlockRope
+from mapanything.models.external.pi3.layers.camera_head import CameraHead
+from mapanything.models.external.pi3.layers.pos_embed import PositionGetter, RoPE2D
+from mapanything.models.external.pi3.layers.transformer_head import (
+ LinearPts3d,
+ TransformerDecoder,
+)
+
+
+def homogenize_points(
+ points,
+):
+ """Convert batched points (xyz) to (xyz1)."""
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+
+
+class Pi3(nn.Module, PyTorchModelHubMixin):
+ def __init__(
+ self,
+ pos_type="rope100",
+ decoder_size="large",
+ ):
+ super().__init__()
+
+ # ----------------------
+ # Encoder
+ # ----------------------
+ self.encoder = dinov2_vitl14_reg(pretrained=False)
+ self.patch_size = 14
+ del self.encoder.mask_token
+
+ # ----------------------
+ # Positonal Encoding
+ # ----------------------
+ self.pos_type = pos_type if pos_type is not None else "none"
+ self.rope = None
+ if self.pos_type.startswith("rope"): # eg rope100
+ if RoPE2D is None:
+ raise ImportError(
+ "Cannot find cuRoPE2D, please install it following the README instructions"
+ )
+ freq = float(self.pos_type[len("rope") :])
+ self.rope = RoPE2D(freq=freq)
+ self.position_getter = PositionGetter()
+ else:
+ raise NotImplementedError
+
+ # ----------------------
+ # Decoder
+ # ----------------------
+ if decoder_size == "small":
+ dec_embed_dim = 384
+ dec_num_heads = 6
+ mlp_ratio = 4
+ dec_depth = 24
+ elif decoder_size == "base":
+ dec_embed_dim = 768
+ dec_num_heads = 12
+ mlp_ratio = 4
+ dec_depth = 24
+ elif decoder_size == "large":
+ dec_embed_dim = 1024
+ dec_num_heads = 16
+ mlp_ratio = 4
+ dec_depth = 36
+ else:
+ raise NotImplementedError
+ self.decoder = nn.ModuleList(
+ [
+ BlockRope(
+ dim=dec_embed_dim,
+ num_heads=dec_num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ drop_path=0.0,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ ffn_layer=Mlp,
+ init_values=0.01,
+ qk_norm=True,
+ attn_class=FlashAttentionRope,
+ rope=self.rope,
+ )
+ for _ in range(dec_depth)
+ ]
+ )
+ self.dec_embed_dim = dec_embed_dim
+
+ # ----------------------
+ # Register_token
+ # ----------------------
+ num_register_tokens = 5
+ self.patch_start_idx = num_register_tokens
+ self.register_token = nn.Parameter(
+ torch.randn(1, 1, num_register_tokens, self.dec_embed_dim)
+ )
+ nn.init.normal_(self.register_token, std=1e-6)
+
+ # ----------------------
+ # Local Points Decoder
+ # ----------------------
+ self.point_decoder = TransformerDecoder(
+ in_dim=2 * self.dec_embed_dim,
+ dec_embed_dim=1024,
+ dec_num_heads=16,
+ out_dim=1024,
+ rope=self.rope,
+ )
+ self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
+
+ # ----------------------
+ # Conf Decoder
+ # ----------------------
+ self.conf_decoder = deepcopy(self.point_decoder)
+ self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
+
+ # ----------------------
+ # Camera Pose Decoder
+ # ----------------------
+ self.camera_decoder = TransformerDecoder(
+ in_dim=2 * self.dec_embed_dim,
+ dec_embed_dim=1024,
+ dec_num_heads=16, # 8
+ out_dim=512,
+ rope=self.rope,
+ use_checkpoint=False,
+ )
+ self.camera_head = CameraHead(dim=512)
+
+ # For ImageNet Normalize
+ image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
+ image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
+
+ self.register_buffer("image_mean", image_mean)
+ self.register_buffer("image_std", image_std)
+
+ def decode(self, hidden, N, H, W):
+ BN, hw, _ = hidden.shape
+ B = BN // N
+
+ final_output = []
+
+ hidden = hidden.reshape(B * N, hw, -1)
+
+ register_token = self.register_token.repeat(B, N, 1, 1).reshape(
+ B * N, *self.register_token.shape[-2:]
+ )
+
+ # Concatenate special tokens with patch tokens
+ hidden = torch.cat([register_token, hidden], dim=1)
+ hw = hidden.shape[1]
+
+ if self.pos_type.startswith("rope"):
+ pos = self.position_getter(
+ B * N, H // self.patch_size, W // self.patch_size, hidden.device
+ )
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos = pos + 1
+ pos_special = (
+ torch.zeros(B * N, self.patch_start_idx, 2)
+ .to(hidden.device)
+ .to(pos.dtype)
+ )
+ pos = torch.cat([pos_special, pos], dim=1)
+
+ for i in range(len(self.decoder)):
+ blk = self.decoder[i]
+
+ if i % 2 == 0:
+ pos = pos.reshape(B * N, hw, -1)
+ hidden = hidden.reshape(B * N, hw, -1)
+ else:
+ pos = pos.reshape(B, N * hw, -1)
+ hidden = hidden.reshape(B, N * hw, -1)
+
+ hidden = blk(hidden, xpos=pos)
+
+ if i + 1 in [len(self.decoder) - 1, len(self.decoder)]:
+ final_output.append(hidden.reshape(B * N, hw, -1))
+
+ return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(
+ B * N, hw, -1
+ )
+
+ def forward(self, imgs):
+ imgs = (imgs - self.image_mean) / self.image_std
+
+ B, N, _, H, W = imgs.shape
+ patch_h, patch_w = H // 14, W // 14
+
+ # encode by dinov2
+ imgs = imgs.reshape(B * N, _, H, W)
+ hidden = self.encoder(imgs, is_training=True)
+
+ if isinstance(hidden, dict):
+ hidden = hidden["x_norm_patchtokens"]
+
+ hidden, pos = self.decode(hidden, N, H, W)
+
+ point_hidden = self.point_decoder(hidden, xpos=pos)
+ conf_hidden = self.conf_decoder(hidden, xpos=pos)
+ camera_hidden = self.camera_decoder(hidden, xpos=pos)
+
+ with torch.amp.autocast(device_type="cuda", enabled=False):
+ # local points
+ point_hidden = point_hidden.float()
+ ret = self.point_head(
+ [point_hidden[:, self.patch_start_idx :]], (H, W)
+ ).reshape(B, N, H, W, -1)
+ xy, z = ret.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ local_points = torch.cat([xy * z, z], dim=-1)
+
+ # confidence
+ conf_hidden = conf_hidden.float()
+ conf = self.conf_head(
+ [conf_hidden[:, self.patch_start_idx :]], (H, W)
+ ).reshape(B, N, H, W, -1)
+
+ # camera
+ camera_hidden = camera_hidden.float()
+ camera_poses = self.camera_head(
+ camera_hidden[:, self.patch_start_idx :], patch_h, patch_w
+ ).reshape(B, N, 4, 4)
+
+ # unproject local points using camera poses
+ points = torch.einsum(
+ "bnij, bnhwj -> bnhwi", camera_poses, homogenize_points(local_points)
+ )[..., :3]
+
+ return dict(
+ points=points,
+ local_points=local_points,
+ conf=conf,
+ camera_poses=camera_poses,
+ )
diff --git a/mapanything/models/external/pow3r/__init__.py b/mapanything/models/external/pow3r/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7fb802cf629d935eff55dfd1abbbfb2913a79e7
--- /dev/null
+++ b/mapanything/models/external/pow3r/__init__.py
@@ -0,0 +1,860 @@
+"""
+Inference wrapper for Pow3R
+"""
+
+import warnings
+from copy import deepcopy
+
+import pow3r.model.blocks # noqa
+import roma
+import torch
+import torch.nn as nn
+import tqdm
+from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
+from dust3r.image_pairs import make_pairs
+from dust3r.inference import check_if_same_size
+from dust3r.model import CroCoNet
+from dust3r.patch_embed import get_patch_embed as dust3r_patch_embed
+from dust3r.utils.device import collate_with_cat, to_cpu
+from dust3r.utils.misc import (
+ fill_default_args,
+ freeze_all_params,
+ interleave,
+ is_symmetrized,
+ transpose_to_landscape,
+)
+from pow3r.model.blocks import Block, BlockInject, DecoderBlock, DecoderBlockInject, Mlp
+from pow3r.model.heads import head_factory
+from pow3r.model.inference import (
+ add_depth,
+ add_intrinsics,
+ add_relpose,
+)
+from pow3r.model.patch_embed import get_patch_embed
+
+from mapanything.models.external.vggt.utils.rotation import mat_to_quat
+from mapanything.utils.geometry import (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ convert_z_depth_to_depth_along_ray,
+ depthmap_to_camera_frame,
+ get_rays_in_camera_frame,
+)
+
+
+class Pow3R(CroCoNet):
+ """Two siamese encoders, followed by two decoders.
+ The goal is to output 3d points directly, both images in view1's frame
+ (hence the asymmetry).
+ """
+
+ def __init__(
+ self,
+ mode="embed",
+ head_type="linear",
+ patch_embed_cls="PatchEmbedDust3R",
+ freeze="none",
+ landscape_only=True,
+ **croco_kwargs,
+ ):
+ # retrieve all default arguments using python magic
+ self.croco_args = fill_default_args(croco_kwargs, super().__init__)
+ super().__init__(**croco_kwargs)
+ del self.mask_token # useless
+ del self.prediction_head
+
+ dec_dim, enc_dim = self.decoder_embed.weight.shape
+ self.enc_embed_dim = enc_dim
+ self.dec_embed_dim = dec_dim
+
+ self.mode = mode
+ # additional parameters in the encoder
+ img_size = self.patch_embed.img_size
+ patch_size = self.patch_embed.patch_size[0]
+ self.patch_embed = dust3r_patch_embed(
+ patch_embed_cls, img_size, patch_size, self.enc_embed_dim
+ )
+ self.patch_embed_rays = get_patch_embed(
+ patch_embed_cls + "_Mlp",
+ img_size,
+ patch_size,
+ self.enc_embed_dim,
+ in_chans=3,
+ )
+ self.patch_embed_depth = get_patch_embed(
+ patch_embed_cls + "_Mlp",
+ img_size,
+ patch_size,
+ self.enc_embed_dim,
+ in_chans=2,
+ )
+ self.pose_embed = Mlp(12, 4 * dec_dim, dec_dim)
+
+ # additional parameters in the decoder
+ self.dec_cls = "_cls" in self.mode
+ self.dec_num_cls = 0
+ if self.dec_cls:
+ # use a CLS token in the decoder only
+ self.mode = self.mode.replace("_cls", "")
+ self.cls_token1 = nn.Parameter(torch.zeros((dec_dim,)))
+ self.cls_token2 = nn.Parameter(torch.zeros((dec_dim,)))
+ self.dec_num_cls = 1 # affects all blocks
+
+ use_ln = "_ln" in self.mode # TODO remove?
+ self.patch_ln = nn.LayerNorm(enc_dim) if use_ln else nn.Identity()
+ self.dec1_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity()
+ self.dec2_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity()
+
+ self.dec_blocks2 = deepcopy(self.dec_blocks)
+
+ # here we modify some of the blocks
+ self.replace_some_blocks()
+
+ self.set_downstream_head(head_type, landscape_only, **croco_kwargs)
+ self.set_freeze(freeze)
+
+ def replace_some_blocks(self):
+ assert self.mode.startswith("inject") # inject[0,0.5]
+ NewBlock = BlockInject
+ DecoderNewBlock = DecoderBlockInject
+
+ all_layers = {
+ i / n
+ for i in range(len(self.enc_blocks))
+ for n in [len(self.enc_blocks), len(self.dec_blocks)]
+ }
+ which_layers = eval(self.mode[self.mode.find("[") :]) or all_layers
+ assert isinstance(which_layers, (set, list))
+
+ n = 0
+ for i, block in enumerate(self.enc_blocks):
+ if i / len(self.enc_blocks) in which_layers:
+ block.__class__ = NewBlock
+ block.init(self.enc_embed_dim)
+ n += 1
+ else:
+ block.__class__ = Block
+ assert n == len(which_layers), breakpoint()
+
+ n = 0
+ for i in range(len(self.dec_blocks)):
+ for blocks in [self.dec_blocks, self.dec_blocks2]:
+ block = blocks[i]
+ if i / len(self.dec_blocks) in which_layers:
+ block.__class__ = DecoderNewBlock
+ block.init(self.dec_embed_dim)
+ n += 1
+ else:
+ block.__class__ = DecoderBlock
+ assert n == 2 * len(which_layers), breakpoint()
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path, **kw):
+ return _load_model(pretrained_model_path, device="cpu")
+
+ def load_state_dict(self, ckpt, **kw):
+ # duplicate all weights for the second decoder if not present
+ new_ckpt = dict(ckpt)
+ if not any(k.startswith("dec_blocks2") for k in ckpt):
+ for key, value in ckpt.items():
+ if key.startswith("dec_blocks"):
+ new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value
+ # remove layers that have different shapes
+ cur_ckpt = self.state_dict()
+ for key, val in ckpt.items():
+ if key.startswith("downstream_head2.proj"):
+ if key in cur_ckpt and cur_ckpt[key].shape != val.shape:
+ print(f" (removing ckpt[{key}] because wrong shape)")
+ del new_ckpt[key]
+ return super().load_state_dict(new_ckpt, **kw)
+
+ def set_freeze(self, freeze): # this is for use by downstream models
+ self.freeze = freeze
+ to_be_frozen = {
+ "none": [],
+ "encoder": [self.patch_embed, self.enc_blocks],
+ }
+ freeze_all_params(to_be_frozen[freeze])
+
+ def set_prediction_head(self, *args, **kwargs):
+ """No prediction head"""
+ return
+
+ def set_downstream_head(
+ self,
+ head_type,
+ landscape_only,
+ patch_size,
+ img_size,
+ mlp_ratio,
+ dec_depth,
+ **kw,
+ ):
+ assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, (
+ f"{img_size=} must be multiple of {patch_size=}"
+ )
+
+ # split heads if different
+ heads = head_type.split(";")
+ assert len(heads) in (1, 2)
+ head1_type, head2_type = (heads + heads)[:2]
+
+ # allocate heads
+ self.downstream_head1 = head_factory(head1_type, self)
+ self.downstream_head2 = head_factory(head2_type, self)
+
+ # magic wrapper
+ self.head1 = transpose_to_landscape(
+ self.downstream_head1, activate=landscape_only
+ )
+ self.head2 = transpose_to_landscape(
+ self.downstream_head2, activate=landscape_only
+ )
+
+ def _encode_image(self, image, true_shape, rays=None, depth=None):
+ # embed the image into patches (x has size B x Npatches x C)
+ x, pos = self.patch_embed(image, true_shape=true_shape)
+
+ if rays is not None: # B,3,H,W
+ rays_emb, pos2 = self.patch_embed_rays(rays, true_shape=true_shape)
+ assert (pos == pos2).all()
+ if self.mode.startswith("embed"):
+ x = x + rays_emb
+ else:
+ rays_emb = None
+
+ if depth is not None: # B,2,H,W
+ depth_emb, pos2 = self.patch_embed_depth(depth, true_shape=true_shape)
+ assert (pos == pos2).all()
+ if self.mode.startswith("embed"):
+ x = x + depth_emb
+ else:
+ depth_emb = None
+
+ x = self.patch_ln(x)
+
+ # add positional embedding without cls token
+ assert self.enc_pos_embed is None
+
+ # now apply the transformer encoder and normalization
+ for blk in self.enc_blocks:
+ x = blk(x, pos, rays=rays_emb, depth=depth_emb)
+
+ x = self.enc_norm(x)
+ return x, pos
+
+ def encode_symmetrized(self, view1, view2):
+ img1 = view1["img"]
+ img2 = view2["img"]
+ B = img1.shape[0]
+ # Recover true_shape when available, otherwise assume that the img shape is the true one
+ shape1 = view1.get(
+ "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1)
+ )
+ shape2 = view2.get(
+ "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1)
+ )
+ # warning! maybe the images have different portrait/landscape orientations
+
+ # privileged information
+ rays1 = view1.get("known_rays", None)
+ rays2 = view2.get("known_rays", None)
+ depth1 = view1.get("known_depth", None)
+ depth2 = view2.get("known_depth", None)
+
+ if is_symmetrized(view1, view2):
+ # computing half of forward pass!'
+ def hsub(x):
+ return None if x is None else x[::2]
+
+ feat1, pos1 = self._encode_image(
+ img1[::2], shape1[::2], rays=hsub(rays1), depth=hsub(depth1)
+ )
+ feat2, pos2 = self._encode_image(
+ img2[::2], shape2[::2], rays=hsub(rays2), depth=hsub(depth2)
+ )
+
+ feat1, feat2 = interleave(feat1, feat2)
+ pos1, pos2 = interleave(pos1, pos2)
+ else:
+ feat1, pos1 = self._encode_image(img1, shape1, rays=rays1, depth=depth1)
+ feat2, pos2 = self._encode_image(img2, shape2, rays=rays2, depth=depth2)
+
+ return (shape1, shape2), (feat1, feat2), (pos1, pos2)
+
+ def _decoder(self, f1, pos1, f2, pos2, relpose1=None, relpose2=None):
+ final_output = [(f1, f2)] # before projection
+
+ # project to decoder dim
+ f1 = self.decoder_embed(f1)
+ f2 = self.decoder_embed(f2)
+
+ # add CLS token for the decoder
+ if self.dec_cls:
+ cls1 = self.cls_token1[None, None].expand(len(f1), 1, -1).clone()
+ cls2 = self.cls_token2[None, None].expand(len(f2), 1, -1).clone()
+
+ if relpose1 is not None: # shape = (B, 4, 4)
+ pose_emb1 = self.pose_embed(relpose1[:, :3].flatten(1)).unsqueeze(1)
+ if self.mode.startswith("embed"):
+ if self.dec_cls:
+ cls1 = cls1 + pose_emb1
+ else:
+ f1 = f1 + pose_emb1
+ else:
+ pose_emb1 = None
+
+ if relpose2 is not None: # shape = (B, 4, 4)
+ pose_emb2 = self.pose_embed(relpose2[:, :3].flatten(1)).unsqueeze(1)
+ if self.mode.startswith("embed"):
+ if self.dec_cls:
+ cls2 = cls2 + pose_emb2
+ else:
+ f2 = f2 + pose_emb2
+ else:
+ pose_emb2 = None
+
+ if self.dec_cls:
+ f1, pos1 = cat_cls(cls1, f1, pos1)
+ f2, pos2 = cat_cls(cls2, f2, pos2)
+
+ f1 = self.dec1_pre_ln(f1)
+ f2 = self.dec2_pre_ln(f2)
+
+ final_output.append((f1, f2)) # to be removed later
+ for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
+ # img1 side
+ f1, _ = blk1(
+ *final_output[-1][::+1],
+ pos1,
+ pos2,
+ relpose=pose_emb1,
+ num_cls=self.dec_num_cls,
+ )
+ # img2 side
+ f2, _ = blk2(
+ *final_output[-1][::-1],
+ pos2,
+ pos1,
+ relpose=pose_emb2,
+ num_cls=self.dec_num_cls,
+ )
+ # store the result
+ final_output.append((f1, f2))
+
+ del final_output[1] # duplicate with final_output[0] (after decoder proj)
+ if self.dec_cls: # remove cls token for decoder layers
+ final_output[1:] = [(f1[:, 1:], f2[:, 1:]) for f1, f2 in final_output[1:]]
+ # normalize last output
+ final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
+ return zip(*final_output)
+
+ def _downstream_head(self, head_num, decout, img_shape):
+ B, S, D = decout[-1].shape
+ head = getattr(self, f"head{head_num}")
+ return head(decout, img_shape)
+
+ def forward(self, view1, view2):
+ # encode the two images --> B,S,D
+ (shape1, shape2), (feat1, feat2), (pos1, pos2) = self.encode_symmetrized(
+ view1, view2
+ )
+
+ # combine all ref images into object-centric representation
+ dec1, dec2 = self._decoder(
+ feat1,
+ pos1,
+ feat2,
+ pos2,
+ relpose1=view1.get("known_pose"),
+ relpose2=view2.get("known_pose"),
+ )
+ with torch.autocast("cuda", enabled=False):
+ res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
+ res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
+
+ res2["pts3d_in_other_view"] = res2.pop(
+ "pts3d"
+ ) # predict view2's pts3d in view1's frame
+ return res1, res2
+
+
+def convert_release_dust3r_args(args):
+ args.model = (
+ args.model.replace("patch_embed_cls", "patch_embed")
+ .replace("AsymmetricMASt3R", "AsymmetricCroCo3DStereo")
+ .replace("PatchEmbedDust3R", "convManyAR")
+ .replace(
+ "pos_embed='RoPE100'",
+ "enc_pos_embed='cuRoPE100', dec_pos_embed='cuRoPE100'",
+ )
+ )
+ return args
+
+
+def _load_model(model_path, device):
+ print("... loading model from", model_path)
+ ckpt = torch.load(model_path, map_location="cpu")
+ try:
+ net = eval(
+ ckpt["args"].model[:-1].replace("convManyAR", "convP")
+ + ", landscape_only=False)"
+ )
+ except Exception:
+ args = convert_release_dust3r_args(ckpt["args"])
+ net = eval(
+ args.model[:-1].replace("convManyAR", "convP") + ", landscape_only=False)"
+ )
+ ckpt["model"] = {
+ k.replace("_downstream_head", "downstream_head"): v
+ for k, v in ckpt["model"].items()
+ }
+ print(net.load_state_dict(ckpt["model"], strict=False))
+ return net.to(device)
+
+
+def cat_cls(cls, tokens, pos):
+ tokens = torch.cat((cls, tokens), dim=1)
+ pos = torch.cat((-pos.new_ones(len(cls), 1, 2), pos), dim=1)
+ return tokens, pos
+
+
+class Pow3RWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ ckpt_path,
+ geometric_input_config,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = name
+ self.ckpt_path = ckpt_path
+ self.geometric_input_config = geometric_input_config
+
+ # Init the model and load the checkpoint
+ print(f"Loading checkpoint from {self.ckpt_path} ...")
+ ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False)
+ model = ckpt["definition"]
+ print(f"Creating model = {model}")
+ self.model = eval(model)
+ print(self.model.load_state_dict(ckpt["weights"]))
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for Pow3R.
+
+ Assumption:
+ - The number of input views is 2.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Length of the list should be 2.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["dust3r"]
+ Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs:
+ "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3).
+ "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation.
+ "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1).
+
+ Returns:
+ List[dict]: A list containing the final outputs for the two views. Length of the list will be 2.
+ """
+ # Check that the number of input views is 2
+ assert len(views) == 2, "Pow3R requires 2 input views."
+
+ # Check the data norm type
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "dust3r", (
+ "Pow3R expects a normalized image with the DUSt3R normalization scheme applied"
+ )
+
+ # Get the batch size per view, device and two views
+ batch_size_per_view = views[0]["img"].shape[0]
+ device = views[0]["img"].device
+ view1, view2 = views
+
+ # Decide if we need to use the geometric inputs
+ if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]:
+ # Decide if we need to use the camera intrinsics
+ if (
+ torch.rand(1, device=device)
+ < self.geometric_input_config["ray_dirs_prob"]
+ ):
+ add_intrinsics(view1, view1.get("camera_intrinsics"))
+ add_intrinsics(view2, view2.get("camera_intrinsics"))
+
+ # Decide if we need to use the depth map
+ if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]:
+ depthmap1 = view1.get("depthmap")
+ depthmap2 = view2.get("depthmap")
+ if depthmap1 is not None:
+ depthmap1 = depthmap1.squeeze(-1).to(device)
+ if depthmap2 is not None:
+ depthmap2 = depthmap2.squeeze(-1).to(device)
+ add_depth(view1, depthmap1)
+ add_depth(view2, depthmap2)
+
+ # Decide if we need to use the camera pose
+ if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]:
+ cam1 = view1.get("camera_pose")
+ cam2 = view2.get("camera_pose")
+ add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1)
+ add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1)
+
+ # Get the model predictions
+ preds = self.model(view1, view2)
+
+ # Convert the output to MapAnything format
+ with torch.autocast("cuda", enabled=False):
+ res = []
+ for view_idx in range(2):
+ # Get the model predictions for the current view
+ curr_view_pred = preds[view_idx]
+
+ # For the first view
+ if view_idx == 0:
+ # Get the global frame and camera frame pointmaps
+ global_pts = curr_view_pred["pts3d"]
+ cam_pts = curr_view_pred["pts3d"]
+ conf = curr_view_pred["conf"]
+
+ # Get the ray directions and depth along ray
+ depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True)
+ ray_directions = cam_pts / depth_along_ray
+
+ # Initalize identity camera pose
+ cam_rot = torch.eye(3, device=device)
+ cam_quat = mat_to_quat(cam_rot)
+ cam_trans = torch.zeros(3, device=device)
+ cam_quat = cam_quat.unsqueeze(0).repeat(batch_size_per_view, 1)
+ cam_trans = cam_trans.unsqueeze(0).repeat(batch_size_per_view, 1)
+ # For the second view
+ elif view_idx == 1:
+ # Get the global frame and camera frame pointmaps
+ pred_global_pts = curr_view_pred["pts3d_in_other_view"]
+ cam_pts = curr_view_pred["pts3d2"]
+ conf = (curr_view_pred["conf"] * curr_view_pred["conf2"]).sqrt()
+
+ # Get the ray directions and depth along ray
+ depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True)
+ ray_directions = cam_pts / depth_along_ray
+
+ # Compute the camera pose using the pointmaps
+ cam_rot, cam_trans, scale = roma.rigid_points_registration(
+ cam_pts.reshape(batch_size_per_view, -1, 3),
+ pred_global_pts.reshape(batch_size_per_view, -1, 3),
+ weights=conf.reshape(batch_size_per_view, -1),
+ compute_scaling=True,
+ )
+ cam_quat = mat_to_quat(cam_rot)
+
+ # Scale the predicted camera frame pointmap and compute the new global frame pointmap
+ cam_pts = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * cam_pts
+ global_pts = cam_pts.reshape(
+ batch_size_per_view, -1, 3
+ ) @ cam_rot.permute(0, 2, 1) + cam_trans.unsqueeze(1)
+ global_pts = global_pts.view(pred_global_pts.shape)
+
+ # Append the result in MapAnything format
+ res.append(
+ {
+ "pts3d": global_pts,
+ "pts3d_cam": cam_pts,
+ "ray_directions": ray_directions,
+ "depth_along_ray": depth_along_ray,
+ "cam_trans": cam_trans,
+ "cam_quats": cam_quat,
+ "conf": conf,
+ }
+ )
+
+ return res
+
+
+class Pow3RBAWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ ckpt_path,
+ geometric_input_config,
+ scene_graph="complete",
+ inference_batch_size=32,
+ global_optim_schedule="cosine",
+ global_optim_lr=0.01,
+ global_optim_niter=300,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = name
+ self.ckpt_path = ckpt_path
+ self.geometric_input_config = geometric_input_config
+ self.scene_graph = scene_graph
+ self.inference_batch_size = inference_batch_size
+ self.global_optim_schedule = global_optim_schedule
+ self.global_optim_lr = global_optim_lr
+ self.global_optim_niter = global_optim_niter
+
+ # Init the model and load the checkpoint
+ print(f"Loading checkpoint from {self.ckpt_path} ...")
+ ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False)
+ model = ckpt["definition"]
+ print(f"Creating model = {model}")
+ self.model = eval(model)
+ print(self.model.load_state_dict(ckpt["weights"]))
+
+ # Init the global aligner mode
+ self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer
+
+ def infer_two_views(self, views):
+ """
+ Wrapper for Pow3R 2-View inference.
+
+ Assumption:
+ - The number of input views is 2.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Length of the list should be 2.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["dust3r"]
+ Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs:
+ "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3).
+ "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation.
+ "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1).
+
+ Returns:
+ List[dict]: A list containing the final outputs for the two views. Length of the list will be 2.
+ """
+ # Check that the number of input views is 2
+ assert len(views) == 2, "Pow3R requires 2 input views."
+
+ # Check the data norm type
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "dust3r", (
+ "Pow3R expects a normalized image with the DUSt3R normalization scheme applied"
+ )
+
+ # Get the device and two views
+ device = views[0]["img"].device
+ view1, view2 = views
+
+ # Decide if we need to use the geometric inputs
+ if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]:
+ # Decide if we need to use the camera intrinsics
+ if (
+ torch.rand(1, device=device)
+ < self.geometric_input_config["ray_dirs_prob"]
+ ):
+ add_intrinsics(view1, view1.get("camera_intrinsics"))
+ add_intrinsics(view2, view2.get("camera_intrinsics"))
+
+ # Decide if we need to use the depth map
+ if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]:
+ depthmap1 = view1.get("depthmap")
+ depthmap2 = view2.get("depthmap")
+ if depthmap1 is not None:
+ depthmap1 = depthmap1.squeeze(-1).to(device)
+ if depthmap2 is not None:
+ depthmap2 = depthmap2.squeeze(-1).to(device)
+ add_depth(view1, depthmap1)
+ add_depth(view2, depthmap2)
+
+ # Decide if we need to use the camera pose
+ if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]:
+ cam1 = view1.get("camera_pose")
+ cam2 = view2.get("camera_pose")
+ add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1)
+ add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1)
+
+ # Get the model predictions
+ preds = self.model(view1, view2)
+
+ return preds
+
+ def loss_of_one_batch(self, batch, device):
+ """
+ Compute prediction for two views.
+ """
+ view1, view2 = batch
+ ignore_keys = set(
+ [
+ "dataset",
+ "label",
+ "instance",
+ "idx",
+ "true_shape",
+ "rng",
+ "name",
+ "data_norm_type",
+ ]
+ )
+ for view in batch:
+ for name in view.keys(): # pseudo_focal
+ if name in ignore_keys:
+ continue
+ view[name] = view[name].to(device, non_blocking=True)
+
+ pred1, pred2 = self.infer_two_views([view1, view2])
+
+ result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2)
+
+ return result
+
+ @torch.no_grad()
+ def inference(self, pairs, device, verbose=False):
+ """
+ Wrapper for multi-pair inference using Pow3R.
+ """
+ if verbose:
+ print(f">> Inference with model on {len(pairs)} image pairs")
+ result = []
+
+ multiple_shapes = not (check_if_same_size(pairs))
+ if multiple_shapes:
+ self.inference_batch_size = 1
+
+ for i in tqdm.trange(
+ 0, len(pairs), self.inference_batch_size, disable=not verbose
+ ):
+ res = self.loss_of_one_batch(
+ collate_with_cat(pairs[i : i + self.inference_batch_size]), device
+ )
+ result.append(to_cpu(res))
+
+ result = collate_with_cat(result, lists=multiple_shapes)
+
+ return result
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for Pow3R using the global aligner.
+
+ Assumption:
+ - The batch size of input views is 1.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys, where B is the batch size and is 1:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["dust3r"]
+
+ Returns:
+ List[dict]: A list containing the final outputs for the input views.
+ """
+ # Check the batch size of input views
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ device = views[0]["img"].device
+ num_views = len(views)
+ assert batch_size_per_view == 1, (
+ f"Batch size of input views should be 1, but got {batch_size_per_view}."
+ )
+
+ # Check the data norm type
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "dust3r", (
+ "Pow3R-BA expects a normalized image with the DUSt3R normalization scheme applied"
+ )
+
+ # Convert the input views to the expected input format
+ images = []
+ for view in views:
+ images.append(
+ dict(
+ img=view["img"],
+ camera_intrinsics=view["camera_intrinsics"],
+ depthmap=view["depthmap"],
+ camera_pose=view["camera_pose"],
+ data_norm_type=view["data_norm_type"],
+ true_shape=view["true_shape"],
+ idx=len(images),
+ instance=str(len(images)),
+ )
+ )
+
+ # Make image pairs and run inference pair-wise
+ pairs = make_pairs(
+ images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
+ )
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=FutureWarning)
+ output = self.inference(
+ pairs,
+ device,
+ verbose=False,
+ )
+
+ # Global optimization
+ with torch.enable_grad():
+ scene = global_aligner(
+ output, device=device, mode=self.global_aligner_mode, verbose=False
+ )
+ _ = scene.compute_global_alignment(
+ init="mst",
+ niter=self.global_optim_niter,
+ schedule=self.global_optim_schedule,
+ lr=self.global_optim_lr,
+ )
+
+ # Make sure scene is not None
+ if scene is None:
+ raise RuntimeError("Global optimization failed.")
+
+ # Get the predictions
+ intrinsics = scene.get_intrinsics()
+ c2w_poses = scene.get_im_poses()
+ depths = scene.get_depthmaps()
+
+ # Convert the output to the MapAnything format
+ with torch.autocast("cuda", enabled=False):
+ res = []
+ for view_idx in range(num_views):
+ # Get the current view predictions
+ curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
+ curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
+ curr_view_depth_z = depths[view_idx].unsqueeze(0)
+
+ # Convert the pose to quaternions and translation
+ curr_view_cam_translations = curr_view_pose[..., :3, 3]
+ curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
+
+ # Get the camera frame pointmaps
+ curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
+ curr_view_depth_z, curr_view_intrinsic
+ )
+
+ # Convert the z depth to depth along ray
+ curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
+ curr_view_depth_z, curr_view_intrinsic
+ )
+ curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
+
+ # Get the ray directions on the unit sphere in the camera frame
+ _, curr_view_ray_dirs = get_rays_in_camera_frame(
+ curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
+ )
+
+ # Get the pointmaps
+ curr_view_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ curr_view_ray_dirs,
+ curr_view_depth_along_ray,
+ curr_view_cam_translations,
+ curr_view_cam_quats,
+ )
+ )
+
+ # Append the outputs to the result list
+ res.append(
+ {
+ "pts3d": curr_view_pts3d,
+ "pts3d_cam": curr_view_pts3d_cam,
+ "ray_directions": curr_view_ray_dirs,
+ "depth_along_ray": curr_view_depth_along_ray,
+ "cam_trans": curr_view_cam_translations,
+ "cam_quats": curr_view_cam_quats,
+ }
+ )
+
+ return res
diff --git a/mapanything/models/external/vggt/__init__.py b/mapanything/models/external/vggt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4293799cc24c3e954a92af9d824053c7ceadb63d
--- /dev/null
+++ b/mapanything/models/external/vggt/__init__.py
@@ -0,0 +1,186 @@
+"""
+Inference wrapper for VGGT
+"""
+
+import torch
+
+from mapanything.models.external.vggt.models.vggt import VGGT
+from mapanything.models.external.vggt.utils.geometry import closed_form_inverse_se3
+from mapanything.models.external.vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from mapanything.models.external.vggt.utils.rotation import mat_to_quat
+from mapanything.utils.geometry import (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ convert_z_depth_to_depth_along_ray,
+ depthmap_to_camera_frame,
+ get_rays_in_camera_frame,
+)
+
+
+class VGGTWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ name,
+ torch_hub_force_reload,
+ load_pretrained_weights=True,
+ depth=24,
+ num_heads=16,
+ intermediate_layer_idx=[4, 11, 17, 23],
+ load_custom_ckpt=False,
+ custom_ckpt_path=None,
+ ):
+ super().__init__()
+ self.name = name
+ self.torch_hub_force_reload = torch_hub_force_reload
+ self.load_custom_ckpt = load_custom_ckpt
+ self.custom_ckpt_path = custom_ckpt_path
+
+ if load_pretrained_weights:
+ # Load pre-trained weights
+ if not torch_hub_force_reload:
+ # Initialize the 1B VGGT model from huggingface hub cache
+ print("Loading facebook/VGGT-1B from huggingface cache ...")
+ self.model = VGGT.from_pretrained(
+ "facebook/VGGT-1B",
+ )
+ else:
+ # Initialize the 1B VGGT model
+ print("Re-downloading facebook/VGGT-1B ...")
+ self.model = VGGT.from_pretrained(
+ "facebook/VGGT-1B", force_download=True
+ )
+ else:
+ # Load the VGGT class
+ self.model = VGGT(
+ depth=depth,
+ num_heads=num_heads,
+ intermediate_layer_idx=intermediate_layer_idx,
+ )
+
+ # Get the dtype for VGGT inference
+ # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
+ self.dtype = (
+ torch.bfloat16
+ if torch.cuda.get_device_capability()[0] >= 8
+ else torch.float16
+ )
+
+ # Load custom checkpoint if requested
+ if self.load_custom_ckpt:
+ print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
+ assert self.custom_ckpt_path is not None, (
+ "custom_ckpt_path must be provided if load_custom_ckpt is set to True"
+ )
+ custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
+ print(self.model.load_state_dict(custom_ckpt, strict=True))
+ del custom_ckpt # in case it occupies memory
+
+ def forward(self, views):
+ """
+ Forward pass wrapper for VGGT
+
+ Assumption:
+ - All the input views have the same image shape.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W).
+ "data_norm_type" (list): ["identity"]
+
+ Returns:
+ List[dict]: A list containing the final outputs for all N views.
+ """
+ # Get input shape of the images, number of views, and batch size per view
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ num_views = len(views)
+
+ # Check the data norm type
+ # VGGT expects a normalized image but without the DINOv2 mean and std applied ("identity")
+ data_norm_type = views[0]["data_norm_type"][0]
+ assert data_norm_type == "identity", (
+ "VGGT expects a normalized image but without the DINOv2 mean and std applied"
+ )
+
+ # Concatenate the images to create a single (B, V, C, H, W) tensor
+ img_list = [view["img"] for view in views]
+ images = torch.stack(img_list, dim=1)
+
+ # Run the VGGT aggregator
+ with torch.autocast("cuda", dtype=self.dtype):
+ aggregated_tokens_list, ps_idx = self.model.aggregator(images)
+
+ # Run the Camera + Pose Branch of VGGT
+ with torch.autocast("cuda", enabled=False):
+ # Predict Cameras
+ pose_enc = self.model.camera_head(aggregated_tokens_list)[-1]
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
+ # Extrinsics Shape: (B, V, 3, 4)
+ # Intrinsics Shape: (B, V, 3, 3)
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(
+ pose_enc, images.shape[-2:]
+ )
+
+ # Predict Depth Maps
+ # Depth Shape: (B, V, H, W, 1)
+ # Depth Confidence Shape: (B, V, H, W)
+ depth_map, depth_conf = self.model.depth_head(
+ aggregated_tokens_list, images, ps_idx
+ )
+
+ # Convert the output to MapAnything format
+ res = []
+ for view_idx in range(num_views):
+ # Get the extrinsics, intrinsics, depth map for the current view
+ curr_view_extrinsic = extrinsic[:, view_idx, ...]
+ curr_view_extrinsic = closed_form_inverse_se3(
+ curr_view_extrinsic
+ ) # Convert to cam2world
+ curr_view_intrinsic = intrinsic[:, view_idx, ...]
+ curr_view_depth_z = depth_map[:, view_idx, ...]
+ curr_view_depth_z = curr_view_depth_z.squeeze(-1)
+ curr_view_confidence = depth_conf[:, view_idx, ...]
+
+ # Get the camera frame pointmaps
+ curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
+ curr_view_depth_z, curr_view_intrinsic
+ )
+
+ # Convert the extrinsics to quaternions and translations
+ curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
+ curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
+
+ # Convert the z depth to depth along ray
+ curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
+ curr_view_depth_z, curr_view_intrinsic
+ )
+ curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
+
+ # Get the ray directions on the unit sphere in the camera frame
+ _, curr_view_ray_dirs = get_rays_in_camera_frame(
+ curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
+ )
+
+ # Get the pointmaps
+ curr_view_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ curr_view_ray_dirs,
+ curr_view_depth_along_ray,
+ curr_view_cam_translations,
+ curr_view_cam_quats,
+ )
+ )
+
+ # Append the outputs to the result list
+ res.append(
+ {
+ "pts3d": curr_view_pts3d,
+ "pts3d_cam": curr_view_pts3d_cam,
+ "ray_directions": curr_view_ray_dirs,
+ "depth_along_ray": curr_view_depth_along_ray,
+ "cam_trans": curr_view_cam_translations,
+ "cam_quats": curr_view_cam_quats,
+ "conf": curr_view_confidence,
+ }
+ )
+
+ return res
diff --git a/mapanything/models/external/vggt/heads/__init__.py b/mapanything/models/external/vggt/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/models/external/vggt/heads/camera_head.py b/mapanything/models/external/vggt/heads/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8bf683e8925d4002c9c4d547753724e23034eae
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/camera_head.py
@@ -0,0 +1,167 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+
+from mapanything.models.external.vggt.heads.head_act import activate_pose
+from mapanything.models.external.vggt.layers import Mlp
+from mapanything.models.external.vggt.layers.block import Block
+
+
+class CameraHead(nn.Module):
+ """
+ CameraHead predicts camera parameters from token representations using iterative refinement.
+
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
+ """
+
+ def __init__(
+ self,
+ dim_in: int = 2048,
+ trunk_depth: int = 4,
+ pose_encoding_type: str = "absT_quaR_FoV",
+ num_heads: int = 16,
+ mlp_ratio: int = 4,
+ init_values: float = 0.01,
+ trans_act: str = "linear",
+ quat_act: str = "linear",
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
+ ):
+ super().__init__()
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ self.target_dim = 9
+ else:
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
+
+ self.trans_act = trans_act
+ self.quat_act = quat_act
+ self.fl_act = fl_act
+ self.trunk_depth = trunk_depth
+
+ # Build the trunk using a sequence of transformer blocks.
+ self.trunk = nn.Sequential(
+ *[
+ Block(
+ dim=dim_in,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ init_values=init_values,
+ )
+ for _ in range(trunk_depth)
+ ]
+ )
+
+ # Normalizations for camera token and trunk output.
+ self.token_norm = nn.LayerNorm(dim_in)
+ self.trunk_norm = nn.LayerNorm(dim_in)
+
+ # Learnable empty camera pose token.
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
+
+ # Module for producing modulation parameters: shift, scale, and a gate.
+ self.poseLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)
+ )
+
+ # Adaptive layer normalization without affine parameters.
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
+ self.pose_branch = Mlp(
+ in_features=dim_in,
+ hidden_features=dim_in // 2,
+ out_features=self.target_dim,
+ drop=0,
+ )
+
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
+ """
+ Forward pass to predict camera parameters.
+
+ Args:
+ aggregated_tokens_list (list): List of token tensors from the network;
+ the last tensor is used for prediction.
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
+
+ Returns:
+ list: A list of predicted camera encodings (post-activation) from each iteration.
+ """
+ # Use tokens from the last block for camera prediction.
+ tokens = aggregated_tokens_list[-1]
+
+ # Extract the camera tokens
+ pose_tokens = tokens[:, :, 0]
+ pose_tokens = self.token_norm(pose_tokens)
+
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
+ return pred_pose_enc_list
+
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
+ """
+ Iteratively refine camera pose predictions.
+
+ Args:
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
+ num_iterations (int): Number of refinement iterations.
+
+ Returns:
+ list: List of activated camera encodings from each iteration.
+ """
+ B, S, C = pose_tokens.shape # S is expected to be 1.
+ pred_pose_enc = None
+ pred_pose_enc_list = []
+
+ for _ in range(num_iterations):
+ # Use a learned empty pose for the first iteration.
+ if pred_pose_enc is None:
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
+ else:
+ # Detach the previous prediction to avoid backprop through time.
+ pred_pose_enc = pred_pose_enc.detach()
+ module_input = self.embed_pose(pred_pose_enc)
+
+ # Generate modulation parameters and split them into shift, scale, and gate components.
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(
+ 3, dim=-1
+ )
+
+ # Adaptive layer normalization and modulation.
+ pose_tokens_modulated = gate_msa * modulate(
+ self.adaln_norm(pose_tokens), shift_msa, scale_msa
+ )
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
+
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
+ # Compute the delta update for the pose encoding.
+ pred_pose_enc_delta = self.pose_branch(
+ self.trunk_norm(pose_tokens_modulated)
+ )
+
+ if pred_pose_enc is None:
+ pred_pose_enc = pred_pose_enc_delta
+ else:
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
+
+ # Apply final activation functions for translation, quaternion, and field-of-view.
+ activated_pose = activate_pose(
+ pred_pose_enc,
+ trans_act=self.trans_act,
+ quat_act=self.quat_act,
+ fl_act=self.fl_act,
+ )
+ pred_pose_enc_list.append(activated_pose)
+
+ return pred_pose_enc_list
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ """
+ Modulate the input tensor using scaling and shifting parameters.
+ """
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
+ return x * (1 + scale) + shift
diff --git a/mapanything/models/external/vggt/heads/dpt_head.py b/mapanything/models/external/vggt/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..10e49054db3eebfb79386dded42f83e929b47004
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/dpt_head.py
@@ -0,0 +1,600 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
+
+
+from typing import List, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from .head_act import activate_head
+from .utils import create_uv_grid, position_grid_to_embed
+
+
+class DPTHead(nn.Module):
+ """
+ DPT Head for dense prediction tasks.
+
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
+ backbone and produces dense predictions by fusing multi-scale features.
+
+ Args:
+ dim_in (int): Input dimension (channels).
+ patch_size (int, optional): Patch size. Default is 14.
+ output_dim (int, optional): Number of output channels. Default is 4.
+ activation (str, optional): Activation type. Default is "inv_log".
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
+ out_channels (List[int], optional): Output channels for each intermediate layer.
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 4,
+ activation: str = "inv_log",
+ conf_activation: str = "expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
+ pos_embed: bool = True,
+ feature_only: bool = False,
+ down_ratio: int = 1,
+ ) -> None:
+ super(DPTHead, self).__init__()
+ self.patch_size = patch_size
+ self.activation = activation
+ self.conf_activation = conf_activation
+ self.pos_embed = pos_embed
+ self.feature_only = feature_only
+ self.down_ratio = down_ratio
+ self.intermediate_layer_idx = intermediate_layer_idx
+
+ self.norm = nn.LayerNorm(dim_in)
+
+ # Projection layers for each output channel from tokens.
+ self.projects = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=dim_in,
+ out_channels=oc,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for oc in out_channels
+ ]
+ )
+
+ # Resize layers for upsampling feature maps.
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ ]
+ )
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ expand=False,
+ )
+
+ # Attach additional modules to scratch.
+ self.scratch.stem_transpose = None
+ self.scratch.refinenet1 = _make_fusion_block(features)
+ self.scratch.refinenet2 = _make_fusion_block(features)
+ self.scratch.refinenet3 = _make_fusion_block(features)
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if feature_only:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1,
+ head_features_1 // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ conv2_in_channels = head_features_1 // 2
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(
+ conv2_in_channels,
+ head_features_2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(
+ head_features_2, output_dim, kernel_size=1, stride=1, padding=0
+ ),
+ )
+
+ def forward(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_chunk_size: int = 8,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Forward pass through the DPT head, supports processing by chunking frames.
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
+ If None or larger than S, all frames are processed at once. Default: 8.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]:
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
+ """
+ B, S, _, H, W = images.shape
+
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
+
+ # Otherwise, process frames in chunks to manage memory usage
+ assert frames_chunk_size > 0
+
+ # Process frames in batches
+ all_preds = []
+ all_conf = []
+
+ for frames_start_idx in range(0, S, frames_chunk_size):
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
+
+ # Process batch of frames
+ if self.feature_only:
+ chunk_output = self._forward_impl(
+ aggregated_tokens_list,
+ images,
+ patch_start_idx,
+ frames_start_idx,
+ frames_end_idx,
+ )
+ all_preds.append(chunk_output)
+ else:
+ chunk_preds, chunk_conf = self._forward_impl(
+ aggregated_tokens_list,
+ images,
+ patch_start_idx,
+ frames_start_idx,
+ frames_end_idx,
+ )
+ all_preds.append(chunk_preds)
+ all_conf.append(chunk_conf)
+
+ # Concatenate results along the sequence dimension
+ if self.feature_only:
+ return torch.cat(all_preds, dim=1)
+ else:
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
+
+ def _forward_impl(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_start_idx: int = None,
+ frames_end_idx: int = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Implementation of the forward pass through the DPT head.
+
+ This method processes a specific chunk of frames from the sequence.
+
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W].
+ patch_start_idx (int): Starting index for patch tokens.
+ frames_start_idx (int, optional): Starting index for frames to process.
+ frames_end_idx (int, optional): Ending index for frames to process.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
+ """
+ if frames_start_idx is not None and frames_end_idx is not None:
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
+
+ B, S, _, H, W = images.shape
+
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
+
+ out = []
+ dpt_idx = 0
+
+ for layer_idx in self.intermediate_layer_idx:
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
+
+ # Select frames if processing a chunk
+ if frames_start_idx is not None and frames_end_idx is not None:
+ x = x[:, frames_start_idx:frames_end_idx]
+
+ x = x.reshape(B * S, -1, x.shape[-1])
+
+ x = self.norm(x)
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[dpt_idx](x)
+ if self.pos_embed:
+ x = self._apply_pos_embed(x, W, H)
+ x = self.resize_layers[dpt_idx](x)
+
+ out.append(x)
+ dpt_idx += 1
+
+ # Fuse features from multiple layers.
+ out = self.scratch_forward(out)
+ # Interpolate fused output to match target image resolution.
+ out = custom_interpolate(
+ out,
+ (
+ int(patch_h * self.patch_size / self.down_ratio),
+ int(patch_w * self.patch_size / self.down_ratio),
+ ),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ if self.pos_embed:
+ out = self._apply_pos_embed(out, W, H)
+
+ if self.feature_only:
+ return out.view(B, S, *out.shape[1:])
+
+ out = self.scratch.output_conv2(out)
+ preds, conf = activate_head(
+ out, activation=self.activation, conf_activation=self.conf_activation
+ )
+
+ preds = preds.view(B, S, *preds.shape[1:])
+ conf = conf.view(B, S, *conf.shape[1:])
+ return preds, conf
+
+ def _apply_pos_embed(
+ self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
+ ) -> torch.Tensor:
+ """
+ Apply positional embedding to tensor x.
+ """
+ patch_w = x.shape[-1]
+ patch_h = x.shape[-2]
+ pos_embed = create_uv_grid(
+ patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
+ )
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
+ pos_embed = pos_embed * ratio
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
+ return x + pos_embed
+
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Forward pass through the fusion blocks.
+
+ Args:
+ features (List[Tensor]): List of feature maps from different layers.
+
+ Returns:
+ Tensor: Fused feature map.
+ """
+ layer_1, layer_2, layer_3, layer_4 = features
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ del layer_4_rn, layer_4
+
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
+ del layer_3_rn, layer_3
+
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
+ del layer_2_rn, layer_2
+
+ out = self.scratch.refinenet1(out, layer_1_rn)
+ del layer_1_rn, layer_1
+
+ out = self.scratch.output_conv1(out)
+ return out
+
+
+################################################################################
+# Modules
+################################################################################
+
+
+def _make_fusion_block(
+ features: int, size: int = None, has_residual: bool = True, groups: int = 1
+) -> nn.Module:
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(inplace=True),
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=size,
+ has_residual=has_residual,
+ groups=groups,
+ )
+
+
+def _make_scratch(
+ in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
+) -> nn.Module:
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0],
+ out_shape1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1],
+ out_shape2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2],
+ out_shape3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3],
+ out_shape4,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn, groups=1):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+ self.groups = groups
+ self.conv1 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ groups=self.groups,
+ )
+ self.conv2 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ groups=self.groups,
+ )
+
+ self.norm1 = None
+ self.norm2 = None
+
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.norm1 is not None:
+ out = self.norm1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.norm2 is not None:
+ out = self.norm2(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ has_residual=True,
+ groups=1,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups = groups
+ self.expand = expand
+ out_features = features
+ if self.expand:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ groups=self.groups,
+ )
+
+ if has_residual:
+ self.resConfUnit1 = ResidualConvUnit(
+ features, activation, bn, groups=self.groups
+ )
+
+ self.has_residual = has_residual
+ self.resConfUnit2 = ResidualConvUnit(
+ features, activation, bn, groups=self.groups
+ )
+
+ self.skip_add = nn.quantized.FloatFunctional()
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if self.has_residual:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = custom_interpolate(
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
+ )
+ output = self.out_conv(output)
+
+ return output
+
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Tuple[int, int] = None,
+ scale_factor: float = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(
+ chunk, size=size, mode=mode, align_corners=align_corners
+ )
+ for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(
+ x, size=size, mode=mode, align_corners=align_corners
+ )
diff --git a/mapanything/models/external/vggt/heads/head_act.py b/mapanything/models/external/vggt/heads/head_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..152c9a6b07f63748c78e6205f48ad8f9db590082
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/head_act.py
@@ -0,0 +1,127 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn.functional as F
+
+
+def activate_pose(
+ pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"
+):
+ """
+ Activate pose parameters with specified activation functions.
+
+ Args:
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
+ trans_act: Activation type for translation component
+ quat_act: Activation type for quaternion component
+ fl_act: Activation type for focal length component
+
+ Returns:
+ Activated pose parameters tensor
+ """
+ T = pred_pose_enc[..., :3]
+ quat = pred_pose_enc[..., 3:7]
+ fl = pred_pose_enc[..., 7:] # or fov
+
+ T = base_pose_act(T, trans_act)
+ quat = base_pose_act(quat, quat_act)
+ fl = base_pose_act(fl, fl_act) # or fov
+
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
+
+ return pred_pose_enc
+
+
+def base_pose_act(pose_enc, act_type="linear"):
+ """
+ Apply basic activation function to pose parameters.
+
+ Args:
+ pose_enc: Tensor containing encoded pose parameters
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
+
+ Returns:
+ Activated pose parameters
+ """
+ if act_type == "linear":
+ return pose_enc
+ elif act_type == "inv_log":
+ return inverse_log_transform(pose_enc)
+ elif act_type == "exp":
+ return torch.exp(pose_enc)
+ elif act_type == "relu":
+ return F.relu(pose_enc)
+ else:
+ raise ValueError(f"Unknown act_type: {act_type}")
+
+
+def activate_head(out, activation="norm_exp", conf_activation="expp1"):
+ """
+ Process network output to extract 3D points and confidence values.
+
+ Args:
+ out: Network output tensor (B, C, H, W)
+ activation: Activation type for 3D points
+ conf_activation: Activation type for confidence values
+
+ Returns:
+ Tuple of (3D points tensor, confidence tensor)
+ """
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
+
+ # Split into xyz (first C-1 channels) and confidence (last channel)
+ xyz = fmap[:, :, :, :-1]
+ conf = fmap[:, :, :, -1]
+
+ if activation == "norm_exp":
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+ xyz_normed = xyz / d
+ pts3d = xyz_normed * torch.expm1(d)
+ elif activation == "norm":
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
+ elif activation == "exp":
+ pts3d = torch.exp(xyz)
+ elif activation == "relu":
+ pts3d = F.relu(xyz)
+ elif activation == "inv_log":
+ pts3d = inverse_log_transform(xyz)
+ elif activation == "xy_inv_log":
+ xy, z = xyz.split([2, 1], dim=-1)
+ z = inverse_log_transform(z)
+ pts3d = torch.cat([xy * z, z], dim=-1)
+ elif activation == "sigmoid":
+ pts3d = torch.sigmoid(xyz)
+ elif activation == "linear":
+ pts3d = xyz
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ if conf_activation == "expp1":
+ conf_out = 1 + conf.exp()
+ elif conf_activation == "expp0":
+ conf_out = conf.exp()
+ elif conf_activation == "sigmoid":
+ conf_out = torch.sigmoid(conf)
+ else:
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
+
+ return pts3d, conf_out
+
+
+def inverse_log_transform(y):
+ """
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ y: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
diff --git a/mapanything/models/external/vggt/heads/track_head.py b/mapanything/models/external/vggt/heads/track_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..afb93d622577602c2d0b46b086ededf3433529f9
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/track_head.py
@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+
+from .dpt_head import DPTHead
+from .track_modules.base_track_predictor import BaseTrackerPredictor
+
+
+class TrackHead(nn.Module):
+ """
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
+ The tracking is performed iteratively, refining predictions over multiple iterations.
+ """
+
+ def __init__(
+ self,
+ dim_in,
+ patch_size=14,
+ features=128,
+ iters=4,
+ predict_conf=True,
+ stride=2,
+ corr_levels=7,
+ corr_radius=4,
+ hidden_size=384,
+ ):
+ """
+ Initialize the TrackHead module.
+
+ Args:
+ dim_in (int): Input dimension of tokens from the backbone.
+ patch_size (int): Size of image patches used in the vision transformer.
+ features (int): Number of feature channels in the feature extractor output.
+ iters (int): Number of refinement iterations for tracking predictions.
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
+ stride (int): Stride value for the tracker predictor.
+ corr_levels (int): Number of correlation pyramid levels
+ corr_radius (int): Radius for correlation computation, controlling the search area.
+ hidden_size (int): Size of hidden layers in the tracker network.
+ """
+ super().__init__()
+
+ self.patch_size = patch_size
+
+ # Feature extractor based on DPT architecture
+ # Processes tokens into feature maps for tracking
+ self.feature_extractor = DPTHead(
+ dim_in=dim_in,
+ patch_size=patch_size,
+ features=features,
+ feature_only=True, # Only output features, no activation
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
+ pos_embed=False,
+ )
+
+ # Tracker module that predicts point trajectories
+ # Takes feature maps and predicts coordinates and visibility
+ self.tracker = BaseTrackerPredictor(
+ latent_dim=features, # Match the output_dim of feature extractor
+ predict_conf=predict_conf,
+ stride=stride,
+ corr_levels=corr_levels,
+ corr_radius=corr_radius,
+ hidden_size=hidden_size,
+ )
+
+ self.iters = iters
+
+ def forward(
+ self,
+ aggregated_tokens_list,
+ images,
+ patch_start_idx,
+ query_points=None,
+ iters=None,
+ ):
+ """
+ Forward pass of the TrackHead.
+
+ Args:
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
+ B = batch size, S = sequence length.
+ patch_start_idx (int): Starting index for patch tokens.
+ query_points (torch.Tensor, optional): Initial query points to track.
+ If None, points are initialized by the tracker.
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
+
+ Returns:
+ tuple:
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
+ """
+ B, S, _, H, W = images.shape
+
+ # Extract features from tokens
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
+ feature_maps = self.feature_extractor(
+ aggregated_tokens_list, images, patch_start_idx
+ )
+
+ # Use default iterations if not specified
+ if iters is None:
+ iters = self.iters
+
+ # Perform tracking using the extracted features
+ coord_preds, vis_scores, conf_scores = self.tracker(
+ query_points=query_points,
+ fmaps=feature_maps,
+ iters=iters,
+ )
+
+ return coord_preds, vis_scores, conf_scores
diff --git a/mapanything/models/external/vggt/heads/track_modules/__init__.py b/mapanything/models/external/vggt/heads/track_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/track_modules/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py b/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a34b8c683dc7b9bcd80a690c31620ec1c32dd08
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py
@@ -0,0 +1,242 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from .blocks import CorrBlock, EfficientUpdateFormer
+from .modules import Mlp
+from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d
+
+
+class BaseTrackerPredictor(nn.Module):
+ def __init__(
+ self,
+ stride=1,
+ corr_levels=5,
+ corr_radius=4,
+ latent_dim=128,
+ hidden_size=384,
+ use_spaceatt=True,
+ depth=6,
+ max_scale=518,
+ predict_conf=True,
+ ):
+ super(BaseTrackerPredictor, self).__init__()
+ """
+ The base template to create a track predictor
+
+ Modified from https://github.com/facebookresearch/co-tracker/
+ and https://github.com/facebookresearch/vggsfm
+ """
+
+ self.stride = stride
+ self.latent_dim = latent_dim
+ self.corr_levels = corr_levels
+ self.corr_radius = corr_radius
+ self.hidden_size = hidden_size
+ self.max_scale = max_scale
+ self.predict_conf = predict_conf
+
+ self.flows_emb_dim = latent_dim // 2
+
+ self.corr_mlp = Mlp(
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
+ hidden_features=self.hidden_size,
+ out_features=self.latent_dim,
+ )
+
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
+
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
+
+ space_depth = depth if use_spaceatt else 0
+ time_depth = depth
+
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=space_depth,
+ time_depth=time_depth,
+ input_dim=self.transformer_dim,
+ hidden_size=self.hidden_size,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=use_spaceatt,
+ )
+
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
+
+ # A linear layer to update track feats at each iteration
+ self.ffeat_updater = nn.Sequential(
+ nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()
+ )
+
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ if predict_conf:
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ def forward(
+ self,
+ query_points,
+ fmaps=None,
+ iters=6,
+ return_feat=False,
+ down_ratio=1,
+ apply_sigmoid=True,
+ ):
+ """
+ query_points: B x N x 2, the number of batches, tracks, and xy
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
+ note HH and WW is the size of feature maps instead of original images
+ """
+ B, N, D = query_points.shape
+ B, S, C, HH, WW = fmaps.shape
+
+ assert D == 2, "Input points must be 2D coordinates"
+
+ # apply a layernorm to fmaps here
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
+
+ # Scale the input query_points because we may downsample the images
+ # by down_ratio or self.stride
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
+ # its query_points should be query_points/4
+ if down_ratio > 1:
+ query_points = query_points / float(down_ratio)
+
+ query_points = query_points / float(self.stride)
+
+ # Init with coords as the query points
+ # It means the search will start from the position of query points at the reference frames
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
+
+ # Sample/extract the features of the query points in the query frame
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
+
+ # init track feats by query feats
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
+ # back up the init coords
+ coords_backup = coords.clone()
+
+ fcorr_fn = CorrBlock(
+ fmaps, num_levels=self.corr_levels, radius=self.corr_radius
+ )
+
+ coord_preds = []
+
+ # Iterative Refinement
+ for _ in range(iters):
+ # Detach the gradients from the last iteration
+ # (in my experience, not very important for performance)
+ coords = coords.detach()
+
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
+
+ corr_dim = fcorrs.shape[3]
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
+ fcorrs_ = self.corr_mlp(fcorrs_)
+
+ # Movement of current coords relative to query points
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
+
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
+ flows_emb = torch.cat(
+ [flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1
+ )
+
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(
+ B * N, S, self.latent_dim
+ )
+
+ # Concatenate them as the input for the transformers
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
+
+ # 2D positional embed
+ # TODO: this can be much simplified
+ pos_embed = get_2d_sincos_pos_embed(
+ self.transformer_dim, grid_size=(HH, WW)
+ ).to(query_points.device)
+ sampled_pos_emb = sample_features4d(
+ pos_embed.expand(B, -1, -1, -1), coords[:, 0]
+ )
+
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(
+ 1
+ )
+
+ x = transformer_input + sampled_pos_emb
+
+ # Add the query ref token to the track feats
+ query_ref_token = torch.cat(
+ [
+ self.query_ref_token[:, 0:1],
+ self.query_ref_token[:, 1:2].expand(-1, S - 1, -1),
+ ],
+ dim=1,
+ )
+ x = x + query_ref_token.to(x.device).to(x.dtype)
+
+ # B, N, S, C
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
+
+ # Compute the delta coordinates and delta track features
+ delta, _ = self.updateformer(x)
+
+ # BN, S, C
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
+ delta_coords_ = delta[:, :, :2]
+ delta_feats_ = delta[:, :, 2:]
+
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
+
+ # Update the track features
+ track_feats_ = (
+ self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
+ )
+
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(
+ 0, 2, 1, 3
+ ) # BxSxNxC
+
+ # B x S x N x 2
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
+
+ # Force coord0 as query
+ # because we assume the query points should not be changed
+ coords[:, 0] = coords_backup[:, 0]
+
+ # The predicted tracks are in the original image scale
+ if down_ratio > 1:
+ coord_preds.append(coords * self.stride * down_ratio)
+ else:
+ coord_preds.append(coords * self.stride)
+
+ # B, S, N
+ vis_e = self.vis_predictor(
+ track_feats.reshape(B * S * N, self.latent_dim)
+ ).reshape(B, S, N)
+ if apply_sigmoid:
+ vis_e = torch.sigmoid(vis_e)
+
+ if self.predict_conf:
+ conf_e = self.conf_predictor(
+ track_feats.reshape(B * S * N, self.latent_dim)
+ ).reshape(B, S, N)
+ if apply_sigmoid:
+ conf_e = torch.sigmoid(conf_e)
+ else:
+ conf_e = None
+
+ if return_feat:
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
+ else:
+ return coord_preds, vis_e, conf_e
diff --git a/mapanything/models/external/vggt/heads/track_modules/blocks.py b/mapanything/models/external/vggt/heads/track_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2452d4e5c5c64e18d811417dbcc6b2d2ec5aae95
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/track_modules/blocks.py
@@ -0,0 +1,288 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Modified from https://github.com/facebookresearch/co-tracker/
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .modules import AttnBlock, CrossAttnBlock
+from .utils import bilinear_sampler
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+
+ # Add input LayerNorm before linear projection
+ self.input_norm = nn.LayerNorm(input_dim)
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+
+ # Add output LayerNorm before final projection
+ self.output_norm = nn.LayerNorm(hidden_size)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+
+ if self.add_space_attn:
+ self.virual_tracks = nn.Parameter(
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
+ )
+ else:
+ self.virual_tracks = None
+
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=nn.MultiheadAttention,
+ )
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=nn.MultiheadAttention,
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None):
+ # Apply input LayerNorm
+ input_tensor = self.input_norm(input_tensor)
+ tokens = self.input_transform(input_tensor)
+
+ init_tokens = tokens
+
+ B, _, T, _ = tokens.shape
+
+ if self.add_space_attn:
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (
+ i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
+ ):
+ space_tokens = (
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
+ ) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](
+ virtual_tokens, point_tokens, mask=mask
+ )
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](
+ point_tokens, virtual_tokens, mask=mask
+ )
+
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(
+ 0, 2, 1, 3
+ ) # (B T) N C -> B N T C
+ j += 1
+
+ if self.add_space_attn:
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+
+ tokens = tokens + init_tokens
+
+ # Apply output LayerNorm before final projection
+ tokens = self.output_norm(tokens)
+ flow = self.flow_head(tokens)
+
+ return flow, None
+
+
+class CorrBlock:
+ def __init__(
+ self,
+ fmaps,
+ num_levels=4,
+ radius=4,
+ multiple_track_feats=False,
+ padding_mode="zeros",
+ ):
+ """
+ Build a pyramid of feature maps from the input.
+
+ fmaps: Tensor (B, S, C, H, W)
+ num_levels: number of pyramid levels (each downsampled by factor 2)
+ radius: search radius for sampling correlation
+ multiple_track_feats: if True, split the target features per pyramid level
+ padding_mode: passed to grid_sample / bilinear_sampler
+ """
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.num_levels = num_levels
+ self.radius = radius
+ self.padding_mode = padding_mode
+ self.multiple_track_feats = multiple_track_feats
+
+ # Build pyramid: each level is half the spatial resolution of the previous
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
+ current_fmaps = fmaps
+ for i in range(num_levels - 1):
+ B, S, C, H, W = current_fmaps.shape
+ # Merge batch & sequence dimensions
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
+ # Avg pool down by factor 2
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
+ _, _, H_new, W_new = current_fmaps.shape
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
+ self.fmaps_pyramid.append(current_fmaps)
+
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
+ # This grid is added to the (scaled) coordinate centroids.
+ r = self.radius
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
+ self.delta = torch.stack(
+ torch.meshgrid(dy, dx, indexing="ij"), dim=-1
+ ) # shape: (2r+1, 2r+1, 2)
+
+ def corr_sample(self, targets, coords):
+ """
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
+ volume, sample it immediately, then discard it. This saves GPU memory.
+
+ Args:
+ targets: Tensor (B, S, N, C) — features for the current targets.
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
+
+ Returns:
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
+ """
+ B, S, N, C = targets.shape
+
+ # If you have multiple track features, split them per level.
+ if self.multiple_track_feats:
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
+
+ out_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ # Get current spatial resolution H, W for this pyramid level.
+ B, S, C, H, W = fmaps.shape
+ # Reshape feature maps for correlation computation:
+ # fmap2s: (B, S, C, H*W)
+ fmap2s = fmaps.view(B, S, C, H * W)
+ # Choose appropriate target features.
+ fmap1 = (
+ targets_split[i] if self.multiple_track_feats else targets
+ ) # shape: (B, S, N, C)
+
+ # Compute correlation directly
+ corrs = compute_corr_level(fmap1, fmap2s, C)
+ corrs = corrs.view(B, S, N, H, W)
+
+ # Prepare sampling grid:
+ # Scale down the coordinates for the current level.
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
+ # Make sure our precomputed delta grid is on the same device/dtype.
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
+ # Now the grid for grid_sample is:
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
+ coords_lvl = centroid_lvl + delta_lvl.view(
+ 1, 2 * self.radius + 1, 2 * self.radius + 1, 2
+ )
+
+ # Sample from the correlation volume using bilinear interpolation.
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
+ corrs_sampled = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W),
+ coords_lvl,
+ padding_mode=self.padding_mode,
+ )
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
+ corrs_sampled = corrs_sampled.view(
+ B, S, N, -1
+ ) # Now shape: (B, S, N, (2r+1)^2)
+ out_pyramid.append(corrs_sampled)
+
+ # Concatenate all levels along the last dimension.
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
+ return out
+
+
+def compute_corr_level(fmap1, fmap2s, C):
+ # fmap1: (B, S, N, C)
+ # fmap2s: (B, S, C, H*W)
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
+ corrs = corrs.view(
+ fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1
+ ) # (B, S, N, H*W)
+ return corrs / math.sqrt(C)
diff --git a/mapanything/models/external/vggt/heads/track_modules/modules.py b/mapanything/models/external/vggt/heads/track_modules/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2ef7d9890da8da2fd1109b7247a3861a1234c3f
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/track_modules/modules.py
@@ -0,0 +1,220 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import collections
+from functools import partial
+from itertools import repeat
+from typing import Callable
+
+import torch.nn as nn
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class ResidualBlock(nn.Module):
+ """
+ ResidualBlock: construct a block of two conv layers with residual connections
+ """
+
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=kernel_size,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=kernel_size,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+ else:
+ raise NotImplementedError
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
+ self.norm3,
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
+ mlp_ratio=4.0,
+ **block_kwargs,
+ ):
+ """
+ Self attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.attn = attn_class(
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
+ )
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, mask=None):
+ # Prepare the mask for PyTorch's attention (it expects a different format)
+ # attn_mask = mask if mask is not None else None
+ # Normalize before attention
+ x = self.norm1(x)
+
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
+
+ attn_output, _ = self.attn(x, x, x)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(
+ self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
+ ):
+ """
+ Cross attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
+ )
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, context, mask=None):
+ # Normalize inputs
+ x = self.norm1(x)
+ context = self.norm_context(context)
+
+ # Apply cross attention
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/mapanything/models/external/vggt/heads/track_modules/utils.py b/mapanything/models/external/vggt/heads/track_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9073520140db9fbf1e3cda632aec5ed65bc6b70a
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/track_modules/utils.py
@@ -0,0 +1,243 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from https://github.com/facebookresearch/vggsfm
+# and https://github.com/facebookresearch/co-tracker/tree/main
+
+
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False
+) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if return_grid:
+ return (
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
+ grid,
+ )
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(
+ embed_dim: int, grid: torch.Tensor
+) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(
+ embed_dim: int, pos: torch.Tensor
+) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (
+ torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+ coords = coords.detach().clone()
+ ############################################################
+ # IMPORTANT:
+ coords = coords.to(input.device).to(input.dtype)
+ ############################################################
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ scale = torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)],
+ device=coords.device,
+ dtype=coords.dtype,
+ )
+ else:
+ scale = torch.tensor(
+ [2 / size for size in reversed(sizes)],
+ device=coords.device,
+ dtype=coords.dtype,
+ )
+
+ coords.mul_(scale) # coords = coords * scale
+ coords.sub_(1) # coords = coords - 1
+
+ return F.grid_sample(
+ input, coords, align_corners=align_corners, padding_mode=padding_mode
+ )
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(
+ B, -1, feats.shape[1] * feats.shape[3]
+ ) # B C R 1 -> B R C
diff --git a/mapanything/models/external/vggt/heads/utils.py b/mapanything/models/external/vggt/heads/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce05048a1ef9c88fe19f94af2fbd50cb473350c8
--- /dev/null
+++ b/mapanything/models/external/vggt/heads/utils.py
@@ -0,0 +1,124 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+def position_grid_to_embed(
+ pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100
+) -> torch.Tensor:
+ """
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
+
+ Args:
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
+ embed_dim: Output channel dimension for embeddings
+
+ Returns:
+ Tensor of shape (H, W, embed_dim) with positional embeddings
+ """
+ H, W, grid_dim = pos_grid.shape
+ assert grid_dim == 2
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
+
+ # Process x and y coordinates separately
+ emb_x = make_sincos_pos_embed(
+ embed_dim // 2, pos_flat[:, 0], omega_0=omega_0
+ ) # [1, H*W, D/2]
+ emb_y = make_sincos_pos_embed(
+ embed_dim // 2, pos_flat[:, 1], omega_0=omega_0
+ ) # [1, H*W, D/2]
+
+ # Combine and reshape
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
+
+ return emb.view(H, W, embed_dim) # [H, W, D]
+
+
+def make_sincos_pos_embed(
+ embed_dim: int, pos: torch.Tensor, omega_0: float = 100
+) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ device = pos.device
+ omega = torch.arange(
+ embed_dim // 2,
+ dtype=torch.float32 if device.type == "mps" else torch.double,
+ device=device,
+ )
+ omega /= embed_dim / 2.0
+ omega = 1.0 / omega_0**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb.float()
+
+
+# Inspired by https://github.com/microsoft/moge
+
+
+def create_uv_grid(
+ width: int,
+ height: int,
+ aspect_ratio: float = None,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+) -> torch.Tensor:
+ """
+ Create a normalized UV grid of shape (width, height, 2).
+
+ The grid spans horizontally and vertically according to an aspect ratio,
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
+
+ Args:
+ width (int): Number of points horizontally.
+ height (int): Number of points vertically.
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
+ device (torch.device, optional): Device on which the tensor is created.
+
+ Returns:
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
+ """
+ # Derive aspect ratio if not explicitly provided
+ if aspect_ratio is None:
+ aspect_ratio = float(width) / float(height)
+
+ # Compute normalized spans for X and Y
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
+ span_x = aspect_ratio / diag_factor
+ span_y = 1.0 / diag_factor
+
+ # Establish the linspace boundaries
+ left_x = -span_x * (width - 1) / width
+ right_x = span_x * (width - 1) / width
+ top_y = -span_y * (height - 1) / height
+ bottom_y = span_y * (height - 1) / height
+
+ # Generate 1D coordinates
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
+
+ # Create 2D meshgrid (width x height) and stack into UV
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
+ uv_grid = torch.stack((uu, vv), dim=-1)
+
+ return uv_grid
diff --git a/mapanything/models/external/vggt/layers/__init__.py b/mapanything/models/external/vggt/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..20cc0ab8b9bd83a38bfa1f0e3ba3668821b68ba3
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+
+__all__ = [
+ "Mlp",
+ "PatchEmbed",
+]
diff --git a/mapanything/models/external/vggt/layers/attention.py b/mapanything/models/external/vggt/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ea9d8c6651396729cc785b4a2931abb1f7b0e56
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/attention.py
@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.fused_attn = fused_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv.unbind(0)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.rope is not None:
+ q = self.rope(q, pos)
+ k = self.rope(k, pos)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+# class MemEffAttention(Attention):
+# def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
+# assert pos is None
+# if not XFORMERS_AVAILABLE:
+# if attn_bias is not None:
+# raise AssertionError("xFormers is required for using nested tensors")
+# return super().forward(x)
+
+# B, N, C = x.shape
+# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+# q, k, v = unbind(qkv, 2)
+
+# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+# x = x.reshape([B, N, C])
+
+# x = self.proj(x)
+# x = self.proj_drop(x)
+# return x
diff --git a/mapanything/models/external/vggt/layers/block.py b/mapanything/models/external/vggt/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5487f56a3a6a207763482cefcfa90406fdb9604
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/block.py
@@ -0,0 +1,280 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Any, Callable, Dict, List, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ fused_attn=fused_attn,
+ rope=rope,
+ )
+
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ pos=pos,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, pos=pos)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+ pos=None,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ if pos is not None:
+ # if necessary, apply rope to the subset
+ pos = pos[brange]
+ residual = residual_func(x_subset, pos=pos)
+ else:
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ else:
+ pass
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = (
+ [b.shape[0] for b in branges]
+ if branges is not None
+ else [x.shape[0] for x in x_list]
+ )
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ # attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ # attn_bias._batch_sizes = batch_sizes
+ # attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ pass
+ # cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
+ # 1, -1, x_list[0].shape[-1]
+ # )
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
+ ]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(
+ x_list, branges, residual_list, residual_scale_factors
+ ):
+ outputs.append(
+ add_residual(
+ x, brange, residual, residual_scale_factor, scaling_vector
+ ).view_as(x)
+ )
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ # assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma
+ if isinstance(self.ls1, LayerScale)
+ else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma
+ if isinstance(self.ls1, LayerScale)
+ else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/mapanything/models/external/vggt/layers/drop_path.py b/mapanything/models/external/vggt/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..7de14f0f11cb348088cefbb0b976f50635b66bba
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/drop_path.py
@@ -0,0 +1,36 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/mapanything/models/external/vggt/layers/layer_scale.py b/mapanything/models/external/vggt/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2592ca74e3972ec2ab62ecf6fabd4e2631353a6
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/layer_scale.py
@@ -0,0 +1,26 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import nn, Tensor
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/mapanything/models/external/vggt/layers/mlp.py b/mapanything/models/external/vggt/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..66ef78b612cb3eb4ea2481b82b3ec02a5fbcaf49
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import nn, Tensor
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/mapanything/models/external/vggt/layers/patch_embed.py b/mapanything/models/external/vggt/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..90d8b1b87df8c719a60a654bccdf022d31e6fd80
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/patch_embed.py
@@ -0,0 +1,100 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+import torch.nn as nn
+from torch import Tensor
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, (
+ f"Input image height {H} is not a multiple of patch height {patch_H}"
+ )
+ assert W % patch_W == 0, (
+ f"Input image width {W} is not a multiple of patch width: {patch_W}"
+ )
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = (
+ Ho
+ * Wo
+ * self.embed_dim
+ * self.in_chans
+ * (self.patch_size[0] * self.patch_size[1])
+ )
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/mapanything/models/external/vggt/layers/rope.py b/mapanything/models/external/vggt/layers/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..178132f0be1e501c2038874c2dfdaf2a830be258
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/rope.py
@@ -0,0 +1,206 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+# Implementation of 2D Rotary Position Embeddings (RoPE).
+
+# This module provides a clean implementation of 2D Rotary Position Embeddings,
+# which extends the original RoPE concept to handle 2D spatial positions.
+
+# Inspired by:
+# https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# https://github.com/naver-ai/rope-vit
+
+
+from typing import Dict, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PositionGetter:
+ """Generates and caches 2D spatial positions for patches in a grid.
+
+ This class efficiently manages the generation of spatial coordinates for patches
+ in a 2D grid, caching results to avoid redundant computations.
+
+ Attributes:
+ position_cache: Dictionary storing precomputed position tensors for different
+ grid dimensions.
+ """
+
+ def __init__(self):
+ """Initializes the position generator with an empty cache."""
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
+
+ def __call__(
+ self, batch_size: int, height: int, width: int, device: torch.device
+ ) -> torch.Tensor:
+ """Generates spatial positions for a batch of patches.
+
+ Args:
+ batch_size: Number of samples in the batch.
+ height: Height of the grid in patches.
+ width: Width of the grid in patches.
+ device: Target device for the position tensor.
+
+ Returns:
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
+ for each position in the grid, repeated for each batch item.
+ """
+ if (height, width) not in self.position_cache:
+ y_coords = torch.arange(height, device=device)
+ x_coords = torch.arange(width, device=device)
+ positions = torch.cartesian_prod(y_coords, x_coords)
+ self.position_cache[height, width] = positions
+
+ cached_positions = self.position_cache[height, width]
+ return (
+ cached_positions.view(1, height * width, 2)
+ .expand(batch_size, -1, -1)
+ .clone()
+ )
+
+
+class RotaryPositionEmbedding2D(nn.Module):
+ """2D Rotary Position Embedding implementation.
+
+ This module applies rotary position embeddings to input tokens based on their
+ 2D spatial positions. It handles the position-dependent rotation of features
+ separately for vertical and horizontal dimensions.
+
+ Args:
+ frequency: Base frequency for the position embeddings. Default: 100.0
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
+
+ Attributes:
+ base_frequency: Base frequency for computing position embeddings.
+ scaling_factor: Factor to scale the computed frequencies.
+ frequency_cache: Cache for storing precomputed frequency components.
+ """
+
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
+ """Initializes the 2D RoPE module."""
+ super().__init__()
+ self.base_frequency = frequency
+ self.scaling_factor = scaling_factor
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
+
+ def _compute_frequency_components(
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Computes frequency components for rotary embeddings.
+
+ Args:
+ dim: Feature dimension (must be even).
+ seq_len: Maximum sequence length.
+ device: Target device for computations.
+ dtype: Data type for the computed tensors.
+
+ Returns:
+ Tuple of (cosine, sine) tensors for frequency components.
+ """
+ cache_key = (dim, seq_len, device, dtype)
+ if cache_key not in self.frequency_cache:
+ # Compute frequency bands
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
+ inv_freq = 1.0 / (self.base_frequency**exponents)
+
+ # Generate position-dependent frequencies
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
+
+ # Compute and cache frequency components
+ angles = angles.to(dtype)
+ angles = torch.cat((angles, angles), dim=-1)
+ cos_components = angles.cos().to(dtype)
+ sin_components = angles.sin().to(dtype)
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
+
+ return self.frequency_cache[cache_key]
+
+ @staticmethod
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
+ """Performs feature rotation by splitting and recombining feature dimensions.
+
+ Args:
+ x: Input tensor to rotate.
+
+ Returns:
+ Rotated feature tensor.
+ """
+ feature_dim = x.shape[-1]
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_1d_rope(
+ self,
+ tokens: torch.Tensor,
+ positions: torch.Tensor,
+ cos_comp: torch.Tensor,
+ sin_comp: torch.Tensor,
+ ) -> torch.Tensor:
+ """Applies 1D rotary position embeddings along one dimension.
+
+ Args:
+ tokens: Input token features.
+ positions: Position indices.
+ cos_comp: Cosine components for rotation.
+ sin_comp: Sine components for rotation.
+
+ Returns:
+ Tokens with applied rotary position embeddings.
+ """
+ # Embed positions with frequency components
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
+
+ # Apply rotation
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
+
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ """Applies 2D rotary position embeddings to input tokens.
+
+ Args:
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
+ The feature dimension (dim) must be divisible by 4.
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
+ the y and x coordinates for each token.
+
+ Returns:
+ Tensor of same shape as input with applied 2D rotary position embeddings.
+
+ Raises:
+ AssertionError: If input dimensions are invalid or positions are malformed.
+ """
+ # Validate inputs
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
+ assert positions.ndim == 3 and positions.shape[-1] == 2, (
+ "Positions must have shape (batch_size, n_tokens, 2)"
+ )
+
+ # Compute feature dimension for each spatial direction
+ feature_dim = tokens.size(-1) // 2
+
+ # Get frequency components
+ max_position = int(positions.max()) + 1
+ cos_comp, sin_comp = self._compute_frequency_components(
+ feature_dim, max_position, tokens.device, tokens.dtype
+ )
+
+ # Split features for vertical and horizontal processing
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
+
+ # Apply RoPE separately for each dimension
+ vertical_features = self._apply_1d_rope(
+ vertical_features, positions[..., 0], cos_comp, sin_comp
+ )
+ horizontal_features = self._apply_1d_rope(
+ horizontal_features, positions[..., 1], cos_comp, sin_comp
+ )
+
+ # Combine processed features
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
diff --git a/mapanything/models/external/vggt/layers/swiglu_ffn.py b/mapanything/models/external/vggt/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9f5a29548fd92188eae213ae41f79e0757f4842
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/swiglu_ffn.py
@@ -0,0 +1,71 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+# try:
+# if XFORMERS_ENABLED:
+# from xformers.ops import SwiGLU
+
+# XFORMERS_AVAILABLE = True
+# warnings.warn("xFormers is available (SwiGLU)")
+# else:
+# warnings.warn("xFormers is disabled (SwiGLU)")
+# raise ImportError
+# except ImportError:
+SwiGLU = SwiGLUFFN
+XFORMERS_AVAILABLE = False
+
+# warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/mapanything/models/external/vggt/layers/vision_transformer.py b/mapanything/models/external/vggt/layers/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e921360e94c6dc351cc635810e3cdfcaf525ac4
--- /dev/null
+++ b/mapanything/models/external/vggt/layers/vision_transformer.py
@@ -0,0 +1,454 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import math
+from functools import partial
+from typing import Callable, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.utils.checkpoint import checkpoint
+
+from . import (
+ MemEffAttention,
+ Mlp,
+ NestedTensorBlock as Block,
+ PatchEmbed,
+ SwiGLUFFNFused,
+)
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
+) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(
+ fn=fn,
+ module=child_module,
+ name=child_name,
+ depth_first=depth_first,
+ include_root=True,
+ )
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ qk_norm=False,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ # tricky but makes it work
+ self.use_checkpoint = False
+ #
+
+ self.num_features = self.embed_dim = (
+ embed_dim # num_features for consistency with other models
+ )
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
+ )
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
+ if num_register_tokens
+ else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append(
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
+ )
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
+ previous_dtype
+ )
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
+ )
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [
+ self.prepare_tokens_with_masks(x, masks)
+ for x, masks in zip(x_list, masks_list)
+ ]
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = (
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ )
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), (
+ f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ )
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = (
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ )
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), (
+ f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ )
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=True, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/mapanything/models/external/vggt/models/__init__.py b/mapanything/models/external/vggt/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/models/external/vggt/models/aggregator.py b/mapanything/models/external/vggt/models/aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..c86275f05c25cc3a1623bf8b10acd6aae94562d3
--- /dev/null
+++ b/mapanything/models/external/vggt/models/aggregator.py
@@ -0,0 +1,385 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from mapanything.models.external.vggt.layers import PatchEmbed
+from mapanything.models.external.vggt.layers.block import Block
+from mapanything.models.external.vggt.layers.rope import (
+ PositionGetter,
+ RotaryPositionEmbedding2D,
+)
+
+logger = logging.getLogger(__name__)
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+class Aggregator(nn.Module):
+ """
+ The Aggregator applies alternating-attention over input frames,
+ as described in VGGT: Visual Geometry Grounded Transformer.
+
+
+ Args:
+ img_size (int): Image size in pixels.
+ patch_size (int): Size of each patch for PatchEmbed.
+ embed_dim (int): Dimension of the token embeddings.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
+ num_register_tokens (int): Number of register tokens.
+ block_fn (nn.Module): The block type used for attention (Block by default).
+ qkv_bias (bool): Whether to include bias in QKV projections.
+ proj_bias (bool): Whether to include bias in the output projection.
+ ffn_bias (bool): Whether to include bias in MLP layers.
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
+ qk_norm (bool): Whether to apply QK normalization.
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
+ init_values (float): Init scale for layer scale.
+ """
+
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_register_tokens=4,
+ block_fn=Block,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ patch_embed="dinov2_vitl14_reg",
+ aa_order=["frame", "global"],
+ aa_block_size=1,
+ qk_norm=True,
+ rope_freq=100,
+ init_values=0.01,
+ ):
+ super().__init__()
+
+ self.__build_patch_embed__(
+ patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim
+ )
+
+ # Initialize rotary position embedding if frequency > 0
+ self.rope = (
+ RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
+ )
+ self.position_getter = PositionGetter() if self.rope is not None else None
+
+ self.frame_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.global_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.depth = depth
+ self.aa_order = aa_order
+ self.patch_size = patch_size
+ self.aa_block_size = aa_block_size
+
+ # Validate that depth is divisible by aa_block_size
+ if self.depth % self.aa_block_size != 0:
+ raise ValueError(
+ f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})"
+ )
+
+ self.aa_block_num = self.depth // self.aa_block_size
+
+ # Note: We have two camera tokens, one for the first frame and one for the rest
+ # The same applies for register tokens
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
+ self.register_token = nn.Parameter(
+ torch.randn(1, 2, num_register_tokens, embed_dim)
+ )
+
+ # The patch tokens start after the camera and register tokens
+ self.patch_start_idx = 1 + num_register_tokens
+
+ # Initialize parameters with small values
+ nn.init.normal_(self.camera_token, std=1e-6)
+ nn.init.normal_(self.register_token, std=1e-6)
+
+ # Register normalization constants as buffers
+ for name, value in (
+ ("_resnet_mean", _RESNET_MEAN),
+ ("_resnet_std", _RESNET_STD),
+ ):
+ self.register_buffer(
+ name,
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
+ persistent=False,
+ )
+
+ def __build_patch_embed__(
+ self,
+ patch_embed,
+ img_size,
+ patch_size,
+ num_register_tokens,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ block_chunks=0,
+ init_values=1.0,
+ embed_dim=1024,
+ ):
+ """
+ Build the patch embed layer. If 'conv', we use a
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
+ """
+
+ if "conv" in patch_embed:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=3,
+ embed_dim=embed_dim,
+ )
+ else:
+ ### From original VGGT codebase: Doesn't load pre-trained DINOv2 weights
+ # vit_models = {
+ # "dinov2_vitl14_reg": vit_large,
+ # "dinov2_vitb14_reg": vit_base,
+ # "dinov2_vits14_reg": vit_small,
+ # "dinov2_vitg2_reg": vit_giant2,
+ # }
+
+ # self.patch_embed = vit_models[patch_embed](
+ # img_size=img_size,
+ # patch_size=patch_size,
+ # num_register_tokens=num_register_tokens,
+ # interpolate_antialias=interpolate_antialias,
+ # interpolate_offset=interpolate_offset,
+ # block_chunks=block_chunks,
+ # init_values=init_values,
+ # )
+
+ ### Use pre-trained DINOv2 with gradient checkpointing
+ self.patch_embed = torch.hub.load("facebookresearch/dinov2", patch_embed)
+ for i in range(len(self.patch_embed.blocks)):
+ self.patch_embed.blocks[i] = (
+ self.wrap_module_with_gradient_checkpointing(
+ self.patch_embed.blocks[i]
+ )
+ )
+
+ # Disable gradient updates for mask token
+ if hasattr(self.patch_embed, "mask_token"):
+ self.patch_embed.mask_token.requires_grad_(False)
+
+ ### Gradient Checkpointing Wrapper from UniCeption:
+ def wrap_module_with_gradient_checkpointing(self, module: nn.Module):
+ """
+ Wrapper for Gradient Checkpointing
+ References: https://github.com/microsoft/MoGe
+ """
+
+ class _CheckpointingWrapper(module.__class__):
+ _restore_cls = module.__class__
+
+ def forward(self, *args, **kwargs):
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
+
+ module.__class__ = _CheckpointingWrapper
+ return module
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ ) -> Tuple[List[torch.Tensor], int]:
+ """
+ Args:
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+
+ Returns:
+ (list[torch.Tensor], int):
+ The list of outputs from the attention blocks,
+ and the patch_start_idx indicating where patch tokens begin.
+ """
+ B, S, C_in, H, W = images.shape
+
+ if C_in != 3:
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
+
+ # Normalize images and reshape for patch embed
+ images = (images - self._resnet_mean) / self._resnet_std
+
+ # Reshape to [B*S, C, H, W] for patch embedding
+ images = images.view(B * S, C_in, H, W)
+ patch_tokens = self.patch_embed.forward_features(images)
+
+ if isinstance(patch_tokens, dict):
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
+
+ _, P, C = patch_tokens.shape
+
+ # Expand camera and register tokens to match batch size and sequence length
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
+
+ # Concatenate special tokens with patch tokens
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
+
+ pos = None
+ if self.rope is not None:
+ pos = self.position_getter(
+ B * S, H // self.patch_size, W // self.patch_size, device=images.device
+ )
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos = pos + 1
+ pos_special = (
+ torch.zeros(B * S, self.patch_start_idx, 2)
+ .to(images.device)
+ .to(pos.dtype)
+ )
+ pos = torch.cat([pos_special, pos], dim=1)
+
+ # update P because we added special tokens
+ _, P, C = tokens.shape
+
+ frame_idx = 0
+ global_idx = 0
+ output_list = []
+
+ for _ in range(self.aa_block_num):
+ for attn_type in self.aa_order:
+ if attn_type == "frame":
+ tokens, frame_idx, frame_intermediates = (
+ self._process_frame_attention(
+ tokens, B, S, P, C, frame_idx, pos=pos
+ )
+ )
+ elif attn_type == "global":
+ tokens, global_idx, global_intermediates = (
+ self._process_global_attention(
+ tokens, B, S, P, C, global_idx, pos=pos
+ )
+ )
+ else:
+ raise ValueError(f"Unknown attention type: {attn_type}")
+
+ for i in range(len(frame_intermediates)):
+ # concat frame and global intermediates, [B x S x P x 2C]
+ concat_inter = torch.cat(
+ [frame_intermediates[i], global_intermediates[i]], dim=-1
+ )
+ output_list.append(concat_inter)
+
+ del concat_inter
+ del frame_intermediates
+ del global_intermediates
+ return output_list, self.patch_start_idx
+
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
+ """
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
+ """
+ # If needed, reshape tokens or positions:
+ if tokens.shape != (B * S, P, C):
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
+
+ if pos is not None and pos.shape != (B * S, P, 2):
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
+ frame_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, frame_idx, intermediates
+
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
+ """
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
+ """
+ if tokens.shape != (B, S * P, C):
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
+
+ if pos is not None and pos.shape != (B, S * P, 2):
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
+ global_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, global_idx, intermediates
+
+
+def slice_expand_and_flatten(token_tensor, B, S):
+ """
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
+ 1) Uses the first position (index=0) for the first frame only
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
+ 3) Expands both to match batch size B
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
+ followed by (S-1) second-position tokens
+ 5) Flattens to (B*S, X, C) for processing
+
+ Returns:
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
+ """
+
+ # Slice out the "query" tokens => shape (1, 1, ...)
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
+ # Slice out the "other" tokens => shape (1, S-1, ...)
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
+ # Concatenate => shape (B, S, ...)
+ combined = torch.cat([query, others], dim=1)
+
+ # Finally flatten => shape (B*S, ...)
+ combined = combined.view(B * S, *combined.shape[2:])
+ return combined
diff --git a/mapanything/models/external/vggt/models/vggt.py b/mapanything/models/external/vggt/models/vggt.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4d353dbaf55beed37860601299cb344677eb01b
--- /dev/null
+++ b/mapanything/models/external/vggt/models/vggt.py
@@ -0,0 +1,131 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+
+from mapanything.models.external.vggt.heads.camera_head import CameraHead
+from mapanything.models.external.vggt.heads.dpt_head import DPTHead
+from mapanything.models.external.vggt.heads.track_head import TrackHead
+from mapanything.models.external.vggt.models.aggregator import Aggregator
+
+
+class VGGT(nn.Module, PyTorchModelHubMixin):
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ intermediate_layer_idx=[4, 11, 17, 23],
+ ):
+ super().__init__()
+
+ self.aggregator = Aggregator(
+ img_size=img_size,
+ patch_size=patch_size,
+ embed_dim=embed_dim,
+ depth=depth,
+ num_heads=num_heads,
+ )
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
+ self.point_head = DPTHead(
+ dim_in=2 * embed_dim,
+ output_dim=4,
+ activation="inv_log",
+ conf_activation="expp1",
+ intermediate_layer_idx=intermediate_layer_idx,
+ )
+ self.depth_head = DPTHead(
+ dim_in=2 * embed_dim,
+ output_dim=2,
+ activation="exp",
+ conf_activation="expp1",
+ intermediate_layer_idx=intermediate_layer_idx,
+ )
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ query_points: torch.Tensor = None,
+ ):
+ """
+ Forward pass of the VGGT model.
+
+ Args:
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
+ Default: None
+
+ Returns:
+ dict: A dictionary containing the following predictions:
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
+ - images (torch.Tensor): Original input images, preserved for visualization
+
+ If query_points is provided, also includes:
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
+ """
+
+ # If without batch dimension, add it
+ if len(images.shape) == 4:
+ images = images.unsqueeze(0)
+ if query_points is not None and len(query_points.shape) == 2:
+ query_points = query_points.unsqueeze(0)
+
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
+
+ predictions = {}
+
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.camera_head is not None:
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
+ predictions["pose_enc"] = pose_enc_list[
+ -1
+ ] # pose encoding of the last iteration
+
+ if self.depth_head is not None:
+ depth, depth_conf = self.depth_head(
+ aggregated_tokens_list,
+ images=images,
+ patch_start_idx=patch_start_idx,
+ )
+ predictions["depth"] = depth
+ predictions["depth_conf"] = depth_conf
+
+ if self.point_head is not None:
+ pts3d, pts3d_conf = self.point_head(
+ aggregated_tokens_list,
+ images=images,
+ patch_start_idx=patch_start_idx,
+ )
+ predictions["world_points"] = pts3d
+ predictions["world_points_conf"] = pts3d_conf
+
+ if self.track_head is not None and query_points is not None:
+ track_list, vis, conf = self.track_head(
+ aggregated_tokens_list,
+ images=images,
+ patch_start_idx=patch_start_idx,
+ query_points=query_points,
+ )
+ predictions["track"] = track_list[-1] # track of the last iteration
+ predictions["vis"] = vis
+ predictions["conf"] = conf
+
+ predictions["images"] = images
+
+ return predictions
diff --git a/mapanything/models/external/vggt/utils/__init__.py b/mapanything/models/external/vggt/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/models/external/vggt/utils/geometry.py b/mapanything/models/external/vggt/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c547838e7b1a8191ff5f24b6683c67d684ea016
--- /dev/null
+++ b/mapanything/models/external/vggt/utils/geometry.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import numpy as np
+import torch
+
+
+def unproject_depth_map_to_point_map(
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
+) -> np.ndarray:
+ """
+ Unproject a batch of depth maps to 3D world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
+
+ Returns:
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
+ """
+ if isinstance(depth_map, torch.Tensor):
+ depth_map = depth_map.cpu().numpy()
+ if isinstance(extrinsics_cam, torch.Tensor):
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
+ if isinstance(intrinsics_cam, torch.Tensor):
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
+
+ world_points_list = []
+ for frame_idx in range(depth_map.shape[0]):
+ cur_world_points, _, _ = depth_to_world_coords_points(
+ depth_map[frame_idx].squeeze(-1),
+ extrinsics_cam[frame_idx],
+ intrinsics_cam[frame_idx],
+ )
+ world_points_list.append(cur_world_points)
+ world_points_array = np.stack(world_points_list, axis=0)
+
+ return world_points_array
+
+
+def depth_to_world_coords_points(
+ depth_map: np.ndarray,
+ extrinsic: np.ndarray,
+ intrinsic: np.ndarray,
+ eps=1e-8,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
+
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
+
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
+
+ # Apply the rotation and translation to the camera coordinates
+ world_coords_points = (
+ np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world
+ ) # HxWx3, 3x3 -> HxWx3
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
+
+ return world_coords_points, cam_coords_points, point_mask
+
+
+def depth_to_cam_coords_points(
+ depth_map: np.ndarray, intrinsic: np.ndarray
+) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
+ """
+ H, W = depth_map.shape
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, (
+ "Intrinsic matrix must have zero skew"
+ )
+
+ # Intrinsic parameters
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
+
+ # Generate grid of pixel coordinates
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+
+ # Unproject to camera coordinates
+ x_cam = (u - cu) * depth_map / fu
+ y_cam = (v - cv) * depth_map / fv
+ z_cam = depth_map
+
+ # Stack to form camera coordinates
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ return cam_coords
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
diff --git a/mapanything/models/external/vggt/utils/load_fn.py b/mapanything/models/external/vggt/utils/load_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d282f409a45605e5964bf96548512ce195677a7
--- /dev/null
+++ b/mapanything/models/external/vggt/utils/load_fn.py
@@ -0,0 +1,155 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from PIL import Image
+from torchvision import transforms as TF
+
+
+def load_and_preprocess_images(image_path_list, mode="crop"):
+ """
+ A quick start function to load and preprocess images for model input.
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
+
+ Args:
+ image_path_list (list): List of paths to image files
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
+ - "crop" (default): Sets width to 518px and center crops height if needed.
+ - "pad": Preserves all pixels by making the largest dimension 518px
+ and padding the smaller dimension to reach a square shape.
+
+ Returns:
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
+
+ Raises:
+ ValueError: If the input list is empty or if mode is invalid
+
+ Notes:
+ - Images with different dimensions will be padded with white (value=1.0)
+ - A warning is printed when images have different shapes
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
+ and height is center-cropped if larger than 518px
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
+ and the smaller dimension is padded to reach a square shape (518x518)
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
+ """
+ # Check for empty list
+ if len(image_path_list) == 0:
+ raise ValueError("At least 1 image is required")
+
+ # Validate mode
+ if mode not in ["crop", "pad"]:
+ raise ValueError("Mode must be either 'crop' or 'pad'")
+
+ images = []
+ shapes = set()
+ to_tensor = TF.ToTensor()
+ target_size = 518
+
+ # First process all images and collect their shapes
+ for image_path in image_path_list:
+ # Open image
+ img = Image.open(image_path)
+
+ # If there's an alpha channel, blend onto white background:
+ if img.mode == "RGBA":
+ # Create white background
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ # Alpha composite onto the white background
+ img = Image.alpha_composite(background, img)
+
+ # Now convert to "RGB" (this step assigns white for transparent areas)
+ img = img.convert("RGB")
+
+ width, height = img.size
+
+ if mode == "pad":
+ # Make the largest dimension 518px while maintaining aspect ratio
+ if width >= height:
+ new_width = target_size
+ new_height = (
+ round(height * (new_width / width) / 14) * 14
+ ) # Make divisible by 14
+ else:
+ new_height = target_size
+ new_width = (
+ round(width * (new_height / height) / 14) * 14
+ ) # Make divisible by 14
+ else: # mode == "crop"
+ # Original behavior: set width to 518px
+ new_width = target_size
+ # Calculate height maintaining aspect ratio, divisible by 14
+ new_height = round(height * (new_width / width) / 14) * 14
+
+ # Resize with new dimensions (width, height)
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
+ img = to_tensor(img) # Convert to tensor (0, 1)
+
+ # Center crop height if it's larger than 518 (only in crop mode)
+ if mode == "crop" and new_height > target_size:
+ start_y = (new_height - target_size) // 2
+ img = img[:, start_y : start_y + target_size, :]
+
+ # For pad mode, pad to make a square of target_size x target_size
+ if mode == "pad":
+ h_padding = target_size - img.shape[1]
+ w_padding = target_size - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ # Pad with white (value=1.0)
+ img = torch.nn.functional.pad(
+ img,
+ (pad_left, pad_right, pad_top, pad_bottom),
+ mode="constant",
+ value=1.0,
+ )
+
+ shapes.add((img.shape[1], img.shape[2]))
+ images.append(img)
+
+ # Check if we have different shapes
+ # In theory our model can also work well with different shapes
+ if len(shapes) > 1:
+ print(f"Warning: Found images with different shapes: {shapes}")
+ # Find maximum dimensions
+ max_height = max(shape[0] for shape in shapes)
+ max_width = max(shape[1] for shape in shapes)
+
+ # Pad images if necessary
+ padded_images = []
+ for img in images:
+ h_padding = max_height - img.shape[1]
+ w_padding = max_width - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ img = torch.nn.functional.pad(
+ img,
+ (pad_left, pad_right, pad_top, pad_bottom),
+ mode="constant",
+ value=1.0,
+ )
+ padded_images.append(img)
+ images = padded_images
+
+ images = torch.stack(images) # concatenate images
+
+ # Ensure correct shape when single image
+ if len(image_path_list) == 1:
+ # Verify shape is (1, C, H, W)
+ if images.dim() == 3:
+ images = images.unsqueeze(0)
+
+ return images
diff --git a/mapanything/models/external/vggt/utils/pose_enc.py b/mapanything/models/external/vggt/utils/pose_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bf5a6d6759d86c28b2fe581df5c7b79303374a5
--- /dev/null
+++ b/mapanything/models/external/vggt/utils/pose_enc.py
@@ -0,0 +1,135 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from .rotation import mat_to_quat, quat_to_mat
+
+
+def extri_intri_to_pose_encoding(
+ extrinsics,
+ intrinsics,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+):
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
+
+ This function transforms camera parameters into a unified pose encoding format,
+ which can be used for various downstream tasks like pose prediction or representation.
+
+ Args:
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
+ where B is batch size and S is sequence length.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
+ Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for computing field of view values. For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+
+ Returns:
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ """
+
+ # extrinsics: BxSx3x4
+ # intrinsics: BxSx3x3
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
+ T = extrinsics[:, :, :3, 3] # BxSx3
+
+ quat = mat_to_quat(R)
+ # Note the order of h and w here
+ H, W = image_size_hw
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
+ pose_encoding = torch.cat(
+ [T, quat, fov_h[..., None], fov_w[..., None]], dim=-1
+ ).float()
+ else:
+ raise NotImplementedError
+
+ return pose_encoding
+
+
+def pose_encoding_to_extri_intri(
+ pose_encoding,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+ build_intrinsics=True,
+):
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
+
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
+ reconstructing the full camera parameters from the compact encoding.
+
+ Args:
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
+ where B is batch size and S is sequence length.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for reconstructing intrinsics from field of view values.
+ For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding used. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
+ If False, only extrinsics are returned and intrinsics will be None.
+
+ Returns:
+ tuple: (extrinsics, intrinsics)
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
+ a 3x1 translation vector.
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
+ or None if build_intrinsics is False. Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
+ assumed to be at the center of the image (W/2, H/2).
+ """
+
+ intrinsics = None
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ T = pose_encoding[..., :3]
+ quat = pose_encoding[..., 3:7]
+ fov_h = pose_encoding[..., 7]
+ fov_w = pose_encoding[..., 8]
+
+ R = quat_to_mat(quat)
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
+
+ if build_intrinsics:
+ H, W = image_size_hw
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
+ intrinsics = torch.zeros(
+ pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device
+ )
+ intrinsics[..., 0, 0] = fx
+ intrinsics[..., 1, 1] = fy
+ intrinsics[..., 0, 2] = W / 2
+ intrinsics[..., 1, 2] = H / 2
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
+ else:
+ raise NotImplementedError
+
+ return extrinsics, intrinsics
diff --git a/mapanything/models/external/vggt/utils/rotation.py b/mapanything/models/external/vggt/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c955951385da74b61234ed3e629dd823caa6e6c
--- /dev/null
+++ b/mapanything/models/external/vggt/utils/rotation.py
@@ -0,0 +1,141 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
+
+import torch
+import torch.nn.functional as F
+
+
+def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Quaternion Order: XYZW or say ijkr, scalar-last
+
+ Convert rotations given as quaternions to rotation matrices.
+ Args:
+ quaternions: quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ i, j, k, r = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
diff --git a/mapanything/models/external/vggt/utils/visual_track.py b/mapanything/models/external/vggt/utils/visual_track.py
new file mode 100644
index 0000000000000000000000000000000000000000..d30aad68265900aac74832b2e62e28433d5f42ab
--- /dev/null
+++ b/mapanything/models/external/vggt/utils/visual_track.py
@@ -0,0 +1,244 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+
+import cv2
+import numpy as np
+import torch
+
+
+def color_from_xy(x, y, W, H, cmap_name="hsv"):
+ """
+ Map (x, y) -> color in (R, G, B).
+ 1) Normalize x,y to [0,1].
+ 2) Combine them into a single scalar c in [0,1].
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
+
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
+ """
+ import matplotlib.cm
+ import matplotlib.colors
+
+ x_norm = x / max(W - 1, 1)
+ y_norm = y / max(H - 1, 1)
+ # Simple combination:
+ c = (x_norm + y_norm) / 2.0
+
+ cmap = matplotlib.cm.get_cmap(cmap_name)
+ # cmap(c) -> (r,g,b,a) in [0,1]
+ rgba = cmap(c)
+ r, g, b = rgba[0], rgba[1], rgba[2]
+ return (r, g, b) # in [0,1], RGB order
+
+
+def get_track_colors_by_position(
+ tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"
+):
+ """
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
+ in [0,255]. The color is determined by the (x,y) position in the first
+ visible frame for each track.
+
+ Args:
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
+ image_width, image_height: used for normalizing (x, y).
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
+
+ Returns:
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
+ """
+ S, N, _ = tracks_b.shape
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
+
+ if vis_mask_b is None:
+ # treat all as visible
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
+
+ for i in range(N):
+ # Find first visible frame for track i
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
+ if len(visible_frames) == 0:
+ # track is never visible; just assign black or something
+ track_colors[i] = (0, 0, 0)
+ continue
+
+ first_s = int(visible_frames[0].item())
+ # use that frame's (x,y)
+ x, y = tracks_b[first_s, i].tolist()
+
+ # map (x,y) -> (R,G,B) in [0,1]
+ r, g, b = color_from_xy(
+ x, y, W=image_width, H=image_height, cmap_name=cmap_name
+ )
+ # scale to [0,255]
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
+ track_colors[i] = (r, g, b)
+
+ return track_colors
+
+
+def visualize_tracks_on_images(
+ images,
+ tracks,
+ track_vis_mask=None,
+ out_dir="track_visuals_concat_by_xy",
+ image_format="CHW", # "CHW" or "HWC"
+ normalize_mode="[0,1]",
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
+ frames_per_row=4, # New parameter for grid layout
+ save_grid=True, # Flag to control whether to save the grid image
+):
+ """
+ Visualizes frames in a grid layout with specified frames per row.
+ Each track's color is determined by its (x,y) position
+ in the first visible frame (or frame 0 if always visible).
+ Finally convert the BGR result to RGB before saving.
+ Also saves each individual frame as a separate PNG file.
+
+ Args:
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
+ track_vis_mask: torch.Tensor (S, N) or None.
+ out_dir: folder to save visualizations.
+ image_format: "CHW" or "HWC".
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
+ cmap_name: a matplotlib colormap name for color_from_xy.
+ frames_per_row: number of frames to display in each row of the grid.
+ save_grid: whether to save all frames in one grid image.
+
+ Returns:
+ None (saves images in out_dir).
+ """
+
+ if len(tracks.shape) == 4:
+ tracks = tracks.squeeze(0)
+ images = images.squeeze(0)
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.squeeze(0)
+
+ import matplotlib
+
+ matplotlib.use("Agg") # for non-interactive (optional)
+
+ os.makedirs(out_dir, exist_ok=True)
+
+ S = images.shape[0]
+ _, N, _ = tracks.shape # (S, N, 2)
+
+ # Move to CPU
+ images = images.cpu().clone()
+ tracks = tracks.cpu().clone()
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.cpu().clone()
+
+ # Infer H, W from images shape
+ if image_format == "CHW":
+ # e.g. images[s].shape = (3, H, W)
+ H, W = images.shape[2], images.shape[3]
+ else:
+ # e.g. images[s].shape = (H, W, 3)
+ H, W = images.shape[1], images.shape[2]
+
+ # Pre-compute the color for each track i based on first visible position
+ track_colors_rgb = get_track_colors_by_position(
+ tracks, # shape (S, N, 2)
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
+ image_width=W,
+ image_height=H,
+ cmap_name=cmap_name,
+ )
+
+ # We'll accumulate each frame's drawn image in a list
+ frame_images = []
+
+ for s in range(S):
+ # shape => either (3, H, W) or (H, W, 3)
+ img = images[s]
+
+ # Convert to (H, W, 3)
+ if image_format == "CHW":
+ img = img.permute(1, 2, 0) # (H, W, 3)
+ # else "HWC", do nothing
+
+ img = img.numpy().astype(np.float32)
+
+ # Scale to [0,255] if needed
+ if normalize_mode == "[0,1]":
+ img = np.clip(img, 0, 1) * 255.0
+ elif normalize_mode == "[-1,1]":
+ img = (img + 1.0) * 0.5 * 255.0
+ img = np.clip(img, 0, 255.0)
+ # else no normalization
+
+ # Convert to uint8
+ img = img.astype(np.uint8)
+
+ # For drawing in OpenCV, convert to BGR
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+ # Draw each visible track
+ cur_tracks = tracks[s] # shape (N, 2)
+ if track_vis_mask is not None:
+ valid_indices = torch.where(track_vis_mask[s])[0]
+ else:
+ valid_indices = range(N)
+
+ cur_tracks_np = cur_tracks.numpy()
+ for i in valid_indices:
+ x, y = cur_tracks_np[i]
+ pt = (int(round(x)), int(round(y)))
+
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
+ R, G, B = track_colors_rgb[i]
+ color_bgr = (int(B), int(G), int(R))
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
+
+ # Convert back to RGB for consistent final saving:
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+
+ # Save individual frame
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
+ # Convert to BGR for OpenCV imwrite
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(frame_path, frame_bgr)
+
+ frame_images.append(img_rgb)
+
+ # Only create and save the grid image if save_grid is True
+ if save_grid:
+ # Calculate grid dimensions
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
+
+ # Create a grid of images
+ grid_img = None
+ for row in range(num_rows):
+ start_idx = row * frames_per_row
+ end_idx = min(start_idx + frames_per_row, S)
+
+ # Concatenate this row horizontally
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
+
+ # If this row has fewer than frames_per_row images, pad with black
+ if end_idx - start_idx < frames_per_row:
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
+ row_img = np.concatenate([row_img, padding], axis=1)
+
+ # Add this row to the grid
+ if grid_img is None:
+ grid_img = row_img
+ else:
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
+
+ out_path = os.path.join(out_dir, "tracks_grid.png")
+ # Convert back to BGR for OpenCV imwrite
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(out_path, grid_img_bgr)
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
+
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
diff --git a/mapanything/models/mapanything/__init__.py b/mapanything/models/mapanything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e6352b2a6e6a5374d5e8a079ef65fc503816e8
--- /dev/null
+++ b/mapanything/models/mapanything/__init__.py
@@ -0,0 +1,9 @@
+from mapanything.models.mapanything.ablations import MapAnythingAblations
+from mapanything.models.mapanything.model import MapAnything
+from mapanything.models.mapanything.modular_dust3r import ModularDUSt3R
+
+__all__ = [
+ "MapAnything",
+ "MapAnythingAblations",
+ "ModularDUSt3R",
+]
diff --git a/mapanything/models/mapanything/ablations.py b/mapanything/models/mapanything/ablations.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8e3652d226526bf38070c5552b7c8a8171b63a6
--- /dev/null
+++ b/mapanything/models/mapanything/ablations.py
@@ -0,0 +1,1653 @@
+"""
+MapAnything Ablation model classes defined using UniCeption modules.
+"""
+
+from functools import partial
+from typing import Callable, Dict, Type, Union
+
+import torch
+import torch.nn as nn
+
+from mapanything.utils.geometry import (
+ apply_log_to_norm,
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ normalize_depth_using_non_zero_pixels,
+ normalize_pose_translations,
+ transform_pose_using_quats_and_trans_2_to_1,
+)
+from uniception.models.encoders import (
+ encoder_factory,
+ EncoderGlobalRepInput,
+ ViTEncoderInput,
+ ViTEncoderNonImageInput,
+)
+from uniception.models.info_sharing.alternating_attention_transformer import (
+ MultiViewAlternatingAttentionTransformer,
+ MultiViewAlternatingAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.base import MultiViewTransformerInput
+from uniception.models.info_sharing.cross_attention_transformer import (
+ MultiViewCrossAttentionTransformer,
+ MultiViewCrossAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.global_attention_transformer import (
+ MultiViewGlobalAttentionTransformer,
+ MultiViewGlobalAttentionTransformerIFR,
+)
+from uniception.models.libs.croco.pos_embed import RoPE2D
+from uniception.models.prediction_heads.adaptors import (
+ CamTranslationPlusQuatsAdaptor,
+ PointMapAdaptor,
+ PointMapPlusRayDirectionsPlusDepthAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor,
+ PointMapWithConfidenceAdaptor,
+ PointMapWithConfidenceAndMaskAdaptor,
+ PointMapWithMaskAdaptor,
+ RayDirectionsPlusDepthAdaptor,
+ RayDirectionsPlusDepthWithConfidenceAdaptor,
+ RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor,
+ RayDirectionsPlusDepthWithMaskAdaptor,
+ RayMapPlusDepthAdaptor,
+ RayMapPlusDepthWithConfidenceAdaptor,
+ RayMapPlusDepthWithConfidenceAndMaskAdaptor,
+ RayMapPlusDepthWithMaskAdaptor,
+)
+from uniception.models.prediction_heads.base import (
+ AdaptorInput,
+ PredictionHeadInput,
+ PredictionHeadLayeredInput,
+)
+from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
+from uniception.models.prediction_heads.linear import LinearFeature
+from uniception.models.prediction_heads.pose_head import PoseHead
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+class MapAnythingAblations(nn.Module):
+ "Modular MapAnything Multi-View model class with no scale token."
+
+ def __init__(
+ self,
+ name: str,
+ encoder_config: Dict,
+ info_sharing_config: Dict,
+ pred_head_config: Dict,
+ geometric_input_config: Dict,
+ fusion_norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(
+ nn.LayerNorm, eps=1e-6
+ ),
+ pretrained_checkpoint_path: str = None,
+ load_specific_pretrained_submodules: bool = False,
+ specific_pretrained_submodules: list = [],
+ torch_hub_force_reload: bool = False,
+ ):
+ """
+ Multi-view model containing an image encoder followed by a multi-view attention transformer and respective downstream heads.
+ The goal is to output scene representation directly in view 0's frame.
+
+ Args:
+ name (str): Name of the model.
+ encoder_config (Dict): Configuration for the encoder.
+ info_sharing_config (Dict): Configuration for the multi-view attention transformer.
+ pred_head_config (Dict): Configuration for the prediction heads.
+ pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
+ load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False)
+ specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: [])
+ torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False)
+ """
+ super().__init__()
+
+ # Initalize the attributes
+ self.name = name
+ self.encoder_config = encoder_config
+ self.info_sharing_config = info_sharing_config
+ self.pred_head_config = pred_head_config
+ self.geometric_input_config = geometric_input_config
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
+ self.load_specific_pretrained_submodules = load_specific_pretrained_submodules
+ self.specific_pretrained_submodules = specific_pretrained_submodules
+ self.torch_hub_force_reload = torch_hub_force_reload
+ self.class_init_args = {
+ "name": self.name,
+ "encoder_config": self.encoder_config,
+ "info_sharing_config": self.info_sharing_config,
+ "pred_head_config": self.pred_head_config,
+ "geometric_input_config": self.geometric_input_config,
+ "pretrained_checkpoint_path": self.pretrained_checkpoint_path,
+ "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules,
+ "specific_pretrained_submodules": self.specific_pretrained_submodules,
+ "torch_hub_force_reload": self.torch_hub_force_reload,
+ }
+
+ # Get relevant parameters from the configs
+ self.info_sharing_type = info_sharing_config["model_type"]
+ self.info_sharing_return_type = info_sharing_config["model_return_type"]
+ self.pred_head_type = pred_head_config["type"]
+
+ # Initialize image encoder
+ if self.encoder_config["uses_torch_hub"]:
+ self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
+ del self.encoder_config["uses_torch_hub"]
+ self.encoder = encoder_factory(**self.encoder_config)
+
+ # Initialize the encoder for ray directions
+ ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
+ ray_dirs_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ ray_dirs_encoder_config["patch_size"] = self.encoder.patch_size
+ self.ray_dirs_encoder = encoder_factory(**ray_dirs_encoder_config)
+
+ # Initialize the encoder for depth (normalized per view and values after normalization are scaled logarithmically)
+ depth_encoder_config = self.geometric_input_config["depth_encoder_config"]
+ depth_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ depth_encoder_config["patch_size"] = self.encoder.patch_size
+ self.depth_encoder = encoder_factory(**depth_encoder_config)
+
+ # Initialize the encoder for log scale factor of depth
+ depth_scale_encoder_config = self.geometric_input_config["scale_encoder_config"]
+ depth_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.depth_scale_encoder = encoder_factory(**depth_scale_encoder_config)
+
+ # Initialize the encoder for camera rotation
+ cam_rot_encoder_config = self.geometric_input_config["cam_rot_encoder_config"]
+ cam_rot_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_rot_encoder = encoder_factory(**cam_rot_encoder_config)
+
+ # Initialize the encoder for camera translation (normalized across all provided camera translations)
+ cam_trans_encoder_config = self.geometric_input_config[
+ "cam_trans_encoder_config"
+ ]
+ cam_trans_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_trans_encoder = encoder_factory(**cam_trans_encoder_config)
+
+ # Initialize the encoder for log scale factor of camera translation
+ cam_trans_scale_encoder_config = self.geometric_input_config[
+ "scale_encoder_config"
+ ]
+ cam_trans_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_trans_scale_encoder = encoder_factory(**cam_trans_scale_encoder_config)
+
+ # Initialize the fusion norm layer
+ self.fusion_norm_layer = fusion_norm_layer(self.encoder.enc_embed_dim)
+
+ # Initialize the info sharing module (Multi-View Transformer)
+ self._initialize_info_sharing(info_sharing_config)
+
+ # Initialize the prediction heads
+ self._initialize_prediction_heads(pred_head_config)
+
+ # Initialize the final adaptors
+ self._initialize_adaptors(pred_head_config)
+
+ # Load pretrained weights
+ self._load_pretrained_weights()
+
+ def _initialize_info_sharing(self, info_sharing_config):
+ """
+ Initialize the information sharing module based on the configuration.
+
+ This method sets up the custom positional encoding if specified and initializes
+ the appropriate multi-view transformer based on the configuration type.
+
+ Args:
+ info_sharing_config (Dict): Configuration for the multi-view attention transformer.
+ Should contain 'custom_positional_encoding', 'model_type', and 'model_return_type'.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If invalid configuration options are provided.
+ """
+ # Initialize Custom Positional Encoding if required
+ custom_positional_encoding = info_sharing_config["custom_positional_encoding"]
+ if custom_positional_encoding is not None:
+ if isinstance(custom_positional_encoding, str):
+ print(
+ f"Using custom positional encoding for multi-view attention transformer: {custom_positional_encoding}"
+ )
+ if custom_positional_encoding.startswith("RoPE"):
+ rope_freq = float(custom_positional_encoding[len("RoPE") :])
+ print(f"RoPE frequency: {rope_freq}")
+ self.custom_positional_encoding = RoPE2D(freq=rope_freq)
+ else:
+ raise ValueError(
+ f"Invalid custom_positional_encoding: {custom_positional_encoding}."
+ )
+ elif isinstance(custom_positional_encoding, Callable):
+ print(
+ "Using callable function as custom positional encoding for multi-view attention transformer."
+ )
+ self.custom_positional_encoding = custom_positional_encoding
+ else:
+ self.custom_positional_encoding = None
+
+ # Add dependecies to info_sharing_config
+ info_sharing_config["module_args"]["input_embed_dim"] = (
+ self.encoder.enc_embed_dim
+ )
+ info_sharing_config["module_args"]["custom_positional_encoding"] = (
+ self.custom_positional_encoding
+ )
+
+ # Initialize Multi-View Transformer
+ if self.info_sharing_return_type == "no_intermediate_features":
+ # Returns only normalized last layer features
+ # Intialize multi-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ elif self.info_sharing_return_type == "intermediate_features":
+ # Returns intermediate features and normalized last layer features
+ # Initialize mulit-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ # Assess if the DPT needs to use encoder features
+ if len(self.info_sharing.indices) == 2:
+ self.use_encoder_features_for_dpt = True
+ elif len(self.info_sharing.indices) == 3:
+ self.use_encoder_features_for_dpt = False
+ else:
+ raise ValueError(
+ "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices."
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']"
+ )
+
+ def _initialize_prediction_heads(self, pred_head_config):
+ """
+ Initialize the prediction heads based on the prediction head configuration.
+
+ This method configures and initializes the appropriate prediction heads based on the
+ specified prediction head type (linear, DPT, or DPT+pose). It sets up the necessary
+ dependencies and creates the required model components.
+
+ Args:
+ pred_head_config (Dict): Configuration for the prediction heads.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If an invalid pred_head_type is provided.
+ """
+ # Add dependencies to prediction head config
+ pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size
+ if self.pred_head_type == "linear":
+ pred_head_config["feature_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ elif "dpt" in self.pred_head_type:
+ # Add dependencies for DPT & Regressor head
+ if self.use_encoder_features_for_dpt:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.encoder.enc_embed_dim
+ ] + [self.info_sharing.dim] * 3
+ else:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.info_sharing.dim
+ ] * 4
+ pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[
+ "feature_head"
+ ]["feature_dim"]
+ # Add dependencies for Pose head if required
+ if "pose" in self.pred_head_type:
+ pred_head_config["pose_head"]["patch_size"] = self.encoder.patch_size
+ pred_head_config["pose_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Initialize Prediction Heads
+ if self.pred_head_type == "linear":
+ # Initialize Dense Prediction Head for all views
+ self.dense_head = LinearFeature(**pred_head_config["feature_head"])
+ elif "dpt" in self.pred_head_type:
+ # Initialze Dense Predction Head for all views
+ self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"])
+ self.dpt_regressor_head = DPTRegressionProcessor(
+ **pred_head_config["regressor_head"]
+ )
+ self.dense_head = nn.Sequential(
+ self.dpt_feature_head, self.dpt_regressor_head
+ )
+ # Initialize Pose Head for all views if required
+ if "pose" in self.pred_head_type:
+ self.pose_head = PoseHead(**pred_head_config["pose_head"])
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ def _initialize_adaptors(self, pred_head_config):
+ """
+ Initialize the adaptors based on the prediction head configuration.
+
+ This method sets up the appropriate adaptors for different scene representation types,
+ such as pointmaps, ray maps with depth, or ray directions with depth and pose.
+
+ Args:
+ pred_head_config (Dict): Configuration for the prediction heads including adaptor type.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If an invalid adaptor_type is provided.
+ AssertionError: If ray directions + depth + pose is used with an incompatible head type.
+ """
+ if pred_head_config["adaptor_type"] == "pointmap":
+ self.dense_adaptor = PointMapAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap"
+ elif pred_head_config["adaptor_type"] == "pointmap+confidence":
+ self.dense_adaptor = PointMapWithConfidenceAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "pointmap+confidence"
+ elif pred_head_config["adaptor_type"] == "pointmap+mask":
+ self.dense_adaptor = PointMapWithMaskAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap+mask"
+ elif pred_head_config["adaptor_type"] == "pointmap+confidence+mask":
+ self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "pointmap+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "raymap+depth":
+ self.dense_adaptor = RayMapPlusDepthAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "raymap+depth"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+confidence":
+ self.dense_adaptor = RayMapPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+confidence"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+mask":
+ self.dense_adaptor = RayMapPlusDepthWithMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+mask"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+confidence+mask":
+ self.dense_adaptor = RayMapPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+mask"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapAdaptor(**pred_head_config["dpt_adaptor"])
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+mask"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose"
+ elif (
+ pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+confidence"
+ ):
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = (
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+mask"
+ elif (
+ pred_head_config["adaptor_type"]
+ == "pointmap+raydirs+depth+pose+confidence+mask"
+ ):
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = (
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence+mask"
+ else:
+ raise ValueError(
+ f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. \
+ Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \
+ 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \
+ 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \
+ 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']"
+ )
+
+ def _load_pretrained_weights(self):
+ """
+ Load pretrained weights from a checkpoint file.
+
+ If load_specific_pretrained_submodules is True, only loads weights for the specified submodules.
+ Otherwise, loads all weights from the checkpoint.
+
+ Returns:
+ None
+ """
+ if self.pretrained_checkpoint_path is not None:
+ if not self.load_specific_pretrained_submodules:
+ print(
+ f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ print(self.load_state_dict(ckpt["model"]))
+ else:
+ print(
+ f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} for specific submodules: {self.specific_pretrained_submodules} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ filtered_ckpt = {}
+ for ckpt_key, ckpt_value in ckpt["model"].items():
+ for submodule in self.specific_pretrained_submodules:
+ if ckpt_key.startswith(submodule):
+ filtered_ckpt[ckpt_key] = ckpt_value
+ print(self.load_state_dict(filtered_ckpt, strict=False))
+
+ def _encode_n_views(self, views):
+ """
+ Encode all the input views (batch of images) in a single forward pass.
+ Assumes all the input views have the same image shape, batch size, and data normalization type.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+
+ Returns:
+ List[torch.Tensor]: A list containing the encoded features for all N views.
+ """
+ num_views = len(views)
+ data_norm_type = views[0]["data_norm_type"][0]
+ imgs_list = [view["img"] for view in views]
+ all_imgs_across_views = torch.cat(imgs_list, dim=0)
+ encoder_input = ViTEncoderInput(
+ image=all_imgs_across_views, data_norm_type=data_norm_type
+ )
+ encoder_output = self.encoder(encoder_input)
+ all_encoder_features_across_views = encoder_output.features.chunk(
+ num_views, dim=0
+ )
+
+ return all_encoder_features_across_views
+
+ def _compute_pose_quats_and_trans_for_across_views_in_ref_view(
+ self,
+ views,
+ num_views,
+ device,
+ dtype,
+ batch_size_per_view,
+ per_sample_cam_input_mask,
+ ):
+ """
+ Compute the pose quats and trans for all the views in the frame of the reference view 0.
+ Returns identity pose for views where the camera input mask is False or the pose is not provided.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ device (torch.device): Device to use for the computation.
+ dtype (torch.dtype): Data type to use for the computation.
+ per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4)
+ torch.Tensor: A tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3)
+ torch.Tensor: A tensor containing the per sample camera input mask.
+ """
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ pose_quats_non_ref_views = []
+ pose_trans_non_ref_views = []
+ pose_quats_ref_view_0 = []
+ pose_trans_ref_view_0 = []
+ for view_idx in range(num_views):
+ per_sample_cam_input_mask_for_curr_view = per_sample_cam_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ]
+ if (
+ "camera_pose_quats" in views[view_idx]
+ and "camera_pose_trans" in views[view_idx]
+ and per_sample_cam_input_mask_for_curr_view.any()
+ ):
+ # Get the camera pose quats and trans for the current view
+ cam_pose_quats = views[view_idx]["camera_pose_quats"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ cam_pose_trans = views[view_idx]["camera_pose_trans"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ # Append to the list
+ pose_quats_non_ref_views.append(cam_pose_quats)
+ pose_trans_non_ref_views.append(cam_pose_trans)
+ # Get the camera pose quats and trans for the reference view 0
+ cam_pose_quats = views[0]["camera_pose_quats"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ cam_pose_trans = views[0]["camera_pose_trans"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ # Append to the list
+ pose_quats_ref_view_0.append(cam_pose_quats)
+ pose_trans_ref_view_0.append(cam_pose_trans)
+ else:
+ per_sample_cam_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+
+ # Initialize the pose quats and trans for all views as identity
+ pose_quats_across_views = torch.tensor(
+ [0.0, 0.0, 0.0, 1.0], dtype=dtype, device=device
+ ).repeat(batch_size_per_view * num_views, 1) # (q_x, q_y, q_z, q_w)
+ pose_trans_across_views = torch.zeros(
+ (batch_size_per_view * num_views, 3), dtype=dtype, device=device
+ )
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ if len(pose_quats_non_ref_views) > 0:
+ # Stack the pose quats and trans for all the non-reference views and reference view 0
+ pose_quats_non_ref_views = torch.cat(pose_quats_non_ref_views, dim=0)
+ pose_trans_non_ref_views = torch.cat(pose_trans_non_ref_views, dim=0)
+ pose_quats_ref_view_0 = torch.cat(pose_quats_ref_view_0, dim=0)
+ pose_trans_ref_view_0 = torch.cat(pose_trans_ref_view_0, dim=0)
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ (
+ pose_quats_non_ref_views_in_ref_view_0,
+ pose_trans_non_ref_views_in_ref_view_0,
+ ) = transform_pose_using_quats_and_trans_2_to_1(
+ pose_quats_ref_view_0,
+ pose_trans_ref_view_0,
+ pose_quats_non_ref_views,
+ pose_trans_non_ref_views,
+ )
+
+ # Update the pose quats and trans for all the non-reference views
+ pose_quats_across_views[per_sample_cam_input_mask] = (
+ pose_quats_non_ref_views_in_ref_view_0.to(dtype=dtype)
+ )
+ pose_trans_across_views[per_sample_cam_input_mask] = (
+ pose_trans_non_ref_views_in_ref_view_0.to(dtype=dtype)
+ )
+
+ return (
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ )
+
+ def _encode_and_fuse_ray_dirs(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_ray_dirs_input_mask,
+ ):
+ """
+ Encode the ray directions for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ per_sample_ray_dirs_input_mask (torch.Tensor): Tensor containing the per sample ray direction input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Get the height and width of the images
+ _, _, height, width = views[0]["img"].shape
+
+ # Get the ray directions for all the views where info is provided and the ray direction input mask is True
+ ray_dirs_list = []
+ for view_idx in range(num_views):
+ per_sample_ray_dirs_input_mask_for_curr_view = (
+ per_sample_ray_dirs_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ]
+ )
+ ray_dirs_for_curr_view = torch.zeros(
+ (batch_size_per_view, height, width, 3),
+ dtype=all_encoder_features_across_views.dtype,
+ device=all_encoder_features_across_views.device,
+ )
+ if (
+ "ray_directions_cam" in views[view_idx]
+ and per_sample_ray_dirs_input_mask_for_curr_view.any()
+ ):
+ ray_dirs_for_curr_view[per_sample_ray_dirs_input_mask_for_curr_view] = (
+ views[view_idx]["ray_directions_cam"][
+ per_sample_ray_dirs_input_mask_for_curr_view
+ ]
+ )
+ else:
+ per_sample_ray_dirs_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+ ray_dirs_list.append(ray_dirs_for_curr_view)
+
+ # Stack the ray directions for all the views and permute to (B * V, C, H, W)
+ ray_dirs = torch.cat(ray_dirs_list, dim=0) # (B * V, H, W, 3)
+ ray_dirs = ray_dirs.permute(0, 3, 1, 2).contiguous() # (B * V, 3, H, W)
+
+ # Encode the ray directions
+ ray_dirs_features_across_views = self.ray_dirs_encoder(
+ ViTEncoderNonImageInput(data=ray_dirs)
+ ).features
+
+ # Fuse the ray direction features with the other encoder features (zero out the features where the ray direction input mask is False)
+ ray_dirs_features_across_views = (
+ ray_dirs_features_across_views
+ * per_sample_ray_dirs_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views + ray_dirs_features_across_views
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_depths(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_depth_input_mask,
+ ):
+ """
+ Encode the z depths for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ per_sample_depth_input_mask (torch.Tensor): Tensor containing the per sample depth input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Get the device and height and width of the images
+ device = all_encoder_features_across_views.device
+ _, _, height, width = views[0]["img"].shape
+
+ # Decide to use randomly sampled sparse depth or dense depth
+ if torch.rand(1) < self.geometric_input_config["sparse_depth_prob"]:
+ use_sparse_depth = True
+ else:
+ use_sparse_depth = False
+
+ # Get the depths for all the views
+ depth_list = []
+ depth_norm_factors_list = []
+ metric_scale_depth_mask_list = []
+ for view_idx in range(num_views):
+ # Get the input mask for current view
+ per_sample_depth_input_mask_for_curr_view = per_sample_depth_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ]
+ depth_for_curr_view = torch.zeros(
+ (batch_size_per_view, height, width, 1),
+ dtype=all_encoder_features_across_views.dtype,
+ device=device,
+ )
+ depth_norm_factor_for_curr_view = torch.zeros(
+ (batch_size_per_view),
+ dtype=all_encoder_features_across_views.dtype,
+ device=device,
+ )
+ metric_scale_mask_for_curr_view = torch.zeros(
+ (batch_size_per_view),
+ dtype=torch.bool,
+ device=device,
+ )
+ if (
+ "depth_along_ray" in views[view_idx]
+ ) and per_sample_depth_input_mask_for_curr_view.any():
+ # Get depth for current view
+ depth_for_curr_view_input = views[view_idx]["depth_along_ray"][
+ per_sample_depth_input_mask_for_curr_view
+ ]
+ # Get the metric scale mask
+ if "is_metric_scale" in views[view_idx]:
+ metric_scale_mask = views[view_idx]["is_metric_scale"][
+ per_sample_depth_input_mask_for_curr_view
+ ]
+ else:
+ metric_scale_mask = torch.zeros(
+ depth_for_curr_view_input.shape[0],
+ dtype=torch.bool,
+ device=device,
+ )
+ # Turn off indication of metric scale samples based on the depth_scale_norm_all_prob
+ depth_scale_norm_all_mask = (
+ torch.rand(metric_scale_mask.shape[0])
+ < self.geometric_input_config["depth_scale_norm_all_prob"]
+ )
+ if depth_scale_norm_all_mask.any():
+ metric_scale_mask[depth_scale_norm_all_mask] = False
+ # Assign the metric scale mask to the respective indices
+ metric_scale_mask_for_curr_view[
+ per_sample_depth_input_mask_for_curr_view
+ ] = metric_scale_mask
+ # Sparsely sample the depth if required
+ if use_sparse_depth:
+ # Create a mask of ones
+ sparsification_mask = torch.ones_like(
+ depth_for_curr_view_input, device=device
+ )
+ # Create a mask for valid pixels (depth > 0)
+ valid_pixel_mask = depth_for_curr_view_input > 0
+ # Calculate the number of valid pixels
+ num_valid_pixels = valid_pixel_mask.sum().item()
+ # Calculate the number of valid pixels to set to zero
+ num_to_zero = int(
+ num_valid_pixels
+ * self.geometric_input_config["sparsification_removal_percent"]
+ )
+ if num_to_zero > 0:
+ # Get the indices of valid pixels
+ valid_indices = valid_pixel_mask.nonzero(as_tuple=True)
+ # Randomly select indices to zero out
+ indices_to_zero = torch.randperm(num_valid_pixels)[:num_to_zero]
+ # Set selected valid indices to zero in the mask
+ sparsification_mask[
+ valid_indices[0][indices_to_zero],
+ valid_indices[1][indices_to_zero],
+ valid_indices[2][indices_to_zero],
+ valid_indices[3][indices_to_zero],
+ ] = 0
+ # Apply the mask on the depth
+ depth_for_curr_view_input = (
+ depth_for_curr_view_input * sparsification_mask
+ )
+ # Normalize the depth
+ scaled_depth_for_curr_view_input, depth_norm_factor = (
+ normalize_depth_using_non_zero_pixels(
+ depth_for_curr_view_input, return_norm_factor=True
+ )
+ )
+ # Assign the depth and depth norm factor to the respective indices
+ depth_for_curr_view[per_sample_depth_input_mask_for_curr_view] = (
+ scaled_depth_for_curr_view_input
+ )
+ depth_norm_factor_for_curr_view[
+ per_sample_depth_input_mask_for_curr_view
+ ] = depth_norm_factor
+ else:
+ per_sample_depth_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+ # Append the depths, depth norm factor and metric scale mask for the current view
+ depth_list.append(depth_for_curr_view)
+ depth_norm_factors_list.append(depth_norm_factor_for_curr_view)
+ metric_scale_depth_mask_list.append(metric_scale_mask_for_curr_view)
+
+ # Stack the depths for all the views and permute to (B * V, C, H, W)
+ depths = torch.cat(depth_list, dim=0) # (B * V, H, W, 1)
+ depths = apply_log_to_norm(
+ depths
+ ) # Scale logarithimically (norm is computed along last dim)
+ depths = depths.permute(0, 3, 1, 2).contiguous() # (B * V, 1, H, W)
+ # Encode the depths using the depth encoder
+ depth_features_across_views = self.depth_encoder(
+ ViTEncoderNonImageInput(data=depths)
+ ).features
+ # Zero out the depth features where the depth input mask is False
+ depth_features_across_views = (
+ depth_features_across_views
+ * per_sample_depth_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+
+ # Stack the depth norm factors for all the views
+ depth_norm_factors = torch.cat(depth_norm_factors_list, dim=0) # (B * V, )
+ # Encode the depth norm factors using the log scale encoder for depth
+ log_depth_norm_factors = torch.log(depth_norm_factors + 1e-8) # (B * V, )
+ depth_scale_features_across_views = self.depth_scale_encoder(
+ EncoderGlobalRepInput(data=log_depth_norm_factors.unsqueeze(-1))
+ ).features
+ # Zero out the depth scale features where the depth input mask is False
+ depth_scale_features_across_views = (
+ depth_scale_features_across_views
+ * per_sample_depth_input_mask.unsqueeze(-1)
+ )
+ # Stack the metric scale mask for all the views
+ metric_scale_depth_mask = torch.cat(
+ metric_scale_depth_mask_list, dim=0
+ ) # (B * V, )
+ # Zero out the depth scale features where the metric scale mask is False
+ # Scale encoding is only provided for metric scale samples
+ depth_scale_features_across_views = (
+ depth_scale_features_across_views * metric_scale_depth_mask.unsqueeze(-1)
+ )
+
+ # Fuse the depth features & depth scale features with the other encoder features
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views
+ + depth_features_across_views
+ + depth_scale_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_cam_quats_and_trans(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ ):
+ """
+ Encode the camera quats and trans for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ pose_quats_across_views (torch.Tensor): Tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4)
+ pose_trans_across_views (torch.Tensor): Tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3)
+ per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Encode the pose quats
+ pose_quats_features_across_views = self.cam_rot_encoder(
+ EncoderGlobalRepInput(data=pose_quats_across_views)
+ ).features
+ # Zero out the pose quat features where the camera input mask is False
+ pose_quats_features_across_views = (
+ pose_quats_features_across_views * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+
+ # Get the metric scale mask for all samples
+ device = all_encoder_features_across_views.device
+ metric_scale_pose_trans_mask = torch.zeros(
+ (batch_size_per_view * num_views), dtype=torch.bool, device=device
+ )
+ for view_idx in range(num_views):
+ if "is_metric_scale" in views[view_idx]:
+ # Get the metric scale mask for the input pose priors
+ metric_scale_mask = views[view_idx]["is_metric_scale"]
+ else:
+ metric_scale_mask = torch.zeros(
+ batch_size_per_view, dtype=torch.bool, device=device
+ )
+ metric_scale_pose_trans_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ] = metric_scale_mask
+
+ # Turn off indication of metric scale samples based on the pose_scale_norm_all_prob
+ pose_norm_all_mask = (
+ torch.rand(batch_size_per_view * num_views)
+ < self.geometric_input_config["pose_scale_norm_all_prob"]
+ )
+ if pose_norm_all_mask.any():
+ metric_scale_pose_trans_mask[pose_norm_all_mask] = False
+
+ # Get the scale norm factor for all the samples and scale the pose translations
+ pose_trans_across_views = torch.split(
+ pose_trans_across_views, batch_size_per_view, dim=0
+ ) # Split into num_views chunks
+ pose_trans_across_views = torch.stack(
+ pose_trans_across_views, dim=1
+ ) # Stack the views along a new dimension (batch_size_per_view, num_views, 3)
+ scaled_pose_trans_across_views, pose_trans_norm_factors = (
+ normalize_pose_translations(
+ pose_trans_across_views, return_norm_factor=True
+ )
+ )
+
+ # Resize the pose translation back to (batch_size_per_view * num_views, 3) and extend the norm factor to (batch_size_per_view * num_views, 1)
+ scaled_pose_trans_across_views = scaled_pose_trans_across_views.unbind(
+ dim=1
+ ) # Convert back to list of views, where each view has batch_size_per_view tensor
+ scaled_pose_trans_across_views = torch.cat(
+ scaled_pose_trans_across_views, dim=0
+ ) # Concatenate back to (batch_size_per_view * num_views, 3)
+ pose_trans_norm_factors_across_views = pose_trans_norm_factors.unsqueeze(
+ -1
+ ).repeat(num_views, 1) # (B, ) -> (B * V, 1)
+
+ # Encode the pose trans
+ pose_trans_features_across_views = self.cam_trans_encoder(
+ EncoderGlobalRepInput(data=scaled_pose_trans_across_views)
+ ).features
+ # Zero out the pose trans features where the camera input mask is False
+ pose_trans_features_across_views = (
+ pose_trans_features_across_views * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+
+ # Encode the pose translation norm factors using the log scale encoder for pose trans
+ log_pose_trans_norm_factors_across_views = torch.log(
+ pose_trans_norm_factors_across_views + 1e-8
+ )
+ pose_trans_scale_features_across_views = self.cam_trans_scale_encoder(
+ EncoderGlobalRepInput(data=log_pose_trans_norm_factors_across_views)
+ ).features
+ # Zero out the pose trans scale features where the camera input mask is False
+ pose_trans_scale_features_across_views = (
+ pose_trans_scale_features_across_views
+ * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+ # Zero out the pose trans scale features where the metric scale mask is False
+ # Scale encoding is only provided for metric scale samples
+ pose_trans_scale_features_across_views = (
+ pose_trans_scale_features_across_views
+ * metric_scale_pose_trans_mask.unsqueeze(-1)
+ )
+
+ # Fuse the pose quat features, pose trans features, pose trans scale features and pose trans type PE features with the other encoder features
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views
+ + pose_quats_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ + pose_trans_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ + pose_trans_scale_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_optional_geometric_inputs(
+ self, views, all_encoder_features_across_views_list
+ ):
+ """
+ Encode all the input optional geometric modalities and fuses it with the image encoder features in a single forward pass.
+ Assumes all the input views have the same shape and batch size.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ all_encoder_features_across_views (List[torch.Tensor]): List of tensors containing the encoded image features for all N views.
+
+ Returns:
+ List[torch.Tensor]: A list containing the encoded features for all N views.
+ """
+ num_views = len(views)
+ batch_size_per_view, _, _, _ = views[0]["img"].shape
+ device = all_encoder_features_across_views_list[0].device
+ dtype = all_encoder_features_across_views_list[0].dtype
+ all_encoder_features_across_views = torch.cat(
+ all_encoder_features_across_views_list, dim=0
+ )
+
+ # Get the overall input mask for all the views
+ overall_geometric_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["overall_prob"]
+ )
+ overall_geometric_input_mask = overall_geometric_input_mask.repeat(num_views)
+
+ # Get the per sample input mask after dropout
+ # Per sample input mask is in view-major order so that index v*B + b in each mask corresponds to sample b of view v: (B * V)
+ per_sample_geometric_input_mask = torch.rand(
+ batch_size_per_view * num_views, device=device
+ ) < (1 - self.geometric_input_config["dropout_prob"])
+ per_sample_geometric_input_mask = (
+ per_sample_geometric_input_mask & overall_geometric_input_mask
+ )
+
+ # Get the ray direction input mask
+ per_sample_ray_dirs_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["ray_dirs_prob"]
+ )
+ per_sample_ray_dirs_input_mask = per_sample_ray_dirs_input_mask.repeat(
+ num_views
+ )
+ per_sample_ray_dirs_input_mask = (
+ per_sample_ray_dirs_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Get the depth input mask
+ per_sample_depth_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["depth_prob"]
+ )
+ per_sample_depth_input_mask = per_sample_depth_input_mask.repeat(num_views)
+ per_sample_depth_input_mask = (
+ per_sample_depth_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Get the camera input mask
+ per_sample_cam_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["cam_prob"]
+ )
+ per_sample_cam_input_mask = per_sample_cam_input_mask.repeat(num_views)
+ per_sample_cam_input_mask = (
+ per_sample_cam_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ # Returned pose quats and trans represent identity pose for views/samples where the camera input mask is False
+ pose_quats_across_views, pose_trans_across_views, per_sample_cam_input_mask = (
+ self._compute_pose_quats_and_trans_for_across_views_in_ref_view(
+ views,
+ num_views,
+ device,
+ dtype,
+ batch_size_per_view,
+ per_sample_cam_input_mask,
+ )
+ )
+
+ # Encode the ray directions and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_ray_dirs(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_ray_dirs_input_mask,
+ )
+
+ # Encode the depths and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_depths(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_depth_input_mask,
+ )
+
+ # Encode the cam quat and trans and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_cam_quats_and_trans(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ )
+
+ # Normalize the fused features (permute -> normalize -> permute)
+ all_encoder_features_across_views = all_encoder_features_across_views.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ all_encoder_features_across_views = self.fusion_norm_layer(
+ all_encoder_features_across_views
+ )
+ all_encoder_features_across_views = all_encoder_features_across_views.permute(
+ 0, 3, 1, 2
+ ).contiguous()
+
+ # Split the batched views into individual views
+ fused_all_encoder_features_across_views = (
+ all_encoder_features_across_views.chunk(num_views, dim=0)
+ )
+
+ return fused_all_encoder_features_across_views
+
+ def forward(self, views):
+ """
+ Forward pass performing the following operations:
+ 1. Encodes the N input views (images).
+ 2. Encodes the optional geometric inputs (ray directions, depths, camera rotations, camera translations).
+ 3. Fuses the encoded features from the N input views and the optional geometric inputs using addition and normalization.
+ 4. Information sharing between the encoded features using a multi-view attention transformer.
+ 5. Passes the final features through the prediction heads.
+ 6. Returns the processed final outputs for N views.
+
+ Assumption:
+ - All the input views have the same image shape.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W). Input images must be normalized based on the data norm type of image encoder.
+ "data_norm_type" (list): [model.encoder.data_norm_type]
+
+ Returns:
+ List[dict]: A list containing the final outputs for all N views.
+ """
+ # Get input shape of the images, number of views, and batch size per view
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ img_shape = (int(height), int(width))
+ num_views = len(views)
+
+ # Run the encoder on all the input views
+ all_encoder_features_across_views = self._encode_n_views(views)
+
+ # Encode the optional geometric inputs and fuse with the encoded features from the N input views
+ # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features)
+ with torch.autocast("cuda", enabled=False):
+ all_encoder_features_across_views = (
+ self._encode_and_fuse_optional_geometric_inputs(
+ views, all_encoder_features_across_views
+ )
+ )
+
+ # Combine all images into view-centric representation
+ info_sharing_input = MultiViewTransformerInput(
+ features=all_encoder_features_across_views
+ )
+ if self.info_sharing_return_type == "no_intermediate_features":
+ final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
+ elif self.info_sharing_return_type == "intermediate_features":
+ (
+ final_info_sharing_multi_view_feat,
+ intermediate_info_sharing_multi_view_feat,
+ ) = self.info_sharing(info_sharing_input)
+
+ if self.pred_head_type == "linear":
+ # Stack the features for all views
+ dense_head_inputs = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
+ # Get the list of features for all views
+ dense_head_inputs_list = []
+ if self.use_encoder_features_for_dpt:
+ # Stack all the image encoder features for all views
+ stacked_encoder_features = torch.cat(
+ all_encoder_features_across_views, dim=0
+ )
+ dense_head_inputs_list.append(stacked_encoder_features)
+ # Stack the first intermediate features for all views
+ stacked_intermediate_features_1 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[0].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_1)
+ # Stack the second intermediate features for all views
+ stacked_intermediate_features_2 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[1].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_2)
+ # Stack the last layer features for all views
+ stacked_final_features = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_final_features)
+ else:
+ # Stack the first intermediate features for all views
+ stacked_intermediate_features_1 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[0].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_1)
+ # Stack the second intermediate features for all views
+ stacked_intermediate_features_2 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[1].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_2)
+ # Stack the third intermediate features for all views
+ stacked_intermediate_features_3 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[2].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_3)
+ # Stack the last layer
+ stacked_final_features = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_final_features)
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Downstream task prediction
+ with torch.autocast("cuda", enabled=False):
+ # Run Prediction Heads & Post-Process Outputs
+ if self.pred_head_type == "linear":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadInput(last_feature=dense_head_inputs)
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ elif self.pred_head_type == "dpt":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadLayeredInput(
+ list_features=dense_head_inputs_list,
+ target_output_shape=img_shape,
+ )
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ elif self.pred_head_type == "dpt+pose":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadLayeredInput(
+ list_features=dense_head_inputs_list,
+ target_output_shape=img_shape,
+ )
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ pose_head_outputs = self.pose_head(
+ PredictionHeadInput(last_feature=dense_head_inputs_list[-1])
+ )
+ pose_final_outputs = self.pose_adaptor(
+ AdaptorInput(
+ adaptor_feature=pose_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Prepare the final scene representation for all views
+ if self.scene_rep_type in [
+ "pointmap",
+ "pointmap+confidence",
+ "pointmap+mask",
+ "pointmap+confidence+mask",
+ ]:
+ output_pts3d = dense_final_outputs.value
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_pts3d = output_pts3d.permute(0, 2, 3, 1).contiguous()
+ # Split the predicted pointmaps back to their respective views
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append({"pts3d": output_pts3d_per_view[i]})
+ elif self.scene_rep_type in [
+ "raymap+depth",
+ "raymap+depth+confidence",
+ "raymap+depth+mask",
+ "raymap+depth+confidence+mask",
+ ]:
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_scene_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted ray origins, directions, and depths along rays
+ output_ray_origins, output_ray_directions, output_depth_along_ray = (
+ output_scene_rep.split([3, 3, 1], dim=-1)
+ )
+ # Get the predicted pointmaps
+ output_pts3d = (
+ output_ray_origins + output_ray_directions * output_depth_along_ray
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_origins_per_view = output_ray_origins.chunk(num_views, dim=0)
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i],
+ "ray_origins": output_ray_origins_per_view[i],
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i],
+ }
+ )
+ elif self.scene_rep_type in [
+ "raydirs+depth+pose",
+ "raydirs+depth+pose+confidence",
+ "raydirs+depth+pose+mask",
+ "raydirs+depth+pose+confidence+mask",
+ ]:
+ # Reshape output dense rep to (B * V, H, W, C)
+ output_dense_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted ray directions and depths along rays
+ output_ray_directions, output_depth_along_ray = output_dense_rep.split(
+ [3, 1], dim=-1
+ )
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the predicted pointmaps in world frame and camera frame
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ output_pts3d_cam = output_ray_directions * output_depth_along_ray
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i],
+ "pts3d_cam": output_pts3d_cam_per_view[i],
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i],
+ "cam_trans": output_cam_translations_per_view[i],
+ "cam_quats": output_cam_quats_per_view[i],
+ }
+ )
+ elif self.scene_rep_type in [
+ "campointmap+pose",
+ "campointmap+pose+confidence",
+ "campointmap+pose+mask",
+ "campointmap+pose+confidence+mask",
+ ]:
+ # Get the predicted camera frame pointmaps
+ output_pts3d_cam = dense_final_outputs.value
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_pts3d_cam = output_pts3d_cam.permute(0, 2, 3, 1).contiguous()
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the ray directions and depths along rays
+ output_depth_along_ray = torch.norm(
+ output_pts3d_cam, dim=-1, keepdim=True
+ )
+ output_ray_directions = output_pts3d_cam / output_depth_along_ray
+ # Get the predicted pointmaps in world frame
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i],
+ "pts3d_cam": output_pts3d_cam_per_view[i],
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i],
+ "cam_trans": output_cam_translations_per_view[i],
+ "cam_quats": output_cam_quats_per_view[i],
+ }
+ )
+ elif self.scene_rep_type in [
+ "pointmap+raydirs+depth+pose",
+ "pointmap+raydirs+depth+pose+confidence",
+ "pointmap+raydirs+depth+pose+mask",
+ "pointmap+raydirs+depth+pose+confidence+mask",
+ ]:
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_dense_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted pointmaps, ray directions and depths along rays
+ output_pts3d, output_ray_directions, output_depth_along_ray = (
+ output_dense_rep.split([3, 3, 1], dim=-1)
+ )
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the predicted pointmaps in camera frame
+ output_pts3d_cam = output_ray_directions * output_depth_along_ray
+ # Replace the predicted world-frame pointmaps if required
+ if self.pred_head_config["adaptor_config"][
+ "use_factored_predictions_for_global_pointmaps"
+ ]:
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i],
+ "pts3d_cam": output_pts3d_cam_per_view[i],
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i],
+ "cam_trans": output_cam_translations_per_view[i],
+ "cam_quats": output_cam_quats_per_view[i],
+ }
+ )
+ else:
+ raise ValueError(
+ f"Invalid scene_rep_type: {self.scene_rep_type}. \
+ Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \
+ 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \
+ 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \
+ 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']"
+ )
+
+ # Get the output confidences for all views (if available) and add them to the result
+ if "confidence" in self.scene_rep_type:
+ output_confidences = dense_final_outputs.confidence
+ # Reshape confidences to (B * V, H, W)
+ output_confidences = (
+ output_confidences.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ # Split the predicted confidences back to their respective views
+ output_confidences_per_view = output_confidences.chunk(num_views, dim=0)
+ # Add the confidences to the result
+ for i in range(num_views):
+ res[i]["conf"] = output_confidences_per_view[i]
+
+ # Get the output masks (and logits) for all views (if available) and add them to the result
+ if "mask" in self.scene_rep_type:
+ # Get the output masks
+ output_masks = dense_final_outputs.mask
+ # Reshape masks to (B * V, H, W)
+ output_masks = output_masks.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ # Threshold the masks at 0.5 to get binary masks (0: ambiguous/invalid, 1: non-ambiguous/valid)
+ output_masks = output_masks > 0.5
+ # Split the predicted masks back to their respective views
+ output_masks_per_view = output_masks.chunk(num_views, dim=0)
+ # Get the output mask logits (for loss)
+ output_mask_logits = dense_final_outputs.logits
+ # Reshape mask logits to (B * V, H, W)
+ output_mask_logits = (
+ output_mask_logits.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ # Split the predicted mask logits back to their respective views
+ output_mask_logits_per_view = output_mask_logits.chunk(num_views, dim=0)
+ # Add the masks and logits to the result
+ for i in range(num_views):
+ res[i]["non_ambiguous_mask"] = output_masks_per_view[i]
+ res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i]
+
+ return res
diff --git a/mapanything/models/mapanything/model.py b/mapanything/models/mapanything/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..82c1a15a59a5b645ee44926b71956447bcdf3b69
--- /dev/null
+++ b/mapanything/models/mapanything/model.py
@@ -0,0 +1,1719 @@
+"""
+MapAnything model class defined using UniCeption modules.
+"""
+
+from functools import partial
+from typing import Callable, Dict, Type, Union
+
+import torch
+import torch.nn as nn
+
+from mapanything.utils.geometry import (
+ apply_log_to_norm,
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ normalize_depth_using_non_zero_pixels,
+ normalize_pose_translations,
+ transform_pose_using_quats_and_trans_2_to_1,
+)
+from uniception.models.encoders import (
+ encoder_factory,
+ EncoderGlobalRepInput,
+ ViTEncoderInput,
+ ViTEncoderNonImageInput,
+)
+from uniception.models.info_sharing.alternating_attention_transformer import (
+ MultiViewAlternatingAttentionTransformer,
+ MultiViewAlternatingAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.base import MultiViewTransformerInput
+from uniception.models.info_sharing.cross_attention_transformer import (
+ MultiViewCrossAttentionTransformer,
+ MultiViewCrossAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.global_attention_transformer import (
+ MultiViewGlobalAttentionTransformer,
+ MultiViewGlobalAttentionTransformerIFR,
+)
+from uniception.models.prediction_heads.adaptors import (
+ CamTranslationPlusQuatsAdaptor,
+ PointMapAdaptor,
+ PointMapPlusRayDirectionsPlusDepthAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor,
+ PointMapWithConfidenceAdaptor,
+ PointMapWithConfidenceAndMaskAdaptor,
+ PointMapWithMaskAdaptor,
+ RayDirectionsPlusDepthAdaptor,
+ RayDirectionsPlusDepthWithConfidenceAdaptor,
+ RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor,
+ RayDirectionsPlusDepthWithMaskAdaptor,
+ RayMapPlusDepthAdaptor,
+ RayMapPlusDepthWithConfidenceAdaptor,
+ RayMapPlusDepthWithConfidenceAndMaskAdaptor,
+ RayMapPlusDepthWithMaskAdaptor,
+ ScaleAdaptor,
+)
+from uniception.models.prediction_heads.base import (
+ AdaptorInput,
+ PredictionHeadInput,
+ PredictionHeadLayeredInput,
+ PredictionHeadTokenInput,
+)
+from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
+from uniception.models.prediction_heads.linear import LinearFeature
+from uniception.models.prediction_heads.mlp_head import MLPHead
+from uniception.models.prediction_heads.pose_head import PoseHead
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+class MapAnything(nn.Module):
+ "Modular MapAnything model class that supports input of images & optional geometric modalities (multiple reconstruction tasks)."
+
+ def __init__(
+ self,
+ name: str,
+ encoder_config: Dict,
+ info_sharing_config: Dict,
+ pred_head_config: Dict,
+ geometric_input_config: Dict,
+ fusion_norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(
+ nn.LayerNorm, eps=1e-6
+ ),
+ pretrained_checkpoint_path: str = None,
+ load_specific_pretrained_submodules: bool = False,
+ specific_pretrained_submodules: list = None,
+ torch_hub_force_reload: bool = False,
+ ):
+ """
+ Multi-view model containing an image encoder fused with optional geometric modalities followed by a multi-view attention transformer and respective downstream heads.
+ The goal is to output scene representation.
+ The multi-view attention transformer also takes as input a scale token to predict the metric scaling factor for the predicted scene representation.
+
+ Args:
+ name (str): Name of the model.
+ encoder_config (Dict): Configuration for the encoder.
+ info_sharing_config (Dict): Configuration for the multi-view attention transformer.
+ pred_head_config (Dict): Configuration for the prediction heads.
+ geometric_input_config (Dict): Configuration for the input of optional geometric modalities.
+ fusion_norm_layer (Union[Type[nn.Module], Callable[..., nn.Module]]): Normalization layer to use after fusion (addition) of encoder and geometric modalities. (default: partial(nn.LayerNorm, eps=1e-6))
+ pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
+ load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False)
+ specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: None)
+ torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False)
+ """
+ super().__init__()
+
+ # Initalize the attributes
+ self.name = name
+ self.encoder_config = encoder_config
+ self.info_sharing_config = info_sharing_config
+ self.pred_head_config = pred_head_config
+ self.geometric_input_config = geometric_input_config
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
+ self.load_specific_pretrained_submodules = load_specific_pretrained_submodules
+ self.specific_pretrained_submodules = specific_pretrained_submodules
+ self.torch_hub_force_reload = torch_hub_force_reload
+ self.class_init_args = {
+ "name": self.name,
+ "encoder_config": self.encoder_config,
+ "info_sharing_config": self.info_sharing_config,
+ "pred_head_config": self.pred_head_config,
+ "geometric_input_config": self.geometric_input_config,
+ "pretrained_checkpoint_path": self.pretrained_checkpoint_path,
+ "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules,
+ "specific_pretrained_submodules": self.specific_pretrained_submodules,
+ "torch_hub_force_reload": self.torch_hub_force_reload,
+ }
+
+ # Get relevant parameters from the configs
+ self.info_sharing_type = info_sharing_config["model_type"]
+ self.info_sharing_return_type = info_sharing_config["model_return_type"]
+ self.pred_head_type = pred_head_config["type"]
+
+ # Initialize image encoder
+ if self.encoder_config["uses_torch_hub"]:
+ self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
+ del self.encoder_config["uses_torch_hub"]
+ self.encoder = encoder_factory(**self.encoder_config)
+
+ # Initialize the encoder for ray directions
+ ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
+ ray_dirs_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ ray_dirs_encoder_config["patch_size"] = self.encoder.patch_size
+ self.ray_dirs_encoder = encoder_factory(**ray_dirs_encoder_config)
+
+ # Initialize the encoder for depth (normalized per view and values after normalization are scaled logarithmically)
+ depth_encoder_config = self.geometric_input_config["depth_encoder_config"]
+ depth_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ depth_encoder_config["patch_size"] = self.encoder.patch_size
+ self.depth_encoder = encoder_factory(**depth_encoder_config)
+
+ # Initialize the encoder for log scale factor of depth
+ depth_scale_encoder_config = self.geometric_input_config["scale_encoder_config"]
+ depth_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.depth_scale_encoder = encoder_factory(**depth_scale_encoder_config)
+
+ # Initialize the encoder for camera rotation
+ cam_rot_encoder_config = self.geometric_input_config["cam_rot_encoder_config"]
+ cam_rot_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_rot_encoder = encoder_factory(**cam_rot_encoder_config)
+
+ # Initialize the encoder for camera translation (normalized across all provided camera translations)
+ cam_trans_encoder_config = self.geometric_input_config[
+ "cam_trans_encoder_config"
+ ]
+ cam_trans_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_trans_encoder = encoder_factory(**cam_trans_encoder_config)
+
+ # Initialize the encoder for log scale factor of camera translation
+ cam_trans_scale_encoder_config = self.geometric_input_config[
+ "scale_encoder_config"
+ ]
+ cam_trans_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_trans_scale_encoder = encoder_factory(**cam_trans_scale_encoder_config)
+
+ # Initialize the fusion norm layer
+ self.fusion_norm_layer = fusion_norm_layer(self.encoder.enc_embed_dim)
+
+ # Initialize the Scale Token
+ # Used to scale the final scene predictions to metric scale
+ # During inference extended to (B, C, T), where T is the number of tokens (i.e., 1)
+ self.scale_token = nn.Parameter(torch.zeros(self.encoder.enc_embed_dim))
+ torch.nn.init.trunc_normal_(self.scale_token, std=0.02)
+
+ # Initialize the info sharing module (multi-view transformer)
+ self._initialize_info_sharing(info_sharing_config)
+
+ # Initialize the prediction heads
+ self._initialize_prediction_heads(pred_head_config)
+
+ # Initialize the final adaptors
+ self._initialize_adaptors(pred_head_config)
+
+ # Load pretrained weights
+ self._load_pretrained_weights()
+
+ def _initialize_info_sharing(self, info_sharing_config):
+ """
+ Initialize the information sharing module based on the configuration.
+
+ This method sets up the custom positional encoding if specified and initializes
+ the appropriate multi-view transformer based on the configuration type.
+
+ Args:
+ info_sharing_config (Dict): Configuration for the multi-view attention transformer.
+ Should contain 'custom_positional_encoding', 'model_type', and 'model_return_type'.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If invalid configuration options are provided.
+ """
+ # Initialize Custom Positional Encoding if required
+ custom_positional_encoding = info_sharing_config["custom_positional_encoding"]
+ if custom_positional_encoding is not None:
+ if isinstance(custom_positional_encoding, str):
+ print(
+ f"Using custom positional encoding for multi-view attention transformer: {custom_positional_encoding}"
+ )
+ raise ValueError(
+ f"Invalid custom_positional_encoding: {custom_positional_encoding}. None implemented."
+ )
+ elif isinstance(custom_positional_encoding, Callable):
+ print(
+ "Using callable function as custom positional encoding for multi-view attention transformer."
+ )
+ self.custom_positional_encoding = custom_positional_encoding
+ else:
+ self.custom_positional_encoding = None
+
+ # Add dependecies to info_sharing_config
+ info_sharing_config["module_args"]["input_embed_dim"] = (
+ self.encoder.enc_embed_dim
+ )
+ info_sharing_config["module_args"]["custom_positional_encoding"] = (
+ self.custom_positional_encoding
+ )
+
+ # Initialize Multi-View Transformer
+ if self.info_sharing_return_type == "no_intermediate_features":
+ # Returns only normalized last layer features
+ # Intialize multi-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ elif self.info_sharing_return_type == "intermediate_features":
+ # Returns intermediate features and normalized last layer features
+ # Initialize mulit-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ # Assess if the DPT needs to use encoder features
+ if len(self.info_sharing.indices) == 2:
+ self.use_encoder_features_for_dpt = True
+ elif len(self.info_sharing.indices) == 3:
+ self.use_encoder_features_for_dpt = False
+ else:
+ raise ValueError(
+ "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices."
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']"
+ )
+
+ def _initialize_prediction_heads(self, pred_head_config):
+ """
+ Initialize the prediction heads based on the prediction head configuration.
+
+ This method configures and initializes the appropriate prediction heads based on the
+ specified prediction head type (linear, DPT, or DPT+pose). It sets up the necessary
+ dependencies and creates the required model components.
+
+ Args:
+ pred_head_config (Dict): Configuration for the prediction heads.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If an invalid pred_head_type is provided.
+ """
+ # Add dependencies to prediction head config
+ pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size
+ if self.pred_head_type == "linear":
+ pred_head_config["feature_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ elif "dpt" in self.pred_head_type:
+ # Add dependencies for DPT & Regressor head
+ if self.use_encoder_features_for_dpt:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.encoder.enc_embed_dim
+ ] + [self.info_sharing.dim] * 3
+ else:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.info_sharing.dim
+ ] * 4
+ pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[
+ "feature_head"
+ ]["feature_dim"]
+ # Add dependencies for Pose head if required
+ if "pose" in self.pred_head_type:
+ pred_head_config["pose_head"]["patch_size"] = self.encoder.patch_size
+ pred_head_config["pose_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+ pred_head_config["scale_head"]["input_feature_dim"] = self.info_sharing.dim
+
+ # Initialize Prediction Heads
+ if self.pred_head_type == "linear":
+ # Initialize Dense Prediction Head for all views
+ self.dense_head = LinearFeature(**pred_head_config["feature_head"])
+ elif "dpt" in self.pred_head_type:
+ # Initialze Dense Predction Head for all views
+ self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"])
+ self.dpt_regressor_head = DPTRegressionProcessor(
+ **pred_head_config["regressor_head"]
+ )
+ self.dense_head = nn.Sequential(
+ self.dpt_feature_head, self.dpt_regressor_head
+ )
+ # Initialize Pose Head for all views if required
+ if "pose" in self.pred_head_type:
+ self.pose_head = PoseHead(**pred_head_config["pose_head"])
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+ self.scale_head = MLPHead(**pred_head_config["scale_head"])
+
+ def _initialize_adaptors(self, pred_head_config):
+ """
+ Initialize the adaptors based on the prediction head configuration.
+
+ This method sets up the appropriate adaptors for different scene representation types,
+ such as pointmaps, ray maps with depth, or ray directions with depth and pose.
+
+ Args:
+ pred_head_config (Dict): Configuration for the prediction heads including adaptor type.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If an invalid adaptor_type is provided.
+ AssertionError: If ray directions + depth + pose is used with an incompatible head type.
+ """
+ if pred_head_config["adaptor_type"] == "pointmap":
+ self.dense_adaptor = PointMapAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap"
+ elif pred_head_config["adaptor_type"] == "pointmap+confidence":
+ self.dense_adaptor = PointMapWithConfidenceAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "pointmap+confidence"
+ elif pred_head_config["adaptor_type"] == "pointmap+mask":
+ self.dense_adaptor = PointMapWithMaskAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap+mask"
+ elif pred_head_config["adaptor_type"] == "pointmap+confidence+mask":
+ self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "pointmap+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "raymap+depth":
+ self.dense_adaptor = RayMapPlusDepthAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "raymap+depth"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+confidence":
+ self.dense_adaptor = RayMapPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+confidence"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+mask":
+ self.dense_adaptor = RayMapPlusDepthWithMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+mask"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+confidence+mask":
+ self.dense_adaptor = RayMapPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+mask"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapAdaptor(**pred_head_config["dpt_adaptor"])
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+mask"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose"
+ elif (
+ pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+confidence"
+ ):
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = (
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+mask"
+ elif (
+ pred_head_config["adaptor_type"]
+ == "pointmap+raydirs+depth+pose+confidence+mask"
+ ):
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = (
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence+mask"
+ else:
+ raise ValueError(
+ f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. \
+ Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \
+ 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \
+ 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \
+ 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']"
+ )
+ self.scale_adaptor = ScaleAdaptor(**pred_head_config["scale_adaptor"])
+
+ def _load_pretrained_weights(self):
+ """
+ Load pretrained weights from a checkpoint file.
+
+ If load_specific_pretrained_submodules is True, only loads weights for the specified submodules.
+ Otherwise, loads all weights from the checkpoint.
+
+ Returns:
+ None
+ """
+ if self.pretrained_checkpoint_path is not None:
+ if not self.load_specific_pretrained_submodules:
+ print(
+ f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ print(self.load_state_dict(ckpt["model"]))
+ else:
+ print(
+ f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} for specific submodules: {self.specific_pretrained_submodules} ..."
+ )
+ assert self.pred_head_type is not None, (
+ "Specific submodules to load cannot be None."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ filtered_ckpt = {}
+ for ckpt_key, ckpt_value in ckpt["model"].items():
+ for submodule in self.specific_pretrained_submodules:
+ if ckpt_key.startswith(submodule):
+ filtered_ckpt[ckpt_key] = ckpt_value
+ print(self.load_state_dict(filtered_ckpt, strict=False))
+
+ def _encode_n_views(self, views):
+ """
+ Encode all the input views (batch of images) in a single forward pass.
+ Assumes all the input views have the same image shape, batch size, and data normalization type.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+
+ Returns:
+ List[torch.Tensor]: A list containing the encoded features for all N views.
+ """
+ num_views = len(views)
+ data_norm_type = views[0]["data_norm_type"][0]
+ imgs_list = [view["img"] for view in views]
+ all_imgs_across_views = torch.cat(imgs_list, dim=0)
+ encoder_input = ViTEncoderInput(
+ image=all_imgs_across_views, data_norm_type=data_norm_type
+ )
+ encoder_output = self.encoder(encoder_input)
+ all_encoder_features_across_views = encoder_output.features.chunk(
+ num_views, dim=0
+ )
+
+ return all_encoder_features_across_views
+
+ def _compute_pose_quats_and_trans_for_across_views_in_ref_view(
+ self,
+ views,
+ num_views,
+ device,
+ dtype,
+ batch_size_per_view,
+ per_sample_cam_input_mask,
+ ):
+ """
+ Compute the pose quats and trans for all the views in the frame of the reference view 0.
+ Returns identity pose for views where the camera input mask is False or the pose is not provided.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ device (torch.device): Device to use for the computation.
+ dtype (torch.dtype): Data type to use for the computation.
+ per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4)
+ torch.Tensor: A tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3)
+ torch.Tensor: A tensor containing the per sample camera input mask.
+ """
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ pose_quats_non_ref_views = []
+ pose_trans_non_ref_views = []
+ pose_quats_ref_view_0 = []
+ pose_trans_ref_view_0 = []
+ for view_idx in range(num_views):
+ per_sample_cam_input_mask_for_curr_view = per_sample_cam_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ]
+ if (
+ "camera_pose_quats" in views[view_idx]
+ and "camera_pose_trans" in views[view_idx]
+ and per_sample_cam_input_mask_for_curr_view.any()
+ ):
+ # Get the camera pose quats and trans for the current view
+ cam_pose_quats = views[view_idx]["camera_pose_quats"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ cam_pose_trans = views[view_idx]["camera_pose_trans"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ # Append to the list
+ pose_quats_non_ref_views.append(cam_pose_quats)
+ pose_trans_non_ref_views.append(cam_pose_trans)
+ # Get the camera pose quats and trans for the reference view 0
+ cam_pose_quats = views[0]["camera_pose_quats"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ cam_pose_trans = views[0]["camera_pose_trans"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ # Append to the list
+ pose_quats_ref_view_0.append(cam_pose_quats)
+ pose_trans_ref_view_0.append(cam_pose_trans)
+ else:
+ per_sample_cam_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+
+ # Initialize the pose quats and trans for all views as identity
+ pose_quats_across_views = torch.tensor(
+ [0.0, 0.0, 0.0, 1.0], dtype=dtype, device=device
+ ).repeat(batch_size_per_view * num_views, 1) # (q_x, q_y, q_z, q_w)
+ pose_trans_across_views = torch.zeros(
+ (batch_size_per_view * num_views, 3), dtype=dtype, device=device
+ )
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ if len(pose_quats_non_ref_views) > 0:
+ # Stack the pose quats and trans for all the non-reference views and reference view 0
+ pose_quats_non_ref_views = torch.cat(pose_quats_non_ref_views, dim=0)
+ pose_trans_non_ref_views = torch.cat(pose_trans_non_ref_views, dim=0)
+ pose_quats_ref_view_0 = torch.cat(pose_quats_ref_view_0, dim=0)
+ pose_trans_ref_view_0 = torch.cat(pose_trans_ref_view_0, dim=0)
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ (
+ pose_quats_non_ref_views_in_ref_view_0,
+ pose_trans_non_ref_views_in_ref_view_0,
+ ) = transform_pose_using_quats_and_trans_2_to_1(
+ pose_quats_ref_view_0,
+ pose_trans_ref_view_0,
+ pose_quats_non_ref_views,
+ pose_trans_non_ref_views,
+ )
+
+ # Update the pose quats and trans for all the non-reference views
+ pose_quats_across_views[per_sample_cam_input_mask] = (
+ pose_quats_non_ref_views_in_ref_view_0.to(dtype=dtype)
+ )
+ pose_trans_across_views[per_sample_cam_input_mask] = (
+ pose_trans_non_ref_views_in_ref_view_0.to(dtype=dtype)
+ )
+
+ return (
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ )
+
+ def _encode_and_fuse_ray_dirs(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_ray_dirs_input_mask,
+ ):
+ """
+ Encode the ray directions for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ per_sample_ray_dirs_input_mask (torch.Tensor): Tensor containing the per sample ray direction input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Get the height and width of the images
+ _, _, height, width = views[0]["img"].shape
+
+ # Get the ray directions for all the views where info is provided and the ray direction input mask is True
+ ray_dirs_list = []
+ for view_idx in range(num_views):
+ per_sample_ray_dirs_input_mask_for_curr_view = (
+ per_sample_ray_dirs_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ]
+ )
+ ray_dirs_for_curr_view = torch.zeros(
+ (batch_size_per_view, height, width, 3),
+ dtype=all_encoder_features_across_views.dtype,
+ device=all_encoder_features_across_views.device,
+ )
+ if (
+ "ray_directions_cam" in views[view_idx]
+ and per_sample_ray_dirs_input_mask_for_curr_view.any()
+ ):
+ ray_dirs_for_curr_view[per_sample_ray_dirs_input_mask_for_curr_view] = (
+ views[view_idx]["ray_directions_cam"][
+ per_sample_ray_dirs_input_mask_for_curr_view
+ ]
+ )
+ else:
+ per_sample_ray_dirs_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+ ray_dirs_list.append(ray_dirs_for_curr_view)
+
+ # Stack the ray directions for all the views and permute to (B * V, C, H, W)
+ ray_dirs = torch.cat(ray_dirs_list, dim=0) # (B * V, H, W, 3)
+ ray_dirs = ray_dirs.permute(0, 3, 1, 2).contiguous() # (B * V, 3, H, W)
+
+ # Encode the ray directions
+ ray_dirs_features_across_views = self.ray_dirs_encoder(
+ ViTEncoderNonImageInput(data=ray_dirs)
+ ).features
+
+ # Fuse the ray direction features with the other encoder features (zero out the features where the ray direction input mask is False)
+ ray_dirs_features_across_views = (
+ ray_dirs_features_across_views
+ * per_sample_ray_dirs_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views + ray_dirs_features_across_views
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_depths(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_depth_input_mask,
+ ):
+ """
+ Encode the z depths for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ per_sample_depth_input_mask (torch.Tensor): Tensor containing the per sample depth input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Get the device and height and width of the images
+ device = all_encoder_features_across_views.device
+ _, _, height, width = views[0]["img"].shape
+
+ # Decide to use randomly sampled sparse depth or dense depth
+ if torch.rand(1) < self.geometric_input_config["sparse_depth_prob"]:
+ use_sparse_depth = True
+ else:
+ use_sparse_depth = False
+
+ # Get the depths for all the views
+ depth_list = []
+ depth_norm_factors_list = []
+ metric_scale_depth_mask_list = []
+ for view_idx in range(num_views):
+ # Get the input mask for current view
+ per_sample_depth_input_mask_for_curr_view = per_sample_depth_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ]
+ depth_for_curr_view = torch.zeros(
+ (batch_size_per_view, height, width, 1),
+ dtype=all_encoder_features_across_views.dtype,
+ device=device,
+ )
+ depth_norm_factor_for_curr_view = torch.zeros(
+ (batch_size_per_view),
+ dtype=all_encoder_features_across_views.dtype,
+ device=device,
+ )
+ metric_scale_mask_for_curr_view = torch.zeros(
+ (batch_size_per_view),
+ dtype=torch.bool,
+ device=device,
+ )
+ if (
+ "depth_along_ray" in views[view_idx]
+ ) and per_sample_depth_input_mask_for_curr_view.any():
+ # Get depth for current view
+ depth_for_curr_view_input = views[view_idx]["depth_along_ray"][
+ per_sample_depth_input_mask_for_curr_view
+ ]
+ # Get the metric scale mask
+ if "is_metric_scale" in views[view_idx]:
+ metric_scale_mask = views[view_idx]["is_metric_scale"][
+ per_sample_depth_input_mask_for_curr_view
+ ]
+ else:
+ metric_scale_mask = torch.zeros(
+ depth_for_curr_view_input.shape[0],
+ dtype=torch.bool,
+ device=device,
+ )
+ # Turn off indication of metric scale samples based on the depth_scale_norm_all_prob
+ depth_scale_norm_all_mask = (
+ torch.rand(metric_scale_mask.shape[0])
+ < self.geometric_input_config["depth_scale_norm_all_prob"]
+ )
+ if depth_scale_norm_all_mask.any():
+ metric_scale_mask[depth_scale_norm_all_mask] = False
+ # Assign the metric scale mask to the respective indices
+ metric_scale_mask_for_curr_view[
+ per_sample_depth_input_mask_for_curr_view
+ ] = metric_scale_mask
+ # Sparsely sample the depth if required
+ if use_sparse_depth:
+ # Create a mask of ones
+ sparsification_mask = torch.ones_like(
+ depth_for_curr_view_input, device=device
+ )
+ # Create a mask for valid pixels (depth > 0)
+ valid_pixel_mask = depth_for_curr_view_input > 0
+ # Calculate the number of valid pixels
+ num_valid_pixels = valid_pixel_mask.sum().item()
+ # Calculate the number of valid pixels to set to zero
+ num_to_zero = int(
+ num_valid_pixels
+ * self.geometric_input_config["sparsification_removal_percent"]
+ )
+ if num_to_zero > 0:
+ # Get the indices of valid pixels
+ valid_indices = valid_pixel_mask.nonzero(as_tuple=True)
+ # Randomly select indices to zero out
+ indices_to_zero = torch.randperm(num_valid_pixels)[:num_to_zero]
+ # Set selected valid indices to zero in the mask
+ sparsification_mask[
+ valid_indices[0][indices_to_zero],
+ valid_indices[1][indices_to_zero],
+ valid_indices[2][indices_to_zero],
+ valid_indices[3][indices_to_zero],
+ ] = 0
+ # Apply the mask on the depth
+ depth_for_curr_view_input = (
+ depth_for_curr_view_input * sparsification_mask
+ )
+ # Normalize the depth
+ scaled_depth_for_curr_view_input, depth_norm_factor = (
+ normalize_depth_using_non_zero_pixels(
+ depth_for_curr_view_input, return_norm_factor=True
+ )
+ )
+ # Assign the depth and depth norm factor to the respective indices
+ depth_for_curr_view[per_sample_depth_input_mask_for_curr_view] = (
+ scaled_depth_for_curr_view_input
+ )
+ depth_norm_factor_for_curr_view[
+ per_sample_depth_input_mask_for_curr_view
+ ] = depth_norm_factor
+ else:
+ per_sample_depth_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+ # Append the depths, depth norm factor and metric scale mask for the current view
+ depth_list.append(depth_for_curr_view)
+ depth_norm_factors_list.append(depth_norm_factor_for_curr_view)
+ metric_scale_depth_mask_list.append(metric_scale_mask_for_curr_view)
+
+ # Stack the depths for all the views and permute to (B * V, C, H, W)
+ depths = torch.cat(depth_list, dim=0) # (B * V, H, W, 1)
+ depths = apply_log_to_norm(
+ depths
+ ) # Scale logarithimically (norm is computed along last dim)
+ depths = depths.permute(0, 3, 1, 2).contiguous() # (B * V, 1, H, W)
+ # Encode the depths using the depth encoder
+ depth_features_across_views = self.depth_encoder(
+ ViTEncoderNonImageInput(data=depths)
+ ).features
+ # Zero out the depth features where the depth input mask is False
+ depth_features_across_views = (
+ depth_features_across_views
+ * per_sample_depth_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+
+ # Stack the depth norm factors for all the views
+ depth_norm_factors = torch.cat(depth_norm_factors_list, dim=0) # (B * V, )
+ # Encode the depth norm factors using the log scale encoder for depth
+ log_depth_norm_factors = torch.log(depth_norm_factors + 1e-8) # (B * V, )
+ depth_scale_features_across_views = self.depth_scale_encoder(
+ EncoderGlobalRepInput(data=log_depth_norm_factors.unsqueeze(-1))
+ ).features
+ # Zero out the depth scale features where the depth input mask is False
+ depth_scale_features_across_views = (
+ depth_scale_features_across_views
+ * per_sample_depth_input_mask.unsqueeze(-1)
+ )
+ # Stack the metric scale mask for all the views
+ metric_scale_depth_mask = torch.cat(
+ metric_scale_depth_mask_list, dim=0
+ ) # (B * V, )
+ # Zero out the depth scale features where the metric scale mask is False
+ # Scale encoding is only provided for metric scale samples
+ depth_scale_features_across_views = (
+ depth_scale_features_across_views * metric_scale_depth_mask.unsqueeze(-1)
+ )
+
+ # Fuse the depth features & depth scale features with the other encoder features
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views
+ + depth_features_across_views
+ + depth_scale_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_cam_quats_and_trans(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ ):
+ """
+ Encode the camera quats and trans for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ pose_quats_across_views (torch.Tensor): Tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4)
+ pose_trans_across_views (torch.Tensor): Tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3)
+ per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Encode the pose quats
+ pose_quats_features_across_views = self.cam_rot_encoder(
+ EncoderGlobalRepInput(data=pose_quats_across_views)
+ ).features
+ # Zero out the pose quat features where the camera input mask is False
+ pose_quats_features_across_views = (
+ pose_quats_features_across_views * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+
+ # Get the metric scale mask for all samples
+ device = all_encoder_features_across_views.device
+ metric_scale_pose_trans_mask = torch.zeros(
+ (batch_size_per_view * num_views), dtype=torch.bool, device=device
+ )
+ for view_idx in range(num_views):
+ if "is_metric_scale" in views[view_idx]:
+ # Get the metric scale mask for the input pose priors
+ metric_scale_mask = views[view_idx]["is_metric_scale"]
+ else:
+ metric_scale_mask = torch.zeros(
+ batch_size_per_view, dtype=torch.bool, device=device
+ )
+ metric_scale_pose_trans_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ] = metric_scale_mask
+
+ # Turn off indication of metric scale samples based on the pose_scale_norm_all_prob
+ pose_norm_all_mask = (
+ torch.rand(batch_size_per_view * num_views)
+ < self.geometric_input_config["pose_scale_norm_all_prob"]
+ )
+ if pose_norm_all_mask.any():
+ metric_scale_pose_trans_mask[pose_norm_all_mask] = False
+
+ # Get the scale norm factor for all the samples and scale the pose translations
+ pose_trans_across_views = torch.split(
+ pose_trans_across_views, batch_size_per_view, dim=0
+ ) # Split into num_views chunks
+ pose_trans_across_views = torch.stack(
+ pose_trans_across_views, dim=1
+ ) # Stack the views along a new dimension (batch_size_per_view, num_views, 3)
+ scaled_pose_trans_across_views, pose_trans_norm_factors = (
+ normalize_pose_translations(
+ pose_trans_across_views, return_norm_factor=True
+ )
+ )
+
+ # Resize the pose translation back to (batch_size_per_view * num_views, 3) and extend the norm factor to (batch_size_per_view * num_views, 1)
+ scaled_pose_trans_across_views = scaled_pose_trans_across_views.unbind(
+ dim=1
+ ) # Convert back to list of views, where each view has batch_size_per_view tensor
+ scaled_pose_trans_across_views = torch.cat(
+ scaled_pose_trans_across_views, dim=0
+ ) # Concatenate back to (batch_size_per_view * num_views, 3)
+ pose_trans_norm_factors_across_views = pose_trans_norm_factors.unsqueeze(
+ -1
+ ).repeat(num_views, 1) # (B, ) -> (B * V, 1)
+
+ # Encode the pose trans
+ pose_trans_features_across_views = self.cam_trans_encoder(
+ EncoderGlobalRepInput(data=scaled_pose_trans_across_views)
+ ).features
+ # Zero out the pose trans features where the camera input mask is False
+ pose_trans_features_across_views = (
+ pose_trans_features_across_views * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+
+ # Encode the pose translation norm factors using the log scale encoder for pose trans
+ log_pose_trans_norm_factors_across_views = torch.log(
+ pose_trans_norm_factors_across_views + 1e-8
+ )
+ pose_trans_scale_features_across_views = self.cam_trans_scale_encoder(
+ EncoderGlobalRepInput(data=log_pose_trans_norm_factors_across_views)
+ ).features
+ # Zero out the pose trans scale features where the camera input mask is False
+ pose_trans_scale_features_across_views = (
+ pose_trans_scale_features_across_views
+ * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+ # Zero out the pose trans scale features where the metric scale mask is False
+ # Scale encoding is only provided for metric scale samples
+ pose_trans_scale_features_across_views = (
+ pose_trans_scale_features_across_views
+ * metric_scale_pose_trans_mask.unsqueeze(-1)
+ )
+
+ # Fuse the pose quat features, pose trans features, pose trans scale features and pose trans type PE features with the other encoder features
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views
+ + pose_quats_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ + pose_trans_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ + pose_trans_scale_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_optional_geometric_inputs(
+ self, views, all_encoder_features_across_views_list
+ ):
+ """
+ Encode all the input optional geometric modalities and fuses it with the image encoder features in a single forward pass.
+ Assumes all the input views have the same shape and batch size.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ all_encoder_features_across_views (List[torch.Tensor]): List of tensors containing the encoded image features for all N views.
+
+ Returns:
+ List[torch.Tensor]: A list containing the encoded features for all N views.
+ """
+ num_views = len(views)
+ batch_size_per_view, _, _, _ = views[0]["img"].shape
+ device = all_encoder_features_across_views_list[0].device
+ dtype = all_encoder_features_across_views_list[0].dtype
+ all_encoder_features_across_views = torch.cat(
+ all_encoder_features_across_views_list, dim=0
+ )
+
+ # Get the overall input mask for all the views
+ overall_geometric_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["overall_prob"]
+ )
+ overall_geometric_input_mask = overall_geometric_input_mask.repeat(num_views)
+
+ # Get the per sample input mask after dropout
+ # Per sample input mask is in view-major order so that index v*B + b in each mask corresponds to sample b of view v: (B * V)
+ per_sample_geometric_input_mask = torch.rand(
+ batch_size_per_view * num_views, device=device
+ ) < (1 - self.geometric_input_config["dropout_prob"])
+ per_sample_geometric_input_mask = (
+ per_sample_geometric_input_mask & overall_geometric_input_mask
+ )
+
+ # Get the ray direction input mask
+ per_sample_ray_dirs_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["ray_dirs_prob"]
+ )
+ per_sample_ray_dirs_input_mask = per_sample_ray_dirs_input_mask.repeat(
+ num_views
+ )
+ per_sample_ray_dirs_input_mask = (
+ per_sample_ray_dirs_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Get the depth input mask
+ per_sample_depth_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["depth_prob"]
+ )
+ per_sample_depth_input_mask = per_sample_depth_input_mask.repeat(num_views)
+ per_sample_depth_input_mask = (
+ per_sample_depth_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Get the camera input mask
+ per_sample_cam_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["cam_prob"]
+ )
+ per_sample_cam_input_mask = per_sample_cam_input_mask.repeat(num_views)
+ per_sample_cam_input_mask = (
+ per_sample_cam_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ # Returned pose quats and trans represent identity pose for views/samples where the camera input mask is False
+ pose_quats_across_views, pose_trans_across_views, per_sample_cam_input_mask = (
+ self._compute_pose_quats_and_trans_for_across_views_in_ref_view(
+ views,
+ num_views,
+ device,
+ dtype,
+ batch_size_per_view,
+ per_sample_cam_input_mask,
+ )
+ )
+
+ # Encode the ray directions and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_ray_dirs(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_ray_dirs_input_mask,
+ )
+
+ # Encode the depths and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_depths(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_depth_input_mask,
+ )
+
+ # Encode the cam quat and trans and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_cam_quats_and_trans(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ )
+
+ # Normalize the fused features (permute -> normalize -> permute)
+ all_encoder_features_across_views = all_encoder_features_across_views.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ all_encoder_features_across_views = self.fusion_norm_layer(
+ all_encoder_features_across_views
+ )
+ all_encoder_features_across_views = all_encoder_features_across_views.permute(
+ 0, 3, 1, 2
+ ).contiguous()
+
+ # Split the batched views into individual views
+ fused_all_encoder_features_across_views = (
+ all_encoder_features_across_views.chunk(num_views, dim=0)
+ )
+
+ return fused_all_encoder_features_across_views
+
+ def forward(self, views):
+ """
+ Forward pass performing the following operations:
+ 1. Encodes the N input views (images).
+ 2. Encodes the optional geometric inputs (ray directions, depths, camera rotations, camera translations).
+ 3. Fuses the encoded features from the N input views and the optional geometric inputs using addition and normalization.
+ 4. Information sharing across the encoded features and a scale token using a multi-view attention transformer.
+ 5. Passes the final features from transformer through the prediction heads.
+ 6. Returns the processed final outputs for N views.
+
+ Assumption:
+ - All the input views and dense geometric inputs have the same image shape.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W). Input images must be normalized based on the data norm type of image encoder.
+ "data_norm_type" (list): [model.encoder.data_norm_type]
+ Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs:
+ "ray_directions_cam" (tensor): Ray directions in the local camera frame. Tensor of shape (B, H, W, 3).
+ "depth_along_ray" (tensor): Depth along the ray. Tensor of shape (B, H, W, 1).
+ "camera_pose_quats" (tensor): Camera pose quaternions. Tensor of shape (B, 4). Camera pose is opencv (RDF) cam2world transformation.
+ "camera_pose_trans" (tensor): Camera pose translations. Tensor of shape (B, 3). Camera pose is opencv (RDF) cam2world transformation.
+ "is_metric_scale" (tensor): Boolean tensor indicating whether the geometric inputs are in metric scale or not. Tensor of shape (B, 1).
+
+ Returns:
+ List[dict]: A list containing the final outputs for all N views.
+ """
+ # Get input shape of the images, number of views, and batch size per view
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ img_shape = (int(height), int(width))
+ num_views = len(views)
+
+ # Run the image encoder on all the input views
+ all_encoder_features_across_views = self._encode_n_views(views)
+
+ # Encode the optional geometric inputs and fuse with the encoded features from the N input views
+ # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features)
+ with torch.autocast("cuda", enabled=False):
+ all_encoder_features_across_views = (
+ self._encode_and_fuse_optional_geometric_inputs(
+ views, all_encoder_features_across_views
+ )
+ )
+
+ # Expand the scale token to match the batch size
+ input_scale_token = (
+ self.scale_token.unsqueeze(0)
+ .unsqueeze(-1)
+ .repeat(batch_size_per_view, 1, 1)
+ ) # (B, C, 1)
+
+ # Combine all images into view-centric representation
+ # Output is a list containing the encoded features for all N views after information sharing.
+ info_sharing_input = MultiViewTransformerInput(
+ features=all_encoder_features_across_views,
+ additional_input_tokens=input_scale_token,
+ )
+ if self.info_sharing_return_type == "no_intermediate_features":
+ final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
+ elif self.info_sharing_return_type == "intermediate_features":
+ (
+ final_info_sharing_multi_view_feat,
+ intermediate_info_sharing_multi_view_feat,
+ ) = self.info_sharing(info_sharing_input)
+
+ if self.pred_head_type == "linear":
+ # Stack the features for all views
+ dense_head_inputs = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
+ # Get the list of features for all views
+ dense_head_inputs_list = []
+ if self.use_encoder_features_for_dpt:
+ # Stack all the image encoder features for all views
+ stacked_encoder_features = torch.cat(
+ all_encoder_features_across_views, dim=0
+ )
+ dense_head_inputs_list.append(stacked_encoder_features)
+ # Stack the first intermediate features for all views
+ stacked_intermediate_features_1 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[0].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_1)
+ # Stack the second intermediate features for all views
+ stacked_intermediate_features_2 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[1].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_2)
+ # Stack the last layer features for all views
+ stacked_final_features = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_final_features)
+ else:
+ # Stack the first intermediate features for all views
+ stacked_intermediate_features_1 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[0].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_1)
+ # Stack the second intermediate features for all views
+ stacked_intermediate_features_2 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[1].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_2)
+ # Stack the third intermediate features for all views
+ stacked_intermediate_features_3 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[2].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_3)
+ # Stack the last layer
+ stacked_final_features = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_final_features)
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Downstream task prediction
+ with torch.autocast("cuda", enabled=False):
+ # Run Prediction Heads & Post-Process Outputs
+ if self.pred_head_type == "linear":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadInput(last_feature=dense_head_inputs)
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ elif self.pred_head_type == "dpt":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadLayeredInput(
+ list_features=dense_head_inputs_list,
+ target_output_shape=img_shape,
+ )
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ elif self.pred_head_type == "dpt+pose":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadLayeredInput(
+ list_features=dense_head_inputs_list,
+ target_output_shape=img_shape,
+ )
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ pose_head_outputs = self.pose_head(
+ PredictionHeadInput(last_feature=dense_head_inputs_list[-1])
+ )
+ pose_final_outputs = self.pose_adaptor(
+ AdaptorInput(
+ adaptor_feature=pose_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+ scale_head_output = self.scale_head(
+ PredictionHeadTokenInput(
+ last_feature=final_info_sharing_multi_view_feat.additional_token_features
+ )
+ )
+ scale_final_output = self.scale_adaptor(
+ AdaptorInput(
+ adaptor_feature=scale_head_output.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ scale_final_output = scale_final_output.value.squeeze(
+ -1
+ ) # (B, 1, 1) -> (B, 1)
+
+ # Prepare the final scene representation for all views
+ if self.scene_rep_type in [
+ "pointmap",
+ "pointmap+confidence",
+ "pointmap+mask",
+ "pointmap+confidence+mask",
+ ]:
+ output_pts3d = dense_final_outputs.value
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_pts3d = output_pts3d.permute(0, 2, 3, 1).contiguous()
+ # Split the predicted pointmaps back to their respective views
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ elif self.scene_rep_type in [
+ "raymap+depth",
+ "raymap+depth+confidence",
+ "raymap+depth+mask",
+ "raymap+depth+confidence+mask",
+ ]:
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_scene_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted ray origins, directions, and depths along rays
+ output_ray_origins, output_ray_directions, output_depth_along_ray = (
+ output_scene_rep.split([3, 3, 1], dim=-1)
+ )
+ # Get the predicted pointmaps
+ output_pts3d = (
+ output_ray_origins + output_ray_directions * output_depth_along_ray
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_origins_per_view = output_ray_origins.chunk(num_views, dim=0)
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_origins": output_ray_origins_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ elif self.scene_rep_type in [
+ "raydirs+depth+pose",
+ "raydirs+depth+pose+confidence",
+ "raydirs+depth+pose+mask",
+ "raydirs+depth+pose+confidence+mask",
+ ]:
+ # Reshape output dense rep to (B * V, H, W, C)
+ output_dense_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted ray directions and depths along rays
+ output_ray_directions, output_depth_along_ray = output_dense_rep.split(
+ [3, 1], dim=-1
+ )
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the predicted pointmaps in world frame and camera frame
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ output_pts3d_cam = output_ray_directions * output_depth_along_ray
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "pts3d_cam": output_pts3d_cam_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "cam_trans": output_cam_translations_per_view[i]
+ * scale_final_output,
+ "cam_quats": output_cam_quats_per_view[i],
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ elif self.scene_rep_type in [
+ "campointmap+pose",
+ "campointmap+pose+confidence",
+ "campointmap+pose+mask",
+ "campointmap+pose+confidence+mask",
+ ]:
+ # Get the predicted camera frame pointmaps
+ output_pts3d_cam = dense_final_outputs.value
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_pts3d_cam = output_pts3d_cam.permute(0, 2, 3, 1).contiguous()
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the ray directions and depths along rays
+ output_depth_along_ray = torch.norm(
+ output_pts3d_cam, dim=-1, keepdim=True
+ )
+ output_ray_directions = output_pts3d_cam / output_depth_along_ray
+ # Get the predicted pointmaps in world frame
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "pts3d_cam": output_pts3d_cam_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "cam_trans": output_cam_translations_per_view[i]
+ * scale_final_output,
+ "cam_quats": output_cam_quats_per_view[i],
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ elif self.scene_rep_type in [
+ "pointmap+raydirs+depth+pose",
+ "pointmap+raydirs+depth+pose+confidence",
+ "pointmap+raydirs+depth+pose+mask",
+ "pointmap+raydirs+depth+pose+confidence+mask",
+ ]:
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_dense_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted pointmaps, ray directions and depths along rays
+ output_pts3d, output_ray_directions, output_depth_along_ray = (
+ output_dense_rep.split([3, 3, 1], dim=-1)
+ )
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the predicted pointmaps in camera frame
+ output_pts3d_cam = output_ray_directions * output_depth_along_ray
+ # Replace the predicted world-frame pointmaps if required
+ if self.pred_head_config["adaptor_config"][
+ "use_factored_predictions_for_global_pointmaps"
+ ]:
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "pts3d_cam": output_pts3d_cam_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "cam_trans": output_cam_translations_per_view[i]
+ * scale_final_output,
+ "cam_quats": output_cam_quats_per_view[i],
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ else:
+ raise ValueError(
+ f"Invalid scene_rep_type: {self.scene_rep_type}. \
+ Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \
+ 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \
+ 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \
+ 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']"
+ )
+
+ # Get the output confidences for all views (if available) and add them to the result
+ if "confidence" in self.scene_rep_type:
+ output_confidences = dense_final_outputs.confidence
+ # Reshape confidences to (B * V, H, W)
+ output_confidences = (
+ output_confidences.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ # Split the predicted confidences back to their respective views
+ output_confidences_per_view = output_confidences.chunk(num_views, dim=0)
+ # Add the confidences to the result
+ for i in range(num_views):
+ res[i]["conf"] = output_confidences_per_view[i]
+
+ # Get the output masks (and logits) for all views (if available) and add them to the result
+ if "mask" in self.scene_rep_type:
+ # Get the output masks
+ output_masks = dense_final_outputs.mask
+ # Reshape masks to (B * V, H, W)
+ output_masks = output_masks.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ # Threshold the masks at 0.5 to get binary masks (0: ambiguous, 1: non-ambiguous)
+ output_masks = output_masks > 0.5
+ # Split the predicted masks back to their respective views
+ output_masks_per_view = output_masks.chunk(num_views, dim=0)
+ # Get the output mask logits (for loss)
+ output_mask_logits = dense_final_outputs.logits
+ # Reshape mask logits to (B * V, H, W)
+ output_mask_logits = (
+ output_mask_logits.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ # Split the predicted mask logits back to their respective views
+ output_mask_logits_per_view = output_mask_logits.chunk(num_views, dim=0)
+ # Add the masks and logits to the result
+ for i in range(num_views):
+ res[i]["non_ambiguous_mask"] = output_masks_per_view[i]
+ res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i]
+
+ return res
diff --git a/mapanything/models/mapanything/modular_dust3r.py b/mapanything/models/mapanything/modular_dust3r.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d9d724fb401861d36e5322af51b64345ac76f72
--- /dev/null
+++ b/mapanything/models/mapanything/modular_dust3r.py
@@ -0,0 +1,468 @@
+"""
+Modular DUSt3R class defined using UniCeption modules.
+"""
+
+from typing import Callable, Dict
+
+import torch
+import torch.nn as nn
+
+from uniception.models.encoders import encoder_factory, ViTEncoderInput
+from uniception.models.info_sharing.alternating_attention_transformer import (
+ MultiViewAlternatingAttentionTransformer,
+ MultiViewAlternatingAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.base import MultiViewTransformerInput
+from uniception.models.info_sharing.cross_attention_transformer import (
+ MultiViewCrossAttentionTransformer,
+ MultiViewCrossAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.global_attention_transformer import (
+ MultiViewGlobalAttentionTransformer,
+ MultiViewGlobalAttentionTransformerIFR,
+)
+from uniception.models.libs.croco.pos_embed import RoPE2D
+from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor
+from uniception.models.prediction_heads.base import (
+ AdaptorInput,
+ PredictionHeadInput,
+ PredictionHeadLayeredInput,
+)
+from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
+from uniception.models.prediction_heads.linear import LinearFeature
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+class ModularDUSt3R(nn.Module):
+ "Modular DUSt3R model class."
+
+ def __init__(
+ self,
+ name: str,
+ encoder_config: Dict,
+ info_sharing_config: Dict,
+ pred_head_config: Dict,
+ pretrained_checkpoint_path: str = None,
+ load_specific_pretrained_submodules: bool = False,
+ specific_pretrained_submodules: list = [],
+ torch_hub_force_reload: bool = False,
+ *args,
+ **kwargs,
+ ):
+ """
+ Two-view model containing siamese encoders followed by a two-view attention transformer and respective downstream heads.
+ The goal is to output scene representation directly, both outputs in view1's frame (hence the asymmetry).
+
+ Args:
+ name (str): Name of the model.
+ encoder_config (Dict): Configuration for the encoder.
+ info_sharing_config (Dict): Configuration for the two-view attention transformer.
+ pred_head_config (Dict): Configuration for the prediction heads.
+ pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
+ load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False)
+ specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: [])
+ torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False)
+ """
+ super().__init__(*args, **kwargs)
+
+ # Initalize the attributes
+ self.name = name
+ self.encoder_config = encoder_config
+ self.info_sharing_config = info_sharing_config
+ self.pred_head_config = pred_head_config
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
+ self.load_specific_pretrained_submodules = load_specific_pretrained_submodules
+ self.specific_pretrained_submodules = specific_pretrained_submodules
+ self.torch_hub_force_reload = torch_hub_force_reload
+ self.class_init_args = {
+ "name": self.name,
+ "encoder_config": self.encoder_config,
+ "info_sharing_config": self.info_sharing_config,
+ "pred_head_config": self.pred_head_config,
+ "pretrained_checkpoint_path": self.pretrained_checkpoint_path,
+ "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules,
+ "specific_pretrained_submodules": self.specific_pretrained_submodules,
+ "torch_hub_force_reload": self.torch_hub_force_reload,
+ }
+
+ # Get relevant parameters from the configs
+ custom_positional_encoding = info_sharing_config["custom_positional_encoding"]
+ self.info_sharing_type = info_sharing_config["model_type"]
+ self.info_sharing_return_type = info_sharing_config["model_return_type"]
+ self.pred_head_type = pred_head_config["type"]
+
+ # Initialize Encoder
+ if self.encoder_config["uses_torch_hub"]:
+ self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
+ del self.encoder_config["uses_torch_hub"]
+ self.encoder = encoder_factory(**self.encoder_config)
+
+ # Initialize Custom Positional Encoding if required
+ if custom_positional_encoding is not None:
+ if isinstance(custom_positional_encoding, str):
+ print(
+ f"Using custom positional encoding for multi-view cross attention transformer: {custom_positional_encoding}"
+ )
+ if custom_positional_encoding.startswith("RoPE"):
+ rope_freq = float(custom_positional_encoding[len("RoPE") :])
+ print(f"RoPE frequency: {rope_freq}")
+ self.custom_positional_encoding = RoPE2D(freq=rope_freq)
+ else:
+ raise ValueError(
+ f"Invalid custom_positional_encoding: {custom_positional_encoding}."
+ )
+ elif isinstance(custom_positional_encoding, Callable):
+ print(
+ "Using callable function as custom positional encoding for multi-view cross attention transformer."
+ )
+ self.custom_positional_encoding = custom_positional_encoding
+ else:
+ self.custom_positional_encoding = None
+
+ # Add dependecies to info_sharing_config
+ info_sharing_config["module_args"]["input_embed_dim"] = (
+ self.encoder.enc_embed_dim
+ )
+ info_sharing_config["module_args"]["custom_positional_encoding"] = (
+ self.custom_positional_encoding
+ )
+
+ # Initialize Multi-View Transformer
+ if self.info_sharing_return_type == "no_intermediate_features":
+ # Returns only normalized last layer features
+ # Intialize multi-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ elif self.info_sharing_return_type == "intermediate_features":
+ # Returns intermediate features and normalized last layer features
+ # Initialize mulit-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ # Assess if the DPT needs to use encoder features
+ if len(self.info_sharing.indices) == 2:
+ self.use_encoder_features_for_dpt = True
+ elif len(self.info_sharing.indices) == 3:
+ self.use_encoder_features_for_dpt = False
+ else:
+ raise ValueError(
+ "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices."
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']"
+ )
+
+ # Add dependencies to prediction head config
+ pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size
+ if self.pred_head_type == "linear":
+ pred_head_config["feature_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ elif self.pred_head_type == "dpt":
+ if self.use_encoder_features_for_dpt:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.encoder.enc_embed_dim
+ ] + [self.info_sharing.dim] * 3
+ else:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.info_sharing.dim
+ ] * 4
+ pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[
+ "feature_head"
+ ]["feature_dim"]
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt']"
+ )
+
+ # Initialize Prediction Heads
+ if self.pred_head_type == "linear":
+ # Initialize Prediction Head 1
+ self.head1 = LinearFeature(**pred_head_config["feature_head"])
+ # Initialize Prediction Head 2
+ self.head2 = LinearFeature(**pred_head_config["feature_head"])
+ elif self.pred_head_type == "dpt":
+ # Initialze Predction Head 1
+ self.dpt_feature_head1 = DPTFeature(**pred_head_config["feature_head"])
+ self.dpt_regressor_head1 = DPTRegressionProcessor(
+ **pred_head_config["regressor_head"]
+ )
+ self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1)
+ # Initialize Prediction Head 2
+ self.dpt_feature_head2 = DPTFeature(**pred_head_config["feature_head"])
+ self.dpt_regressor_head2 = DPTRegressionProcessor(
+ **pred_head_config["regressor_head"]
+ )
+ self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2)
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt']"
+ )
+
+ # Initialize Final Output Adaptor
+ if pred_head_config["adaptor_type"] == "pointmap+confidence":
+ self.adaptor = PointMapWithConfidenceAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap"
+ else:
+ raise ValueError(
+ f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. Valid options: ['pointmap+confidence']"
+ )
+
+ # Load pretrained weights
+ if self.pretrained_checkpoint_path is not None:
+ if not self.load_specific_pretrained_submodules:
+ print(
+ f"Loading pretrained weights from {self.pretrained_checkpoint_path} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ print(self.load_state_dict(ckpt["model"]))
+ else:
+ print(
+ f"Loading pretrained weights from {self.pretrained_checkpoint_path} for specific submodules: {specific_pretrained_submodules} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ filtered_ckpt = {}
+ for ckpt_key, ckpt_value in ckpt["model"].items():
+ for submodule in specific_pretrained_submodules:
+ if ckpt_key.startswith(submodule):
+ filtered_ckpt[ckpt_key] = ckpt_value
+ print(self.load_state_dict(filtered_ckpt, strict=False))
+
+ def _encode_image_pairs(self, img1, img2, data_norm_type):
+ "Encode two different batches of images (each batch can have different image shape)"
+ if img1.shape[-2:] == img2.shape[-2:]:
+ encoder_input = ViTEncoderInput(
+ image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type
+ )
+ encoder_output = self.encoder(encoder_input)
+ out, out2 = encoder_output.features.chunk(2, dim=0)
+ else:
+ encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type)
+ out = self.encoder(encoder_input)
+ out = out.features
+ encoder_input2 = ViTEncoderInput(image=img2, data_norm_type=data_norm_type)
+ out2 = self.encoder(encoder_input2)
+ out2 = out2.features
+
+ return out, out2
+
+ def _encode_symmetrized(self, view1, view2):
+ "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input"
+ img1 = view1["img"]
+ img2 = view2["img"]
+ if isinstance(view1["data_norm_type"], list):
+ assert all(
+ [x == view1["data_norm_type"][0] for x in view1["data_norm_type"]]
+ ), "All data_norm_type values should be the same in the list."
+ data_norm_type = view1["data_norm_type"][0]
+ elif isinstance(view1["data_norm_type"], str):
+ data_norm_type = view1["data_norm_type"]
+ else:
+ raise ValueError(
+ f"Invalid data_norm_type: {view1['data_norm_type']}. Should be either a list with all same values or a string."
+ )
+ feat1, feat2 = self._encode_image_pairs(
+ img1, img2, data_norm_type=data_norm_type
+ )
+
+ return feat1, feat2
+
+ def _downstream_head(self, head_num, decout, img_shape):
+ "Run the respective prediction heads"
+ head = getattr(self, f"head{head_num}")
+ if self.pred_head_type == "linear":
+ head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
+ elif self.pred_head_type == "dpt":
+ head_input = PredictionHeadLayeredInput(
+ list_features=decout[f"{head_num}"], target_output_shape=img_shape
+ )
+
+ return head(head_input)
+
+ def forward(self, views):
+ """
+ Forward pass performing the following operations:
+ 1. Encodes the two input views (images).
+ 2. Combines the encoded features using a two-view attention transformer.
+ 3. Passes the combined features through the respective prediction heads.
+ 4. Returns the processed final outputs for both views.
+
+ Args:
+ views (List(dict)): A list of size two whose elements are:
+ view1 (dict): Dictionary containing the first view's images and instance information.
+ "img" is a required key and value is a tensor of shape (B, C, H, W).
+ view2 (dict): Dictionary containing the second view's images and instance information.
+ "img" is a required key and value is a tensor of shape (B, C, H, W).
+
+ Returns:
+ List[dict, dict]: A list containing the final outputs for both views.
+ """
+ # Get input shapes
+ view1 = views[0]
+ view2 = views[1]
+ _, _, height1, width1 = view1["img"].shape
+ _, _, height2, width2 = view2["img"].shape
+ shape1 = (int(height1), int(width1))
+ shape2 = (int(height2), int(width2))
+
+ if "img_encoder_feats" in view1 and "img_encoder_feats" in view2:
+ # Reuse the pre-computed image features for the two views
+ feat1 = view1["img_encoder_feats"]
+ feat2 = view2["img_encoder_feats"]
+ else:
+ # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
+ feat1, feat2 = self._encode_symmetrized(view1, view2)
+
+ # Combine all images into view-centric representation
+ info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2])
+ if self.info_sharing_return_type == "no_intermediate_features":
+ final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
+ elif self.info_sharing_return_type == "intermediate_features":
+ (
+ final_info_sharing_multi_view_feat,
+ intermediate_info_sharing_multi_view_feat,
+ ) = self.info_sharing(info_sharing_input)
+
+ if self.pred_head_type == "linear":
+ # Define feature dictionary for linear head
+ info_sharing_outputs = {
+ "1": final_info_sharing_multi_view_feat.features[0].float(),
+ "2": final_info_sharing_multi_view_feat.features[1].float(),
+ }
+ elif self.pred_head_type == "dpt":
+ # Define feature dictionary for DPT head
+ if self.use_encoder_features_for_dpt:
+ info_sharing_outputs = {
+ "1": [
+ feat1.float(),
+ intermediate_info_sharing_multi_view_feat[0]
+ .features[0]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[1]
+ .features[0]
+ .float(),
+ final_info_sharing_multi_view_feat.features[0].float(),
+ ],
+ "2": [
+ feat2.float(),
+ intermediate_info_sharing_multi_view_feat[0]
+ .features[1]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[1]
+ .features[1]
+ .float(),
+ final_info_sharing_multi_view_feat.features[1].float(),
+ ],
+ }
+ else:
+ info_sharing_outputs = {
+ "1": [
+ intermediate_info_sharing_multi_view_feat[0]
+ .features[0]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[1]
+ .features[0]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[2]
+ .features[0]
+ .float(),
+ final_info_sharing_multi_view_feat.features[0].float(),
+ ],
+ "2": [
+ intermediate_info_sharing_multi_view_feat[0]
+ .features[1]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[1]
+ .features[1]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[2]
+ .features[1]
+ .float(),
+ final_info_sharing_multi_view_feat.features[1].float(),
+ ],
+ }
+
+ # Downstream task prediction
+ with torch.autocast("cuda", enabled=False):
+ # Prediction heads
+ head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
+ head_output2 = self._downstream_head(2, info_sharing_outputs, shape2)
+
+ # Post-process outputs
+ final_output1 = self.adaptor(
+ AdaptorInput(
+ adaptor_feature=head_output1.decoded_channels,
+ output_shape_hw=shape1,
+ )
+ )
+ final_output2 = self.adaptor(
+ AdaptorInput(
+ adaptor_feature=head_output2.decoded_channels,
+ output_shape_hw=shape2,
+ )
+ )
+
+ # Reshape final scene representation to (B, H, W, C)
+ final_scene_rep1 = final_output1.value.permute(0, 2, 3, 1).contiguous()
+ final_scene_rep2 = final_output2.value.permute(0, 2, 3, 1).contiguous()
+
+ # Convert output scene representation to pointmaps
+ if self.scene_rep_type == "pointmap":
+ output_pts3d1 = final_scene_rep1
+ output_pts3d2 = final_scene_rep2
+ else:
+ raise ValueError(f"Invalid scene_rep_type: {self.scene_rep_type}.")
+
+ # Reshape confidence to (B, H, W, 1)
+ output_conf1 = (
+ final_output1.confidence.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ output_conf2 = (
+ final_output2.confidence.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+
+ # Convert outputs to dictionary
+ res1 = {
+ "pts3d": output_pts3d1,
+ "conf": output_conf1,
+ }
+ res2 = {
+ "pts3d": output_pts3d2,
+ "conf": output_conf2,
+ }
+ res = [res1, res2]
+
+ return res
diff --git a/mapanything/train/__init__.py b/mapanything/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/train/losses.py b/mapanything/train/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..85c727d8d91ecfab3697193372d72b9e759886d5
--- /dev/null
+++ b/mapanything/train/losses.py
@@ -0,0 +1,4786 @@
+"""
+Multi-view geometric losses for training 3D reconstruction models.
+
+References: DUSt3R & MASt3R
+"""
+
+import math
+from copy import copy, deepcopy
+
+import einops as ein
+import torch
+import torch.nn as nn
+
+from mapanything.utils.geometry import (
+ angle_diff_vec3,
+ apply_log_to_norm,
+ closed_form_pose_inverse,
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ geotrf,
+ normalize_multiple_pointclouds,
+ quaternion_inverse,
+ quaternion_multiply,
+ quaternion_to_rotation_matrix,
+ transform_pose_using_quats_and_trans_2_to_1,
+)
+
+
+def get_loss_terms_and_details(
+ losses_dict, valid_masks, self_name, n_views, flatten_across_image_only
+):
+ """
+ Helper function to generate loss terms and details for different loss types.
+
+ Args:
+ losses_dict (dict): Dictionary mapping loss types to their values.
+ Format: {
+ 'loss_type': {
+ 'values': list_of_loss_tensors or single_tensor,
+ 'use_mask': bool,
+ 'is_multi_view': bool
+ }
+ }
+ valid_masks (list): List of valid masks for each view.
+ self_name (str): Name of the loss class.
+ n_views (int): Number of views.
+ flatten_across_image_only (bool): Whether flattening was done across image only.
+
+ Returns:
+ tuple: (loss_terms, details) where loss_terms is a list of tuples (loss, mask, type)
+ and details is a dictionary of loss details.
+ """
+ loss_terms = []
+ details = {}
+
+ for loss_type, loss_info in losses_dict.items():
+ values = loss_info["values"]
+ use_mask = loss_info["use_mask"]
+ is_multi_view = loss_info["is_multi_view"]
+ if is_multi_view:
+ # Handle multi-view losses (list of tensors)
+ view_loss_details = []
+ for i in range(n_views):
+ mask = valid_masks[i] if use_mask else None
+ loss_terms.append((values[i], mask, loss_type))
+
+ # Add details for individual view
+ if not flatten_across_image_only or not use_mask:
+ values_after_masking = values[i]
+ else:
+ values_after_masking = values[i][mask]
+
+ if values_after_masking.numel() > 0:
+ view_loss_detail = float(values_after_masking.mean())
+ if view_loss_detail > 0:
+ details[f"{self_name}_{loss_type}_view{i + 1}"] = (
+ view_loss_detail
+ )
+ view_loss_details.append(view_loss_detail)
+ # Add average across views
+ if len(view_loss_details) > 0:
+ details[f"{self_name}_{loss_type}_avg"] = sum(view_loss_details) / len(
+ view_loss_details
+ )
+ else:
+ # Handle single tensor losses
+ if values is not None:
+ loss_terms.append((values, None, loss_type))
+ if values.numel() > 0:
+ loss_detail = float(values.mean())
+ if loss_detail > 0:
+ details[f"{self_name}_{loss_type}"] = loss_detail
+
+ return loss_terms, details
+
+
+def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
+ if beta == 0:
+ return err
+ else:
+ return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
+
+
+def compute_normal_loss(points, gt_points, mask):
+ """
+ Compute the normal loss between the predicted and ground truth points.
+ References:
+ https://github.com/microsoft/MoGe/blob/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/train/losses.py#L205
+
+ Args:
+ points (torch.Tensor): Predicted points. Shape: (..., H, W, 3).
+ gt_points (torch.Tensor): Ground truth points. Shape: (..., H, W, 3).
+ mask (torch.Tensor): Mask indicating valid points. Shape: (..., H, W).
+
+ Returns:
+ torch.Tensor: Normal loss.
+ """
+ height, width = points.shape[-3:-1]
+
+ leftup, rightup, leftdown, rightdown = (
+ points[..., :-1, :-1, :],
+ points[..., :-1, 1:, :],
+ points[..., 1:, :-1, :],
+ points[..., 1:, 1:, :],
+ )
+ upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1)
+ leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1)
+ downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1)
+ rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1)
+
+ gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = (
+ gt_points[..., :-1, :-1, :],
+ gt_points[..., :-1, 1:, :],
+ gt_points[..., 1:, :-1, :],
+ gt_points[..., 1:, 1:, :],
+ )
+ gt_upxleft = torch.cross(
+ gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1
+ )
+ gt_leftxdown = torch.cross(
+ gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1
+ )
+ gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1)
+ gt_rightxup = torch.cross(
+ gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1
+ )
+
+ mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = (
+ mask[..., :-1, :-1],
+ mask[..., :-1, 1:],
+ mask[..., 1:, :-1],
+ mask[..., 1:, 1:],
+ )
+ mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown
+ mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup
+ mask_downxright = mask_leftdown & mask_rightup & mask_leftup
+ mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown
+
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3)
+
+ loss = (
+ mask_upxleft
+ * _smooth(
+ angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE),
+ beta=BETA_RAD,
+ )
+ + mask_leftxdown
+ * _smooth(
+ angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE),
+ beta=BETA_RAD,
+ )
+ + mask_downxright
+ * _smooth(
+ angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE),
+ beta=BETA_RAD,
+ )
+ + mask_rightxup
+ * _smooth(
+ angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE),
+ beta=BETA_RAD,
+ )
+ )
+
+ total_valid_mask = mask_upxleft | mask_leftxdown | mask_downxright | mask_rightxup
+ valid_count = total_valid_mask.sum()
+ if valid_count > 0:
+ loss = loss.sum() / (valid_count * (4 * max(points.shape[-3:-1])))
+ else:
+ loss = 0 * loss.sum()
+
+ return loss
+
+
+def compute_gradient_loss(prediction, gt_target, mask):
+ """
+ Compute the gradient loss between the prediction and GT target at valid points.
+ References:
+ https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss
+ https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py
+
+ Args:
+ prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C).
+ gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C).
+ mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W).
+ """
+ # Expand mask to match number of channels in prediction
+ mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1])
+ summed_mask = torch.sum(mask, (1, 2, 3))
+
+ # Compute the gradient of the prediction and GT target
+ diff = prediction - gt_target
+ diff = torch.mul(mask, diff)
+
+ # Gradient in x direction
+ grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
+ mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
+ grad_x = torch.mul(mask_x, grad_x)
+
+ # Gradient in y direction
+ grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
+ mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
+ grad_y = torch.mul(mask_y, grad_y)
+
+ # Clamp the outlier gradients
+ grad_x = grad_x.clamp(max=100)
+ grad_y = grad_y.clamp(max=100)
+
+ # Compute the total loss
+ image_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3))
+ num_valid_pixels = torch.sum(summed_mask)
+ if num_valid_pixels > 0:
+ image_loss = torch.sum(image_loss) / num_valid_pixels
+ else:
+ image_loss = 0 * torch.sum(image_loss)
+
+ return image_loss
+
+
+def compute_gradient_matching_loss(prediction, gt_target, mask, scales=4):
+ """
+ Compute the multi-scale gradient matching loss between the prediction and GT target at valid points.
+ This loss biases discontinuities to be sharp and to coincide with discontinuities in the ground truth.
+ More info in MiDAS: https://arxiv.org/pdf/1907.01341.pdf; Equation 11
+ References:
+ https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss
+ https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py
+
+ Args:
+ prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C).
+ gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C).
+ mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W).
+ scales (int): Number of scales to compute the loss at. Default: 4.
+ """
+ # Define total loss
+ total_loss = 0.0
+
+ # Compute the gradient loss at different scales
+ for scale in range(scales):
+ step = pow(2, scale)
+ grad_loss = compute_gradient_loss(
+ prediction[:, ::step, ::step],
+ gt_target[:, ::step, ::step],
+ mask[:, ::step, ::step],
+ )
+ total_loss += grad_loss
+
+ return total_loss
+
+
+def Sum(*losses_and_masks):
+ """
+ Aggregates multiple losses into a single loss value or returns the original losses.
+
+ Args:
+ *losses_and_masks: Variable number of tuples, each containing (loss, mask, rep_type)
+ - loss: Tensor containing loss values
+ - mask: Mask indicating valid pixels/regions
+ - rep_type: String indicating the type of representation (e.g., 'pts3d', 'depth')
+
+ Returns:
+ If the first loss has dimensions > 0:
+ Returns the original list of (loss, mask, rep_type) tuples
+ Otherwise:
+ Returns a scalar tensor that is the sum of all loss values
+ """
+ loss, mask, rep_type = losses_and_masks[0]
+ if loss.ndim > 0:
+ # we are actually returning the loss for every pixels
+ return losses_and_masks
+ else:
+ # we are returning the global loss
+ for loss2, mask2, rep_type2 in losses_and_masks[1:]:
+ loss = loss + loss2
+ return loss
+
+
+class BaseCriterion(nn.Module):
+ "Base Criterion to support different reduction methods"
+
+ def __init__(self, reduction="mean"):
+ super().__init__()
+ self.reduction = reduction
+
+
+class LLoss(BaseCriterion):
+ "L-norm loss"
+
+ def forward(self, a, b, **kwargs):
+ assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 4, (
+ f"Bad shape = {a.shape}"
+ )
+ dist = self.distance(a, b, **kwargs)
+ assert dist.ndim == a.ndim - 1 # one dimension less
+ if self.reduction == "none":
+ return dist
+ if self.reduction == "sum":
+ return dist.sum()
+ if self.reduction == "mean":
+ return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
+ raise ValueError(f"bad {self.reduction=} mode")
+
+ def distance(self, a, b, **kwargs):
+ raise NotImplementedError()
+
+
+class L1Loss(LLoss):
+ "L1 distance"
+
+ def distance(self, a, b, **kwargs):
+ return torch.abs(a - b).sum(dim=-1)
+
+
+class L2Loss(LLoss):
+ "Euclidean (L2 Norm) distance"
+
+ def distance(self, a, b, **kwargs):
+ return torch.norm(a - b, dim=-1)
+
+
+class GenericLLoss(LLoss):
+ "Criterion that supports different L-norms"
+
+ def distance(self, a, b, loss_type, **kwargs):
+ if loss_type == "l1":
+ # L1 distance
+ return torch.abs(a - b).sum(dim=-1)
+ elif loss_type == "l2":
+ # Euclidean (L2 norm) distance
+ return torch.norm(a - b, dim=-1)
+ else:
+ raise ValueError(
+ f"Unsupported loss type: {loss_type}. Supported types are 'l1' and 'l2'."
+ )
+
+
+class FactoredLLoss(LLoss):
+ "Criterion that supports different L-norms for the factored loss functions"
+
+ def __init__(
+ self,
+ reduction="mean",
+ points_loss_type="l2",
+ depth_loss_type="l1",
+ ray_directions_loss_type="l1",
+ pose_quats_loss_type="l1",
+ pose_trans_loss_type="l1",
+ scale_loss_type="l1",
+ ):
+ super().__init__(reduction)
+ self.points_loss_type = points_loss_type
+ self.depth_loss_type = depth_loss_type
+ self.ray_directions_loss_type = ray_directions_loss_type
+ self.pose_quats_loss_type = pose_quats_loss_type
+ self.pose_trans_loss_type = pose_trans_loss_type
+ self.scale_loss_type = scale_loss_type
+
+ def _distance(self, a, b, loss_type):
+ if loss_type == "l1":
+ # L1 distance
+ return torch.abs(a - b).sum(dim=-1)
+ elif loss_type == "l2":
+ # Euclidean (L2 norm) distance
+ return torch.norm(a - b, dim=-1)
+ else:
+ raise ValueError(f"Unsupported loss type: {loss_type}.")
+
+ def distance(self, a, b, factor, **kwargs):
+ if factor == "points":
+ return self._distance(a, b, self.points_loss_type)
+ elif factor == "depth":
+ return self._distance(a, b, self.depth_loss_type)
+ elif factor == "ray_directions":
+ return self._distance(a, b, self.ray_directions_loss_type)
+ elif factor == "pose_quats":
+ return self._distance(a, b, self.pose_quats_loss_type)
+ elif factor == "pose_trans":
+ return self._distance(a, b, self.pose_trans_loss_type)
+ elif factor == "scale":
+ return self._distance(a, b, self.scale_loss_type)
+ else:
+ raise ValueError(f"Unsupported factor type: {factor}.")
+
+
+class RobustRegressionLoss(LLoss):
+ """
+ Generalized Robust Loss introduced in https://arxiv.org/abs/1701.03077.
+ """
+
+ def __init__(self, alpha=0.5, scaling_c=0.25, reduction="mean"):
+ """
+ Initialize the Robust Regression Loss.
+
+ Args:
+ alpha (float): Shape parameter controlling the robustness of the loss.
+ Lower values make the loss more robust to outliers. Default: 0.5.
+ scaling_c (float): Scale parameter controlling the transition between
+ quadratic and robust behavior. Default: 0.1.
+ reduction (str): Specifies the reduction to apply to the output:
+ 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+ super().__init__(reduction)
+ self.alpha = alpha
+ self.scaling_c = scaling_c
+
+ def distance(self, a, b, **kwargs):
+ error_scaled = torch.sum(((a - b) / self.scaling_c) ** 2, dim=-1)
+ robust_loss = (abs(self.alpha - 2) / self.alpha) * (
+ torch.pow((error_scaled / abs(self.alpha - 2)) + 1, self.alpha / 2) - 1
+ )
+ return robust_loss
+
+
+class BCELoss(BaseCriterion):
+ """Binary Cross Entropy loss"""
+
+ def forward(self, predicted_logits, reference_mask):
+ """
+ Args:
+ predicted_logits: (B, H, W) tensor of predicted logits for the mask
+ reference_mask: (B, H, W) tensor of reference mask
+
+ Returns:
+ loss: scalar tensor of the BCE loss
+ """
+ bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(
+ predicted_logits, reference_mask.float()
+ )
+
+ return bce_loss
+
+
+class Criterion(nn.Module):
+ """
+ Base class for all criterion modules that wrap a BaseCriterion.
+
+ This class serves as a wrapper around BaseCriterion objects, providing
+ additional functionality like naming and reduction mode control.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to wrap.
+ """
+
+ def __init__(self, criterion=None):
+ super().__init__()
+ assert isinstance(criterion, BaseCriterion), (
+ f"{criterion} is not a proper criterion!"
+ )
+ self.criterion = copy(criterion)
+
+ def get_name(self):
+ """
+ Returns a string representation of this criterion.
+
+ Returns:
+ str: A string containing the class name and the wrapped criterion.
+ """
+ return f"{type(self).__name__}({self.criterion})"
+
+ def with_reduction(self, mode="none"):
+ """
+ Creates a deep copy of this criterion with the specified reduction mode.
+
+ This method recursively sets the reduction mode for this criterion and
+ any chained MultiLoss criteria.
+
+ Args:
+ mode (str): The reduction mode to set. Default: "none".
+
+ Returns:
+ Criterion: A new criterion with the specified reduction mode.
+ """
+ res = loss = deepcopy(self)
+ while loss is not None:
+ assert isinstance(loss, Criterion)
+ loss.criterion.reduction = mode # make it return the loss for each sample
+ loss = loss._loss2 # we assume loss is a Multiloss
+ return res
+
+
+class MultiLoss(nn.Module):
+ """
+ Base class for combinable loss functions with automatic tracking of individual loss values.
+
+ This class enables easy combination of multiple loss functions through arithmetic operations:
+ loss = MyLoss1() + 0.1*MyLoss2()
+
+ The combined loss functions maintain their individual weights and the forward pass
+ automatically computes and aggregates all losses while tracking individual loss values.
+
+ Usage:
+ Inherit from this class and override get_name() and compute_loss() methods.
+
+ Attributes:
+ _alpha (float): Weight multiplier for this loss component.
+ _loss2 (MultiLoss): Reference to the next loss in the chain, if any.
+ """
+
+ def __init__(self):
+ """Initialize the MultiLoss with default weight of 1 and no chained loss."""
+ super().__init__()
+ self._alpha = 1
+ self._loss2 = None
+
+ def compute_loss(self, *args, **kwargs):
+ """
+ Compute the loss value for this specific loss component.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ torch.Tensor or tuple: Either the loss tensor or a tuple of (loss, details_dict).
+
+ Raises:
+ NotImplementedError: This method must be implemented by subclasses.
+ """
+ raise NotImplementedError()
+
+ def get_name(self):
+ """
+ Get the name of this loss component.
+
+ Returns:
+ str: The name of the loss.
+
+ Raises:
+ NotImplementedError: This method must be implemented by subclasses.
+ """
+ raise NotImplementedError()
+
+ def __mul__(self, alpha):
+ """
+ Multiply the loss by a scalar weight.
+
+ Args:
+ alpha (int or float): The weight to multiply the loss by.
+
+ Returns:
+ MultiLoss: A new loss object with the updated weight.
+
+ Raises:
+ AssertionError: If alpha is not a number.
+ """
+ assert isinstance(alpha, (int, float))
+ res = copy(self)
+ res._alpha = alpha
+ return res
+
+ __rmul__ = __mul__ # Support both loss*alpha and alpha*loss
+
+ def __add__(self, loss2):
+ """
+ Add another loss to this loss, creating a chain of losses.
+
+ Args:
+ loss2 (MultiLoss): Another loss to add to this one.
+
+ Returns:
+ MultiLoss: A new loss object representing the combined losses.
+
+ Raises:
+ AssertionError: If loss2 is not a MultiLoss.
+ """
+ assert isinstance(loss2, MultiLoss)
+ res = cur = copy(self)
+ # Find the end of the chain
+ while cur._loss2 is not None:
+ cur = cur._loss2
+ cur._loss2 = loss2
+ return res
+
+ def __repr__(self):
+ """
+ Create a string representation of the loss, including weights and chained losses.
+
+ Returns:
+ str: String representation of the loss.
+ """
+ name = self.get_name()
+ if self._alpha != 1:
+ name = f"{self._alpha:g}*{name}"
+ if self._loss2:
+ name = f"{name} + {self._loss2}"
+ return name
+
+ def forward(self, *args, **kwargs):
+ """
+ Compute the weighted loss and aggregate with any chained losses.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ tuple: A tuple containing:
+ - torch.Tensor: The total weighted loss.
+ - dict: Details about individual loss components.
+ """
+ loss = self.compute_loss(*args, **kwargs)
+ if isinstance(loss, tuple):
+ loss, details = loss
+ elif loss.ndim == 0:
+ details = {self.get_name(): float(loss)}
+ else:
+ details = {}
+ loss = loss * self._alpha
+
+ if self._loss2:
+ loss2, details2 = self._loss2(*args, **kwargs)
+ loss = loss + loss2
+ details |= details2
+
+ return loss, details
+
+
+class NonAmbiguousMaskLoss(Criterion, MultiLoss):
+ """
+ Loss on non-ambiguous mask prediction logits.
+ """
+
+ def __init__(self, criterion):
+ super().__init__(criterion)
+
+ def compute_loss(self, batch, preds, **kw):
+ """
+ Args:
+ batch: list of dicts with the gt data
+ preds: list of dicts with the predictions
+
+ Returns:
+ loss: Sum class of the lossses for N-views and the loss details
+ """
+ # Init loss list to keep track of individual losses for each view
+ loss_list = []
+ mask_loss_details = {}
+ mask_loss_total = 0
+ self_name = type(self).__name__
+
+ # Loop over the views
+ for view_idx, (gt, pred) in enumerate(zip(batch, preds)):
+ # Get the GT non-ambiguous masks
+ gt_non_ambiguous_mask = gt["non_ambiguous_mask"]
+
+ # Get the predicted non-ambiguous mask logits
+ pred_non_ambiguous_mask_logits = pred["non_ambiguous_mask_logits"]
+
+ # Compute the loss for the current view
+ loss = self.criterion(pred_non_ambiguous_mask_logits, gt_non_ambiguous_mask)
+
+ # Add the loss to the list
+ loss_list.append((loss, None, "non_ambiguous_mask"))
+
+ # Add the loss details to the dictionary
+ mask_loss_details[f"{self_name}_mask_view{view_idx + 1}"] = float(loss)
+ mask_loss_total += float(loss)
+
+ # Compute the average loss across all views
+ mask_loss_details[f"{self_name}_mask_avg"] = mask_loss_total / len(batch)
+
+ return Sum(*loss_list), (mask_loss_details | {})
+
+
+class ConfLoss(MultiLoss):
+ """
+ Applies confidence-weighted regression loss using model-predicted confidence values.
+
+ The confidence-weighted loss has the form:
+ conf_loss = raw_loss * conf - alpha * log(conf)
+
+ Where:
+ - raw_loss is the original per-pixel loss
+ - conf is the predicted confidence (higher values = higher confidence)
+ - alpha is a hyperparameter controlling the regularization strength
+
+ This loss can be selectively applied to specific loss components in factored and multi-view settings.
+ """
+
+ def __init__(self, pixel_loss, alpha=1, loss_set_indices=None):
+ """
+ Args:
+ pixel_loss (MultiLoss): The pixel-level regression loss to be used.
+ alpha (float): Hyperparameter controlling the confidence regularization strength.
+ loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to.
+ Each index selects a specific loss set across all views (with the same rep_type).
+ If None, defaults to [0] which applies to the first loss set only.
+ """
+ super().__init__()
+ assert alpha > 0
+ self.alpha = alpha
+ self.pixel_loss = pixel_loss.with_reduction("none")
+ self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices
+
+ def get_name(self):
+ return f"ConfLoss({self.pixel_loss})"
+
+ def get_conf_log(self, x):
+ return x, torch.log(x)
+
+ def compute_loss(self, batch, preds, **kw):
+ # Init loss list and details
+ total_loss = 0
+ conf_loss_details = {}
+ running_avg_dict = {}
+ self_name = type(self.pixel_loss).__name__
+ n_views = len(batch)
+
+ # Compute per-pixel loss for each view
+ losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw)
+
+ # Select specific loss sets based on indices
+ selected_losses = []
+ processed_indices = set()
+ for idx in self.loss_set_indices:
+ start_idx = idx * n_views
+ end_idx = min((idx + 1) * n_views, len(losses))
+ selected_losses.extend(losses[start_idx:end_idx])
+ processed_indices.update(range(start_idx, end_idx))
+
+ # Process selected losses with confidence weighting
+ for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses):
+ view_idx = loss_idx % n_views # Map to corresponding view index
+
+ if loss.numel() == 0:
+ # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ continue
+
+ # Get the confidence and log confidence
+ if (
+ hasattr(self.pixel_loss, "flatten_across_image_only")
+ and self.pixel_loss.flatten_across_image_only
+ ):
+ # Reshape confidence to match the flattened dimensions
+ conf_reshaped = preds[view_idx]["conf"].view(
+ preds[view_idx]["conf"].shape[0], -1
+ )
+ conf, log_conf = self.get_conf_log(conf_reshaped[msk])
+ loss = loss[msk]
+ else:
+ conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk])
+
+ # Weight the loss by the confidence
+ conf_loss = loss * conf - self.alpha * log_conf
+
+ # Only add to total loss and store details if there are valid elements
+ if conf_loss.numel() > 0:
+ conf_loss = conf_loss.mean()
+ total_loss = total_loss + conf_loss
+
+ # Store details
+ conf_loss_details[
+ f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"
+ ] = float(conf_loss)
+
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_{rep_type}_conf_loss_avg"
+ if avg_key not in conf_loss_details:
+ conf_loss_details[avg_key] = float(conf_loss)
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ] = 1
+ else:
+ valid_views = (
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ]
+ + 1
+ )
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ] = valid_views
+ conf_loss_details[avg_key] += (
+ float(conf_loss) - conf_loss_details[avg_key]
+ ) / valid_views
+
+ # Add unmodified losses for sets not in selected_losses
+ for idx, (loss, msk, rep_type) in enumerate(losses):
+ if idx not in processed_indices:
+ if msk is not None:
+ loss_after_masking = loss[msk]
+ else:
+ loss_after_masking = loss
+ if loss_after_masking.numel() > 0:
+ loss_mean = loss_after_masking.mean()
+ else:
+ # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ loss_mean = 0
+ total_loss = total_loss + loss_mean
+
+ return total_loss, dict(**conf_loss_details, **pixel_loss_details)
+
+
+class ExcludeTopNPercentPixelLoss(MultiLoss):
+ """
+ Pixel-level regression loss where for each instance in a batch the top N% of per-pixel loss values are ignored
+ for the mean loss computation.
+ Allows selecting which pixel-level regression loss sets to apply the exclusion to.
+ """
+
+ def __init__(
+ self,
+ pixel_loss,
+ top_n_percent=5,
+ apply_to_real_data_only=True,
+ loss_set_indices=None,
+ ):
+ """
+ Args:
+ pixel_loss (MultiLoss): The pixel-level regression loss to be used.
+ top_n_percent (float): The percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5.
+ apply_to_real_data_only (bool): Whether to apply the loss only to real world data. Default: True.
+ loss_set_indices (list or None): Indices of the loss sets to apply the exclusion to.
+ Each index selects a specific loss set across all views (with the same rep_type).
+ If None, defaults to [0] which applies to the first loss set only.
+ """
+ super().__init__()
+ self.pixel_loss = pixel_loss.with_reduction("none")
+ self.top_n_percent = top_n_percent
+ self.bottom_n_percent = 100 - top_n_percent
+ self.apply_to_real_data_only = apply_to_real_data_only
+ self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices
+
+ def get_name(self):
+ return f"ExcludeTopNPercentPixelLoss({self.pixel_loss})"
+
+ def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent):
+ """
+ Function to compute the mask for keeping the bottom n percent of per-pixel loss values.
+
+ Args:
+ tensor (torch.Tensor): The tensor containing the per-pixel loss values.
+ Shape: (B, N) where B is the batch size and N is the number of total pixels.
+ mask (torch.Tensor): The mask indicating valid pixels. Shape: (B, N).
+
+ Returns:
+ torch.Tensor: Flattened tensor containing the bottom n percent of per-pixel loss values.
+ """
+ B, N = tensor.shape
+
+ # Calculate the number of valid elements (where mask is True)
+ num_valid = mask.sum(dim=1)
+
+ # Calculate the number of elements to keep (n% of valid elements)
+ num_keep = (num_valid * bottom_n_percent / 100).long()
+
+ # Create a mask for the bottom n% elements
+ keep_mask = torch.arange(N, device=tensor.device).unsqueeze(
+ 0
+ ) < num_keep.unsqueeze(1)
+
+ # Create a tensor with inf where mask is False
+ masked_tensor = torch.where(
+ mask, tensor, torch.tensor(float("inf"), device=tensor.device)
+ )
+
+ # Sort the masked tensor along the N dimension
+ sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False)
+
+ # Get the bottom n% elements
+ bottom_n_percent_elements = sorted_tensor[keep_mask]
+
+ return bottom_n_percent_elements
+
+ def compute_loss(self, batch, preds, **kw):
+ # Compute per-pixel loss
+ losses, details = self.pixel_loss(batch, preds, **kw)
+ n_views = len(batch)
+
+ # Select specific loss sets based on indices
+ selected_losses = []
+ processed_indices = set()
+ for idx in self.loss_set_indices:
+ start_idx = idx * n_views
+ end_idx = min((idx + 1) * n_views, len(losses))
+ selected_losses.extend(losses[start_idx:end_idx])
+ processed_indices.update(range(start_idx, end_idx))
+
+ # Initialize total loss
+ total_loss = 0.0
+ loss_details = {}
+ running_avg_dict = {}
+ self_name = type(self.pixel_loss).__name__
+
+ # Process selected losses with top N percent exclusion
+ for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses):
+ view_idx = loss_idx % n_views # Map to corresponding view index
+
+ if loss.numel() == 0:
+ # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ continue
+
+ # Create empty list for current view's aggregated tensors
+ aggregated_losses = []
+
+ if self.apply_to_real_data_only:
+ # Get the synthetic and real world data mask
+ synthetic_mask = batch[view_idx]["is_synthetic"]
+ real_data_mask = ~batch[view_idx]["is_synthetic"]
+ else:
+ # Apply the filtering to all data
+ synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"])
+ real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"])
+
+ # Process synthetic data
+ if synthetic_mask.any():
+ synthetic_loss = loss[synthetic_mask]
+ synthetic_msk = msk[synthetic_mask]
+ aggregated_losses.append(synthetic_loss[synthetic_msk])
+
+ # Process real data
+ if real_data_mask.any():
+ real_loss = loss[real_data_mask]
+ real_msk = msk[real_data_mask]
+ real_bottom_n_percent_loss = self.keep_bottom_n_percent(
+ real_loss, real_msk, self.bottom_n_percent
+ )
+ aggregated_losses.append(real_bottom_n_percent_loss)
+
+ # Compute view loss
+ view_loss = torch.cat(aggregated_losses, dim=0)
+
+ # Only add to total loss and store details if there are valid elements
+ if view_loss.numel() > 0:
+ view_loss = view_loss.mean()
+ total_loss = total_loss + view_loss
+
+ # Store details
+ loss_details[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}"
+ ] = float(view_loss)
+
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg"
+ if avg_key not in loss_details:
+ loss_details[avg_key] = float(view_loss)
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ] = 1
+ else:
+ valid_views = (
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ]
+ + 1
+ )
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ] = valid_views
+ loss_details[avg_key] += (
+ float(view_loss) - loss_details[avg_key]
+ ) / valid_views
+
+ # Add unmodified losses for sets not in selected_losses
+ for idx, (loss, msk, rep_type) in enumerate(losses):
+ if idx not in processed_indices:
+ if msk is not None:
+ loss_after_masking = loss[msk]
+ else:
+ loss_after_masking = loss
+ if loss_after_masking.numel() > 0:
+ loss_mean = loss_after_masking.mean()
+ else:
+ # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ loss_mean = 0
+ total_loss = total_loss + loss_mean
+
+ return total_loss, dict(**loss_details, **details)
+
+
+class ConfAndExcludeTopNPercentPixelLoss(MultiLoss):
+ """
+ Combined loss that applies ConfLoss to one set of pixel-level regression losses
+ and ExcludeTopNPercentPixelLoss to another set of pixel-level regression losses.
+ """
+
+ def __init__(
+ self,
+ pixel_loss,
+ conf_alpha=1,
+ top_n_percent=5,
+ apply_to_real_data_only=True,
+ conf_loss_set_indices=None,
+ exclude_loss_set_indices=None,
+ ):
+ """
+ Args:
+ pixel_loss (MultiLoss): The pixel-level regression loss to be used.
+ conf_alpha (float): Alpha parameter for ConfLoss. Default: 1.
+ top_n_percent (float): Percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5.
+ apply_to_real_data_only (bool): Whether to apply the exclude loss only to real world data. Default: True.
+ conf_loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to.
+ Each index selects a specific loss set across all views (with the same rep_type).
+ If None, defaults to [0] which applies to the first loss set only.
+ exclude_loss_set_indices (list or None): Indices of the loss sets to apply top N percent exclusion to.
+ Each index selects a specific loss set across all views (with the same rep_type).
+ If None, defaults to [1] which applies to the second loss set only.
+ """
+ super().__init__()
+ self.pixel_loss = pixel_loss.with_reduction("none")
+ assert conf_alpha > 0
+ self.conf_alpha = conf_alpha
+ self.top_n_percent = top_n_percent
+ self.bottom_n_percent = 100 - top_n_percent
+ self.apply_to_real_data_only = apply_to_real_data_only
+ self.conf_loss_set_indices = (
+ [0] if conf_loss_set_indices is None else conf_loss_set_indices
+ )
+ self.exclude_loss_set_indices = (
+ [1] if exclude_loss_set_indices is None else exclude_loss_set_indices
+ )
+
+ def get_name(self):
+ return f"ConfAndExcludeTopNPercentPixelLoss({self.pixel_loss})"
+
+ def get_conf_log(self, x):
+ return x, torch.log(x)
+
+ def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent):
+ """
+ Function to compute the mask for keeping the bottom n percent of per-pixel loss values.
+ """
+ B, N = tensor.shape
+
+ # Calculate the number of valid elements (where mask is True)
+ num_valid = mask.sum(dim=1)
+
+ # Calculate the number of elements to keep (n% of valid elements)
+ num_keep = (num_valid * bottom_n_percent / 100).long()
+
+ # Create a mask for the bottom n% elements
+ keep_mask = torch.arange(N, device=tensor.device).unsqueeze(
+ 0
+ ) < num_keep.unsqueeze(1)
+
+ # Create a tensor with inf where mask is False
+ masked_tensor = torch.where(
+ mask, tensor, torch.tensor(float("inf"), device=tensor.device)
+ )
+
+ # Sort the masked tensor along the N dimension
+ sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False)
+
+ # Get the bottom n% elements
+ bottom_n_percent_elements = sorted_tensor[keep_mask]
+
+ return bottom_n_percent_elements
+
+ def compute_loss(self, batch, preds, **kw):
+ # Compute per-pixel loss
+ losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw)
+ n_views = len(batch)
+
+ # Select specific loss sets for confidence weighting
+ conf_selected_losses = []
+ conf_processed_indices = set()
+ for idx in self.conf_loss_set_indices:
+ start_idx = idx * n_views
+ end_idx = min((idx + 1) * n_views, len(losses))
+ conf_selected_losses.extend(losses[start_idx:end_idx])
+ conf_processed_indices.update(range(start_idx, end_idx))
+
+ # Select specific loss sets for top N percent exclusion
+ exclude_selected_losses = []
+ exclude_processed_indices = set()
+ for idx in self.exclude_loss_set_indices:
+ start_idx = idx * n_views
+ end_idx = min((idx + 1) * n_views, len(losses))
+ exclude_selected_losses.extend(losses[start_idx:end_idx])
+ exclude_processed_indices.update(range(start_idx, end_idx))
+
+ # Initialize total loss and details
+ total_loss = 0
+ loss_details = {}
+ running_avg_dict = {}
+ self_name = type(self.pixel_loss).__name__
+
+ # Process selected losses with confidence weighting
+ for loss_idx, (loss, msk, rep_type) in enumerate(conf_selected_losses):
+ view_idx = loss_idx % n_views # Map to corresponding view index
+
+ if loss.numel() == 0:
+ # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for conf loss", force=True)
+ continue
+
+ # Get the confidence and log confidence
+ if (
+ hasattr(self.pixel_loss, "flatten_across_image_only")
+ and self.pixel_loss.flatten_across_image_only
+ ):
+ # Reshape confidence to match the flattened dimensions
+ conf_reshaped = preds[view_idx]["conf"].view(
+ preds[view_idx]["conf"].shape[0], -1
+ )
+ conf, log_conf = self.get_conf_log(conf_reshaped[msk])
+ loss = loss[msk]
+ else:
+ conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk])
+
+ # Weight the loss by the confidence
+ conf_loss = loss * conf - self.conf_alpha * log_conf
+
+ # Only add to total loss and store details if there are valid elements
+ if conf_loss.numel() > 0:
+ conf_loss = conf_loss.mean()
+ total_loss = total_loss + conf_loss
+
+ # Store details
+ loss_details[f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"] = (
+ float(conf_loss)
+ )
+
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_{rep_type}_conf_loss_avg"
+ if avg_key not in loss_details:
+ loss_details[avg_key] = float(conf_loss)
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ] = 1
+ else:
+ valid_views = (
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ]
+ + 1
+ )
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ] = valid_views
+ loss_details[avg_key] += (
+ float(conf_loss) - loss_details[avg_key]
+ ) / valid_views
+
+ # Process selected losses with top N percent exclusion
+ for loss_idx, (loss, msk, rep_type) in enumerate(exclude_selected_losses):
+ view_idx = loss_idx % n_views # Map to corresponding view index
+
+ if loss.numel() == 0:
+ # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for exclude loss", force=True)
+ continue
+
+ # Create empty list for current view's aggregated tensors
+ aggregated_losses = []
+
+ if self.apply_to_real_data_only:
+ # Get the synthetic and real world data mask
+ synthetic_mask = batch[view_idx]["is_synthetic"]
+ real_data_mask = ~batch[view_idx]["is_synthetic"]
+ else:
+ # Apply the filtering to all data
+ synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"])
+ real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"])
+
+ # Process synthetic data
+ if synthetic_mask.any():
+ synthetic_loss = loss[synthetic_mask]
+ synthetic_msk = msk[synthetic_mask]
+ aggregated_losses.append(synthetic_loss[synthetic_msk])
+
+ # Process real data
+ if real_data_mask.any():
+ real_loss = loss[real_data_mask]
+ real_msk = msk[real_data_mask]
+ real_bottom_n_percent_loss = self.keep_bottom_n_percent(
+ real_loss, real_msk, self.bottom_n_percent
+ )
+ aggregated_losses.append(real_bottom_n_percent_loss)
+
+ # Compute view loss
+ view_loss = torch.cat(aggregated_losses, dim=0)
+
+ # Only add to total loss and store details if there are valid elements
+ if view_loss.numel() > 0:
+ view_loss = view_loss.mean()
+ total_loss = total_loss + view_loss
+
+ # Store details
+ loss_details[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}"
+ ] = float(view_loss)
+
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg"
+ if avg_key not in loss_details:
+ loss_details[avg_key] = float(view_loss)
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ] = 1
+ else:
+ valid_views = (
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ]
+ + 1
+ )
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ] = valid_views
+ loss_details[avg_key] += (
+ float(view_loss) - loss_details[avg_key]
+ ) / valid_views
+
+ # Add unmodified losses for sets not processed with either confidence or exclusion
+ all_processed_indices = conf_processed_indices.union(exclude_processed_indices)
+ for idx, (loss, msk, rep_type) in enumerate(losses):
+ if idx not in all_processed_indices:
+ if msk is not None:
+ loss_after_masking = loss[msk]
+ else:
+ loss_after_masking = loss
+ if loss_after_masking.numel() > 0:
+ loss_mean = loss_after_masking.mean()
+ else:
+ # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ loss_mean = 0
+ total_loss = total_loss + loss_mean
+
+ return total_loss, dict(**loss_details, **pixel_loss_details)
+
+
+class Regr3D(Criterion, MultiLoss):
+ """
+ Regression Loss for World Frame Pointmaps.
+ Asymmetric loss where view 1 is supposed to be the anchor.
+
+ For each view i:
+ Pi = RTi @ Di
+ lossi = (RTi1 @ pred_Di) - (RT1^-1 @ RTi @ Di)
+ where RT1 is the anchor view camera pose
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_mode="?avg_dis",
+ gt_scale=False,
+ ambiguous_loss_value=0,
+ max_metric_scale=False,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ ):
+ """
+ Initialize the loss criterion for World Frame Pointmaps.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis".
+ If prefixed with "?", normalization is only applied to non-metric scale data.
+ gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
+ If False, both GT and predictions are normalized independently. Default: False.
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
+ value, it will be treated as non-metric. Default: False (no limit).
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for pointmaps. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ """
+ super().__init__(criterion)
+ if norm_mode.startswith("?"):
+ # Do no norm pts from metric scale datasets
+ self.norm_all = False
+ self.norm_mode = norm_mode[1:]
+ else:
+ self.norm_all = True
+ self.norm_mode = norm_mode
+ self.gt_scale = gt_scale
+ self.ambiguous_loss_value = ambiguous_loss_value
+ self.max_metric_scale = max_metric_scale
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ n_views = len(batch)
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+
+ # Initialize lists to store points and masks
+ no_norm_gt_pts = []
+ valid_masks = []
+
+ # Process ground truth points and valid masks
+ for view_idx in range(n_views):
+ no_norm_gt_pts.append(
+ geotrf(in_camera0, batch[view_idx]["pts3d"])
+ ) # B,H,W,3
+ valid_masks.append(batch[view_idx]["valid_mask"].clone())
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for view_idx in range(n_views):
+ dis = no_norm_gt_pts[view_idx].norm(dim=-1) # (B, H, W)
+ valid_masks[view_idx] = valid_masks[view_idx] & (dis <= dist_clip)
+
+ # Get predicted points
+ no_norm_pr_pts = []
+ for view_idx in range(n_views):
+ no_norm_pr_pts.append(preds[view_idx]["pts3d"])
+
+ if not self.norm_all:
+ if self.max_metric_scale:
+ B = valid_masks[0].shape[0]
+ # Calculate distances to camera for all views
+ dists_to_cam1 = []
+ for view_idx in range(n_views):
+ dist = torch.where(
+ valid_masks[view_idx],
+ torch.norm(no_norm_gt_pts[view_idx], dim=-1),
+ 0,
+ ).view(B, -1)
+ dists_to_cam1.append(dist)
+
+ # Update metric scale flags
+ metric_scale_mask = batch[0]["is_metric_scale"]
+ for dist in dists_to_cam1:
+ metric_scale_mask = metric_scale_mask & (
+ dist.max(dim=-1).values < self.max_metric_scale
+ )
+
+ for view_idx in range(n_views):
+ batch[view_idx]["is_metric_scale"] = metric_scale_mask
+
+ non_metric_scale_mask = ~batch[0]["is_metric_scale"]
+ else:
+ non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"])
+
+ # Initialize normalized points
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+
+ # Normalize 3d points
+ if self.norm_mode and non_metric_scale_mask.any():
+ normalized_pr_pts = normalize_multiple_pointclouds(
+ [pts[non_metric_scale_mask] for pts in no_norm_pr_pts],
+ [mask[non_metric_scale_mask] for mask in valid_masks],
+ self.norm_mode,
+ )
+ for i in range(n_views):
+ pr_pts[i][non_metric_scale_mask] = normalized_pr_pts[i]
+ elif non_metric_scale_mask.any():
+ for i in range(n_views):
+ pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][
+ non_metric_scale_mask
+ ]
+
+ if self.norm_mode and not self.gt_scale:
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ normalized_gt_pts = gt_normalization_output[:-1]
+ norm_factor = gt_normalization_output[-1]
+ for i in range(n_views):
+ gt_pts[i] = normalized_gt_pts[i]
+ pr_pts[i][~non_metric_scale_mask] = (
+ no_norm_pr_pts[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask]
+ )
+ elif ~non_metric_scale_mask.any():
+ for i in range(n_views):
+ gt_pts[i] = no_norm_gt_pts[i]
+ pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][
+ ~non_metric_scale_mask
+ ]
+ else:
+ for i in range(n_views):
+ gt_pts[i] = no_norm_gt_pts[i]
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for view_idx in range(n_views):
+ ambiguous_masks.append(
+ (~batch[view_idx]["non_ambiguous_mask"]) & (~valid_masks[view_idx])
+ )
+
+ return gt_pts, pr_pts, valid_masks, ambiguous_masks, {}
+
+ def compute_loss(self, batch, preds, **kw):
+ gt_pts, pred_pts, masks, ambiguous_masks, monitoring = self.get_all_info(
+ batch, preds, **kw
+ )
+ n_views = len(batch)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixels as "valid" pixels
+ masks = [mask | amb_mask for mask, amb_mask in zip(masks, ambiguous_masks)]
+
+ losses = []
+ details = {}
+ running_avg_dict = {}
+ self_name = type(self).__name__
+
+ if not self.flatten_across_image_only:
+ for view_idx in range(n_views):
+ pred = pred_pts[view_idx][masks[view_idx]]
+ gt = gt_pts[view_idx][masks[view_idx]]
+
+ if self.loss_in_log:
+ pred = apply_log_to_norm(pred)
+ gt = apply_log_to_norm(gt)
+
+ loss = self.criterion(pred, gt)
+
+ if self.ambiguous_loss_value > 0:
+ loss = torch.where(
+ ambiguous_masks[view_idx][masks[view_idx]],
+ self.ambiguous_loss_value,
+ loss,
+ )
+
+ losses.append((loss, masks[view_idx], "pts3d"))
+ if loss.numel() > 0:
+ loss_mean = float(loss.mean())
+ details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_pts3d_avg"
+ if avg_key not in details:
+ details[avg_key] = loss_mean
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1
+ else:
+ valid_views = (
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1
+ )
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views
+ details[avg_key] += (loss_mean - details[avg_key]) / valid_views
+ else:
+ batch_size, _, _, dim = gt_pts[0].shape
+
+ for view_idx in range(n_views):
+ gt = gt_pts[view_idx].view(batch_size, -1, dim)
+ pred = pred_pts[view_idx].view(batch_size, -1, dim)
+ view_mask = masks[view_idx].view(batch_size, -1)
+ amb_mask = ambiguous_masks[view_idx].view(batch_size, -1)
+
+ if self.loss_in_log:
+ pred = apply_log_to_norm(pred)
+ gt = apply_log_to_norm(gt)
+
+ loss = self.criterion(pred, gt)
+
+ if self.ambiguous_loss_value > 0:
+ loss = torch.where(amb_mask, self.ambiguous_loss_value, loss)
+
+ losses.append((loss, view_mask, "pts3d"))
+ loss_after_masking = loss[view_mask]
+ if loss_after_masking.numel() > 0:
+ loss_mean = float(loss_after_masking.mean())
+ details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_pts3d_avg"
+ if avg_key not in details:
+ details[avg_key] = loss_mean
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1
+ else:
+ valid_views = (
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1
+ )
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views
+ details[avg_key] += (loss_mean - details[avg_key]) / valid_views
+
+ return Sum(*losses), (details | monitoring)
+
+
+class PointsPlusScaleRegr3D(Criterion, MultiLoss):
+ """
+ Regression Loss for World Frame Pointmaps & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ ambiguous_loss_value=0,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ world_frame_points_loss_weight=1,
+ scale_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for World Frame Pointmaps & Scale.
+ The predicited scene representation is always normalized w.r.t. the frame of view0.
+ Loss is applied between the predicted metric scale and the ground truth metric scale.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ """
+ super().__init__(criterion)
+ self.norm_predictions = norm_predictions
+ self.norm_mode = norm_mode
+ self.ambiguous_loss_value = ambiguous_loss_value
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+ self.world_frame_points_loss_weight = world_frame_points_loss_weight
+ self.scale_loss_weight = scale_loss_weight
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized w.r.t. camera of view0.
+ """
+ n_views = len(batch)
+
+ # Everything is normalized w.r.t. camera of view0
+ # Intialize lists to store data for all views
+ # Ground truth quantities
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+ no_norm_gt_pts = []
+ valid_masks = []
+ # Predicted quantities
+ no_norm_pr_pts = []
+ metric_pr_pts_to_compute_scale = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get the ground truth
+ no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+
+ # Get predictions for normalized loss
+ if "metric_scaling_factor" in preds[i].keys():
+ # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
+ # This detaches the predicted metric scaling factor from the geometry based loss
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ else:
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"]
+ no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
+
+ # Get the predicted metric scale points
+ if "metric_scaling_factor" in preds[i].keys():
+ # Detach the raw predicted points so that the scale loss is only applied to the scaling factor
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.detach()
+ * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
+ )
+ else:
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.clone()
+ )
+ metric_pr_pts_to_compute_scale.append(
+ curr_view_metric_pr_pts_to_compute_scale
+ )
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+
+ # Normalize the predicted points if specified
+ if self.norm_predictions:
+ pr_normalization_output = normalize_multiple_pointclouds(
+ no_norm_pr_pts,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+
+ # Normalize the ground truth points
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+ gt_norm_factor = gt_normalization_output[-1]
+
+ for i in range(n_views):
+ if self.norm_predictions:
+ # Assign the normalized predictions
+ pr_pts[i] = pr_pts_norm[i]
+ else:
+ pr_pts[i] = no_norm_pr_pts[i]
+ # Assign the normalized ground truth quantities
+ gt_pts[i] = gt_pts_norm[i]
+
+ # Get the mask indicating ground truth metric scale quantities
+ metric_scale_mask = batch[0]["is_metric_scale"]
+ valid_gt_norm_factor_mask = (
+ gt_norm_factor[:, 0, 0, 0] > 1e-8
+ ) # Mask out cases where depth for all views is invalid
+ valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
+
+ if valid_metric_scale_mask.any():
+ # Compute the scale norm factor using the predicted metric scale points
+ metric_pr_normalization_output = normalize_multiple_pointclouds(
+ metric_pr_pts_to_compute_scale,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_metric_norm_factor = metric_pr_normalization_output[-1]
+
+ # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
+ gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
+ pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
+ else:
+ gt_metric_norm_factor = None
+ pr_metric_norm_factor = None
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for i in range(n_views):
+ ambiguous_masks.append(
+ (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
+ )
+
+ # Pack into info dicts
+ gt_info = []
+ pred_info = []
+ for i in range(n_views):
+ gt_info.append(
+ {
+ "pts3d": gt_pts[i],
+ }
+ )
+ pred_info.append(
+ {
+ "pts3d": pr_pts[i],
+ }
+ )
+
+ return (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ )
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, pts_dim = gt_info[i]["pts3d"].shape
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space if specified
+ if self.loss_in_log:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ }
+
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class FactoredGeometryRegr3D(Criterion, MultiLoss):
+ """
+ Regression Loss for Factored Geometry.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_mode="?avg_dis",
+ gt_scale=False,
+ ambiguous_loss_value=0,
+ max_metric_scale=False,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ cam_frame_points_loss_weight=1,
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ compute_pairwise_relative_pose_loss=False,
+ compute_world_frame_points_loss=True,
+ world_frame_points_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose),
+ and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps.
+ If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order:
+ (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans.
+ Else, the pixel-level losses are returned in the following order:
+ (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis".
+ If prefixed with "?", normalization is only applied to non-metric scale data.
+ gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
+ If False, both GT and predictions are normalized independently. Default: False.
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
+ value, it will be treated as non-metric. Default: False (no limit).
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth and pointmaps. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
+ exhaustive pairwise relative poses. Default: False.
+ compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ """
+ super().__init__(criterion)
+ if norm_mode.startswith("?"):
+ # Do no norm pts from metric scale datasets
+ self.norm_all = False
+ self.norm_mode = norm_mode[1:]
+ else:
+ self.norm_all = True
+ self.norm_mode = norm_mode
+ self.gt_scale = gt_scale
+ self.ambiguous_loss_value = ambiguous_loss_value
+ self.max_metric_scale = max_metric_scale
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+ self.depth_type_for_loss = depth_type_for_loss
+ assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
+ "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
+ )
+ self.cam_frame_points_loss_weight = cam_frame_points_loss_weight
+ self.depth_loss_weight = depth_loss_weight
+ self.ray_directions_loss_weight = ray_directions_loss_weight
+ self.pose_quats_loss_weight = pose_quats_loss_weight
+ self.pose_trans_loss_weight = pose_trans_loss_weight
+ self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
+ self.compute_world_frame_points_loss = compute_world_frame_points_loss
+ self.world_frame_points_loss_weight = world_frame_points_loss_weight
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized w.r.t. camera of view0.
+ """
+ n_views = len(batch)
+
+ # Everything is normalized w.r.t. camera of view0
+ # Intialize lists to store data for all views
+ # Ground truth quantities
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+ no_norm_gt_pts = []
+ no_norm_gt_pts_cam = []
+ no_norm_gt_depth = []
+ no_norm_gt_pose_trans = []
+ valid_masks = []
+ gt_ray_directions = []
+ gt_pose_quats = []
+ # Predicted quantities
+ no_norm_pr_pts = []
+ no_norm_pr_pts_cam = []
+ no_norm_pr_depth = []
+ no_norm_pr_pose_trans = []
+ pr_ray_directions = []
+ pr_pose_quats = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get ground truth
+ no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+ no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
+ gt_ray_directions.append(batch[i]["ray_directions_cam"])
+ if self.depth_type_for_loss == "depth_along_ray":
+ no_norm_gt_depth.append(batch[i]["depth_along_ray"])
+ elif self.depth_type_for_loss == "depth_z":
+ no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
+ if i == 0:
+ # For view0, initialize identity pose
+ gt_pose_quats.append(
+ torch.tensor(
+ [0, 0, 0, 1],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ no_norm_gt_pose_trans.append(
+ torch.tensor(
+ [0, 0, 0],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ else:
+ # For other views, transform pose to view0's frame
+ gt_pose_quats_world = batch[i]["camera_pose_quats"]
+ no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
+ gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
+ transform_pose_using_quats_and_trans_2_to_1(
+ batch[0]["camera_pose_quats"],
+ batch[0]["camera_pose_trans"],
+ gt_pose_quats_world,
+ no_norm_gt_pose_trans_world,
+ )
+ )
+ gt_pose_quats.append(gt_pose_quats_in_view0)
+ no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
+
+ # Get predictions
+ no_norm_pr_pts.append(preds[i]["pts3d"])
+ no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"])
+ pr_ray_directions.append(preds[i]["ray_directions"])
+ if self.depth_type_for_loss == "depth_along_ray":
+ no_norm_pr_depth.append(preds[i]["depth_along_ray"])
+ elif self.depth_type_for_loss == "depth_z":
+ no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:])
+ no_norm_pr_pose_trans.append(preds[i]["cam_trans"])
+ pr_pose_quats.append(preds[i]["cam_quats"])
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Handle metric scale
+ if not self.norm_all:
+ if self.max_metric_scale:
+ B = valid_masks[0].shape[0]
+ dists_to_cam1 = []
+ for i in range(n_views):
+ dists_to_cam1.append(
+ torch.where(
+ valid_masks[i], torch.norm(no_norm_gt_pts[i], dim=-1), 0
+ ).view(B, -1)
+ )
+
+ batch[0]["is_metric_scale"] = batch[0]["is_metric_scale"]
+ for dist in dists_to_cam1:
+ batch[0]["is_metric_scale"] &= (
+ dist.max(dim=-1).values < self.max_metric_scale
+ )
+
+ for i in range(1, n_views):
+ batch[i]["is_metric_scale"] = batch[0]["is_metric_scale"]
+
+ non_metric_scale_mask = ~batch[0]["is_metric_scale"]
+ else:
+ non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"])
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
+ gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
+ gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
+
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+ pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
+ pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
+ pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
+
+ # Normalize points
+ if self.norm_mode and non_metric_scale_mask.any():
+ pr_normalization_output = normalize_multiple_pointclouds(
+ [pts[non_metric_scale_mask] for pts in no_norm_pr_pts],
+ [mask[non_metric_scale_mask] for mask in valid_masks],
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+ pr_norm_factor = pr_normalization_output[-1]
+
+ for i in range(n_views):
+ pr_pts[i][non_metric_scale_mask] = pr_pts_norm[i]
+ pr_pts_cam[i][non_metric_scale_mask] = (
+ no_norm_pr_pts_cam[i][non_metric_scale_mask] / pr_norm_factor
+ )
+ pr_depth[i][non_metric_scale_mask] = (
+ no_norm_pr_depth[i][non_metric_scale_mask] / pr_norm_factor
+ )
+ pr_pose_trans[i][non_metric_scale_mask] = (
+ no_norm_pr_pose_trans[i][non_metric_scale_mask]
+ / pr_norm_factor[:, :, 0, 0]
+ )
+
+ elif non_metric_scale_mask.any():
+ for i in range(n_views):
+ pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][
+ non_metric_scale_mask
+ ]
+ pr_pts_cam[i][non_metric_scale_mask] = no_norm_pr_pts_cam[i][
+ non_metric_scale_mask
+ ]
+ pr_depth[i][non_metric_scale_mask] = no_norm_pr_depth[i][
+ non_metric_scale_mask
+ ]
+ pr_pose_trans[i][non_metric_scale_mask] = no_norm_pr_pose_trans[i][
+ non_metric_scale_mask
+ ]
+
+ if self.norm_mode and not self.gt_scale:
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+ norm_factor = gt_normalization_output[-1]
+
+ for i in range(n_views):
+ gt_pts[i] = gt_pts_norm[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i] / norm_factor
+ gt_depth[i] = no_norm_gt_depth[i] / norm_factor
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i] / norm_factor[:, :, 0, 0]
+
+ pr_pts[i][~non_metric_scale_mask] = (
+ no_norm_pr_pts[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask]
+ )
+ pr_pts_cam[i][~non_metric_scale_mask] = (
+ no_norm_pr_pts_cam[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask]
+ )
+ pr_depth[i][~non_metric_scale_mask] = (
+ no_norm_pr_depth[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask]
+ )
+ pr_pose_trans[i][~non_metric_scale_mask] = (
+ no_norm_pr_pose_trans[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask][:, :, 0, 0]
+ )
+
+ elif ~non_metric_scale_mask.any():
+ for i in range(n_views):
+ gt_pts[i] = no_norm_gt_pts[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i]
+ gt_depth[i] = no_norm_gt_depth[i]
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i]
+ pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][
+ ~non_metric_scale_mask
+ ]
+ pr_pts_cam[i][~non_metric_scale_mask] = no_norm_pr_pts_cam[i][
+ ~non_metric_scale_mask
+ ]
+ pr_depth[i][~non_metric_scale_mask] = no_norm_pr_depth[i][
+ ~non_metric_scale_mask
+ ]
+ pr_pose_trans[i][~non_metric_scale_mask] = no_norm_pr_pose_trans[i][
+ ~non_metric_scale_mask
+ ]
+ else:
+ for i in range(n_views):
+ gt_pts[i] = no_norm_gt_pts[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i]
+ gt_depth[i] = no_norm_gt_depth[i]
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i]
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for i in range(n_views):
+ ambiguous_masks.append(
+ (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
+ )
+
+ # Pack into info dicts
+ gt_info = []
+ pred_info = []
+ for i in range(n_views):
+ gt_info.append(
+ {
+ "ray_directions": gt_ray_directions[i],
+ self.depth_type_for_loss: gt_depth[i],
+ "pose_trans": gt_pose_trans[i],
+ "pose_quats": gt_pose_quats[i],
+ "pts3d": gt_pts[i],
+ "pts3d_cam": gt_pts_cam[i],
+ }
+ )
+ pred_info.append(
+ {
+ "ray_directions": pr_ray_directions[i],
+ self.depth_type_for_loss: pr_depth[i],
+ "pose_trans": pr_pose_trans[i],
+ "pose_quats": pr_pose_quats[i],
+ "pts3d": pr_pts[i],
+ "pts3d_cam": pr_pts_cam[i],
+ }
+ )
+
+ return gt_info, pred_info, valid_masks, ambiguous_masks
+
+ def compute_loss(self, batch, preds, **kw):
+ gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info(
+ batch, preds, **kw
+ )
+ n_views = len(batch)
+
+ # Mask out samples in the batch where the gt depth validity mask is entirely zero
+ valid_norm_factor_masks = [
+ mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
+ ] # List of (B,)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+ cam_pts3d_losses = []
+ if self.compute_world_frame_points_loss:
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
+ gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
+ if self.compute_world_frame_points_loss:
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
+ gt_ray_directions = gt_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ pred_ray_directions = pred_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
+ gt_depth = gt_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ pred_depth = pred_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
+ batch_size, -1, cam_pts_dim
+ )
+ if self.compute_world_frame_points_loss:
+ pts_dim = gt_info[i]["pts3d"].shape[-1]
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space for depth if specified
+ if self.loss_in_log:
+ gt_depth = apply_log_to_norm(gt_depth)
+ pred_depth = apply_log_to_norm(pred_depth)
+ gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
+ pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
+ if self.compute_world_frame_points_loss:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ if self.compute_pairwise_relative_pose_loss:
+ # Get the inverse of current view predicted pose
+ pred_inv_curr_view_pose_quats = quaternion_inverse(
+ pred_info[i]["pose_quats"]
+ )
+ pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ pred_inv_curr_view_pose_quats
+ )
+ pred_inv_curr_view_pose_trans = -1 * ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the inverse of the current view GT pose
+ gt_inv_curr_view_pose_quats = quaternion_inverse(
+ gt_info[i]["pose_quats"]
+ )
+ gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ gt_inv_curr_view_pose_quats
+ )
+ gt_inv_curr_view_pose_trans = -1 * ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the other N-1 relative poses using the current pose as reference frame
+ pred_rel_pose_quats = []
+ pred_rel_pose_trans = []
+ gt_rel_pose_quats = []
+ gt_rel_pose_trans = []
+ for ov_idx in range(n_views):
+ if ov_idx == i:
+ continue
+ # Get the relative predicted pose
+ pred_ov_rel_pose_quats = quaternion_multiply(
+ pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
+ )
+ pred_ov_rel_pose_trans = (
+ ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + pred_inv_curr_view_pose_trans
+ )
+
+ # Get the relative GT pose
+ gt_ov_rel_pose_quats = quaternion_multiply(
+ gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
+ )
+ gt_ov_rel_pose_trans = (
+ ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + gt_inv_curr_view_pose_trans
+ )
+
+ # Get the valid translations using valid_norm_factor_masks for current view and other view
+ overall_valid_mask_for_trans = (
+ valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
+ )
+
+ # Append the relative poses
+ pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
+ pred_rel_pose_trans.append(
+ pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+ gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
+ gt_rel_pose_trans.append(
+ gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+
+ # Cat the N-1 relative poses along the batch dimension
+ pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
+ pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
+ gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
+ gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(
+ pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
+ ),
+ self.criterion(
+ pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+ else:
+ # Get the pose info for the current view
+ pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans, gt_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
+ self.criterion(
+ pred_pose_quats, -gt_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions, gt_ray_directions, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute camera frame point loss
+ cam_pts3d_loss = self.criterion(
+ pred_cam_pts3d, gt_cam_pts3d, factor="points"
+ )
+ cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
+ cam_pts3d_losses.append(cam_pts3d_loss)
+
+ if self.compute_world_frame_points_loss:
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Use helper function to generate loss terms and details
+ if self.compute_world_frame_points_loss:
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ }
+ else:
+ losses_dict = {}
+ losses_dict.update(
+ {
+ "cam_pts3d": {
+ "values": cam_pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
+ """
+ Regression, Normals & Gradient Matching Loss for Factored Geometry.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_mode="?avg_dis",
+ gt_scale=False,
+ ambiguous_loss_value=0,
+ max_metric_scale=False,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ cam_frame_points_loss_weight=1,
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ compute_pairwise_relative_pose_loss=False,
+ compute_world_frame_points_loss=True,
+ world_frame_points_loss_weight=1,
+ apply_normal_and_gm_loss_to_synthetic_data_only=True,
+ normal_loss_weight=1,
+ gm_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for Factored Geometry (see parent class for details).
+ Additionally computes:
+ (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
+ (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_mode (str): Normalization mode for scene representation. Default: "avg_dis".
+ If prefixed with "?", normalization is only applied to non-metric scale data.
+ gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
+ If False, both GT and predictions are normalized independently. Default: False.
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
+ value, it will be treated as non-metric. Default: False (no limit).
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth and pointmaps. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
+ exhaustive pairwise relative poses. Default: False.
+ compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
+ If False, apply the normal and gm loss to all data. Default: True.
+ normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
+ gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
+ """
+ super().__init__(
+ criterion=criterion,
+ norm_mode=norm_mode,
+ gt_scale=gt_scale,
+ ambiguous_loss_value=ambiguous_loss_value,
+ max_metric_scale=max_metric_scale,
+ loss_in_log=loss_in_log,
+ flatten_across_image_only=flatten_across_image_only,
+ depth_type_for_loss=depth_type_for_loss,
+ cam_frame_points_loss_weight=cam_frame_points_loss_weight,
+ depth_loss_weight=depth_loss_weight,
+ ray_directions_loss_weight=ray_directions_loss_weight,
+ pose_quats_loss_weight=pose_quats_loss_weight,
+ pose_trans_loss_weight=pose_trans_loss_weight,
+ compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
+ compute_world_frame_points_loss=compute_world_frame_points_loss,
+ world_frame_points_loss_weight=world_frame_points_loss_weight,
+ )
+ self.apply_normal_and_gm_loss_to_synthetic_data_only = (
+ apply_normal_and_gm_loss_to_synthetic_data_only
+ )
+ self.normal_loss_weight = normal_loss_weight
+ self.gm_loss_weight = gm_loss_weight
+
+ def compute_loss(self, batch, preds, **kw):
+ gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info(
+ batch, preds, **kw
+ )
+ n_views = len(batch)
+
+ # Mask out samples in the batch where the gt depth validity mask is entirely zero
+ valid_norm_factor_masks = [
+ mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
+ ] # List of (B,)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ normal_losses = []
+ gradient_matching_losses = []
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+ cam_pts3d_losses = []
+ if self.compute_world_frame_points_loss:
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the camera frame points, log space depth_z & valid masks
+ pred_local_pts3d = pred_info[i]["pts3d_cam"]
+ pred_depth_z = pred_local_pts3d[..., 2:]
+ pred_depth_z = apply_log_to_norm(pred_depth_z)
+ gt_local_pts3d = gt_info[i]["pts3d_cam"]
+ gt_depth_z = gt_local_pts3d[..., 2:]
+ gt_depth_z = apply_log_to_norm(gt_depth_z)
+ valid_mask_for_normal_gm_loss = valid_masks[i].clone()
+
+ # Update the validity mask for normal & gm loss based on the synthetic data mask if required
+ if self.apply_normal_and_gm_loss_to_synthetic_data_only:
+ synthetic_mask = batch[i]["is_synthetic"] # (B, )
+ synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
+ synthetic_mask = synthetic_mask.expand(
+ -1, pred_depth_z.shape[1], pred_depth_z.shape[2]
+ ) # (B, H, W)
+ valid_mask_for_normal_gm_loss = (
+ valid_mask_for_normal_gm_loss & synthetic_mask
+ )
+
+ # Compute the normal loss
+ normal_loss = compute_normal_loss(
+ pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
+ )
+ normal_loss = normal_loss * self.normal_loss_weight
+ normal_losses.append(normal_loss)
+
+ # Compute the gradient matching loss
+ gradient_matching_loss = compute_gradient_matching_loss(
+ pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
+ )
+ gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
+ gradient_matching_losses.append(gradient_matching_loss)
+
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
+ gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
+ if self.compute_world_frame_points_loss:
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
+ gt_ray_directions = gt_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ pred_ray_directions = pred_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
+ gt_depth = gt_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ pred_depth = pred_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
+ batch_size, -1, cam_pts_dim
+ )
+ if self.compute_world_frame_points_loss:
+ pts_dim = gt_info[i]["pts3d"].shape[-1]
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space for depth if specified
+ if self.loss_in_log:
+ gt_depth = apply_log_to_norm(gt_depth)
+ pred_depth = apply_log_to_norm(pred_depth)
+ gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
+ pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
+ if self.compute_world_frame_points_loss:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ if self.compute_pairwise_relative_pose_loss:
+ # Get the inverse of current view predicted pose
+ pred_inv_curr_view_pose_quats = quaternion_inverse(
+ pred_info[i]["pose_quats"]
+ )
+ pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ pred_inv_curr_view_pose_quats
+ )
+ pred_inv_curr_view_pose_trans = -1 * ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the inverse of the current view GT pose
+ gt_inv_curr_view_pose_quats = quaternion_inverse(
+ gt_info[i]["pose_quats"]
+ )
+ gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ gt_inv_curr_view_pose_quats
+ )
+ gt_inv_curr_view_pose_trans = -1 * ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the other N-1 relative poses using the current pose as reference frame
+ pred_rel_pose_quats = []
+ pred_rel_pose_trans = []
+ gt_rel_pose_quats = []
+ gt_rel_pose_trans = []
+ for ov_idx in range(n_views):
+ if ov_idx == i:
+ continue
+ # Get the relative predicted pose
+ pred_ov_rel_pose_quats = quaternion_multiply(
+ pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
+ )
+ pred_ov_rel_pose_trans = (
+ ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + pred_inv_curr_view_pose_trans
+ )
+
+ # Get the relative GT pose
+ gt_ov_rel_pose_quats = quaternion_multiply(
+ gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
+ )
+ gt_ov_rel_pose_trans = (
+ ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + gt_inv_curr_view_pose_trans
+ )
+
+ # Get the valid translations using valid_norm_factor_masks for current view and other view
+ overall_valid_mask_for_trans = (
+ valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
+ )
+
+ # Append the relative poses
+ pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
+ pred_rel_pose_trans.append(
+ pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+ gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
+ gt_rel_pose_trans.append(
+ gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+
+ # Cat the N-1 relative poses along the batch dimension
+ pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
+ pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
+ gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
+ gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(
+ pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
+ ),
+ self.criterion(
+ pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+ else:
+ # Get the pose info for the current view
+ pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans, gt_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
+ self.criterion(
+ pred_pose_quats, -gt_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions, gt_ray_directions, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute camera frame point loss
+ cam_pts3d_loss = self.criterion(
+ pred_cam_pts3d, gt_cam_pts3d, factor="points"
+ )
+ cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
+ cam_pts3d_losses.append(cam_pts3d_loss)
+
+ if self.compute_world_frame_points_loss:
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Use helper function to generate loss terms and details
+ if self.compute_world_frame_points_loss:
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ }
+ else:
+ losses_dict = {}
+ losses_dict.update(
+ {
+ "cam_pts3d": {
+ "values": cam_pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "normal": {
+ "values": normal_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "gradient_matching": {
+ "values": gradient_matching_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
+ """
+ Regression Loss for Factored Geometry & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ ambiguous_loss_value=0,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ cam_frame_points_loss_weight=1,
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ scale_loss_weight=1,
+ compute_pairwise_relative_pose_loss=False,
+ compute_world_frame_points_loss=True,
+ world_frame_points_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), Scale
+ and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps.
+ If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order:
+ (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans, (7) scale.
+ Else, the pixel-level losses are returned in the following order:
+ (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans, (6) scale.
+ The predicited scene representation is always normalized w.r.t. the frame of view0.
+ Loss is applied between the predicted metric scale and the ground truth metric scale.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
+ exhaustive pairwise relative poses. Default: False.
+ compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ """
+ super().__init__(criterion)
+ self.norm_predictions = norm_predictions
+ self.norm_mode = norm_mode
+ self.ambiguous_loss_value = ambiguous_loss_value
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+ self.depth_type_for_loss = depth_type_for_loss
+ assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
+ "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
+ )
+ self.cam_frame_points_loss_weight = cam_frame_points_loss_weight
+ self.depth_loss_weight = depth_loss_weight
+ self.ray_directions_loss_weight = ray_directions_loss_weight
+ self.pose_quats_loss_weight = pose_quats_loss_weight
+ self.pose_trans_loss_weight = pose_trans_loss_weight
+ self.scale_loss_weight = scale_loss_weight
+ self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
+ self.compute_world_frame_points_loss = compute_world_frame_points_loss
+ self.world_frame_points_loss_weight = world_frame_points_loss_weight
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized w.r.t. camera of view0.
+ """
+ n_views = len(batch)
+
+ # Everything is normalized w.r.t. camera of view0
+ # Intialize lists to store data for all views
+ # Ground truth quantities
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+ no_norm_gt_pts = []
+ no_norm_gt_pts_cam = []
+ no_norm_gt_depth = []
+ no_norm_gt_pose_trans = []
+ valid_masks = []
+ gt_ray_directions = []
+ gt_pose_quats = []
+ # Predicted quantities
+ no_norm_pr_pts = []
+ no_norm_pr_pts_cam = []
+ no_norm_pr_depth = []
+ no_norm_pr_pose_trans = []
+ pr_ray_directions = []
+ pr_pose_quats = []
+ metric_pr_pts_to_compute_scale = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get the ground truth
+ no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+ no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
+ gt_ray_directions.append(batch[i]["ray_directions_cam"])
+ if self.depth_type_for_loss == "depth_along_ray":
+ no_norm_gt_depth.append(batch[i]["depth_along_ray"])
+ elif self.depth_type_for_loss == "depth_z":
+ no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
+ if i == 0:
+ # For view0, initialize identity pose
+ gt_pose_quats.append(
+ torch.tensor(
+ [0, 0, 0, 1],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ no_norm_gt_pose_trans.append(
+ torch.tensor(
+ [0, 0, 0],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ else:
+ # For other views, transform pose to view0's frame
+ gt_pose_quats_world = batch[i]["camera_pose_quats"]
+ no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
+ gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
+ transform_pose_using_quats_and_trans_2_to_1(
+ batch[0]["camera_pose_quats"],
+ batch[0]["camera_pose_trans"],
+ gt_pose_quats_world,
+ no_norm_gt_pose_trans_world,
+ )
+ )
+ gt_pose_quats.append(gt_pose_quats_in_view0)
+ no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
+
+ # Get predictions for normalized loss
+ if self.depth_type_for_loss == "depth_along_ray":
+ curr_view_no_norm_depth = preds[i]["depth_along_ray"]
+ elif self.depth_type_for_loss == "depth_z":
+ curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:]
+ if "metric_scaling_factor" in preds[i].keys():
+ # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
+ # This detaches the predicted metric scaling factor from the geometry based loss
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_pr_pose_trans = (
+ preds[i]["cam_trans"] / preds[i]["metric_scaling_factor"]
+ )
+ else:
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"]
+ curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
+ curr_view_no_norm_depth = curr_view_no_norm_depth
+ curr_view_no_norm_pr_pose_trans = preds[i]["cam_trans"]
+ no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
+ no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
+ no_norm_pr_depth.append(curr_view_no_norm_depth)
+ no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
+ pr_ray_directions.append(preds[i]["ray_directions"])
+ pr_pose_quats.append(preds[i]["cam_quats"])
+
+ # Get the predicted metric scale points
+ if "metric_scaling_factor" in preds[i].keys():
+ # Detach the raw predicted points so that the scale loss is only applied to the scaling factor
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.detach()
+ * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
+ )
+ else:
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.clone()
+ )
+ metric_pr_pts_to_compute_scale.append(
+ curr_view_metric_pr_pts_to_compute_scale
+ )
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
+ gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
+ gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
+
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+ pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
+ pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
+ pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
+
+ # Normalize the predicted points if specified
+ if self.norm_predictions:
+ pr_normalization_output = normalize_multiple_pointclouds(
+ no_norm_pr_pts,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+ pr_norm_factor = pr_normalization_output[-1]
+
+ # Normalize the ground truth points
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+ gt_norm_factor = gt_normalization_output[-1]
+
+ for i in range(n_views):
+ if self.norm_predictions:
+ # Assign the normalized predictions
+ pr_pts[i] = pr_pts_norm[i]
+ pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor
+ pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor
+ pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0]
+ else:
+ pr_pts[i] = no_norm_pr_pts[i]
+ pr_pts_cam[i] = no_norm_pr_pts_cam[i]
+ pr_depth[i] = no_norm_pr_depth[i]
+ pr_pose_trans[i] = no_norm_pr_pose_trans[i]
+ # Assign the normalized ground truth quantities
+ gt_pts[i] = gt_pts_norm[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor
+ gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0]
+
+ # Get the mask indicating ground truth metric scale quantities
+ metric_scale_mask = batch[0]["is_metric_scale"]
+ valid_gt_norm_factor_mask = (
+ gt_norm_factor[:, 0, 0, 0] > 1e-8
+ ) # Mask out cases where depth for all views is invalid
+ valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
+
+ if valid_metric_scale_mask.any():
+ # Compute the scale norm factor using the predicted metric scale points
+ metric_pr_normalization_output = normalize_multiple_pointclouds(
+ metric_pr_pts_to_compute_scale,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_metric_norm_factor = metric_pr_normalization_output[-1]
+
+ # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
+ gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
+ pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
+ else:
+ gt_metric_norm_factor = None
+ pr_metric_norm_factor = None
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for i in range(n_views):
+ ambiguous_masks.append(
+ (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
+ )
+
+ # Pack into info dicts
+ gt_info = []
+ pred_info = []
+ for i in range(n_views):
+ gt_info.append(
+ {
+ "ray_directions": gt_ray_directions[i],
+ self.depth_type_for_loss: gt_depth[i],
+ "pose_trans": gt_pose_trans[i],
+ "pose_quats": gt_pose_quats[i],
+ "pts3d": gt_pts[i],
+ "pts3d_cam": gt_pts_cam[i],
+ }
+ )
+ pred_info.append(
+ {
+ "ray_directions": pr_ray_directions[i],
+ self.depth_type_for_loss: pr_depth[i],
+ "pose_trans": pr_pose_trans[i],
+ "pose_quats": pr_pose_quats[i],
+ "pts3d": pr_pts[i],
+ "pts3d_cam": pr_pts_cam[i],
+ }
+ )
+
+ return (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ )
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ # Mask out samples in the batch where the gt depth validity mask is entirely zero
+ valid_norm_factor_masks = [
+ mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
+ ] # List of (B,)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+ cam_pts3d_losses = []
+ if self.compute_world_frame_points_loss:
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
+ gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
+ if self.compute_world_frame_points_loss:
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
+ gt_ray_directions = gt_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ pred_ray_directions = pred_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
+ gt_depth = gt_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ pred_depth = pred_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
+ batch_size, -1, cam_pts_dim
+ )
+ if self.compute_world_frame_points_loss:
+ pts_dim = gt_info[i]["pts3d"].shape[-1]
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space for depth if specified
+ if self.loss_in_log:
+ gt_depth = apply_log_to_norm(gt_depth)
+ pred_depth = apply_log_to_norm(pred_depth)
+ gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
+ pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
+ if self.compute_world_frame_points_loss:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ if self.compute_pairwise_relative_pose_loss:
+ # Get the inverse of current view predicted pose
+ pred_inv_curr_view_pose_quats = quaternion_inverse(
+ pred_info[i]["pose_quats"]
+ )
+ pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ pred_inv_curr_view_pose_quats
+ )
+ pred_inv_curr_view_pose_trans = -1 * ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the inverse of the current view GT pose
+ gt_inv_curr_view_pose_quats = quaternion_inverse(
+ gt_info[i]["pose_quats"]
+ )
+ gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ gt_inv_curr_view_pose_quats
+ )
+ gt_inv_curr_view_pose_trans = -1 * ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the other N-1 relative poses using the current pose as reference frame
+ pred_rel_pose_quats = []
+ pred_rel_pose_trans = []
+ gt_rel_pose_quats = []
+ gt_rel_pose_trans = []
+ for ov_idx in range(n_views):
+ if ov_idx == i:
+ continue
+ # Get the relative predicted pose
+ pred_ov_rel_pose_quats = quaternion_multiply(
+ pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
+ )
+ pred_ov_rel_pose_trans = (
+ ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + pred_inv_curr_view_pose_trans
+ )
+
+ # Get the relative GT pose
+ gt_ov_rel_pose_quats = quaternion_multiply(
+ gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
+ )
+ gt_ov_rel_pose_trans = (
+ ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + gt_inv_curr_view_pose_trans
+ )
+
+ # Get the valid translations using valid_norm_factor_masks for current view and other view
+ overall_valid_mask_for_trans = (
+ valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
+ )
+
+ # Append the relative poses
+ pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
+ pred_rel_pose_trans.append(
+ pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+ gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
+ gt_rel_pose_trans.append(
+ gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+
+ # Cat the N-1 relative poses along the batch dimension
+ pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
+ pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
+ gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
+ gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(
+ pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
+ ),
+ self.criterion(
+ pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+ else:
+ # Get the pose info for the current view
+ pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans, gt_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
+ self.criterion(
+ pred_pose_quats, -gt_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions, gt_ray_directions, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute camera frame point loss
+ cam_pts3d_loss = self.criterion(
+ pred_cam_pts3d, gt_cam_pts3d, factor="points"
+ )
+ cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
+ cam_pts3d_losses.append(cam_pts3d_loss)
+
+ if self.compute_world_frame_points_loss:
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+ if self.compute_world_frame_points_loss:
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ }
+ else:
+ losses_dict = {}
+ losses_dict.update(
+ {
+ "cam_pts3d": {
+ "values": cam_pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
+ """
+ Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ ambiguous_loss_value=0,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ cam_frame_points_loss_weight=1,
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ scale_loss_weight=1,
+ compute_pairwise_relative_pose_loss=False,
+ compute_world_frame_points_loss=True,
+ world_frame_points_loss_weight=1,
+ apply_normal_and_gm_loss_to_synthetic_data_only=True,
+ normal_loss_weight=1,
+ gm_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for Ray Directions, Depth, Pose, Pointmaps & Scale.
+ Additionally computes:
+ (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
+ (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
+ exhaustive pairwise relative poses. Default: False.
+ compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
+ If False, apply the normal and gm loss to all data. Default: True.
+ normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
+ gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
+ """
+ super().__init__(
+ criterion=criterion,
+ norm_predictions=norm_predictions,
+ norm_mode=norm_mode,
+ ambiguous_loss_value=ambiguous_loss_value,
+ loss_in_log=loss_in_log,
+ flatten_across_image_only=flatten_across_image_only,
+ depth_type_for_loss=depth_type_for_loss,
+ cam_frame_points_loss_weight=cam_frame_points_loss_weight,
+ depth_loss_weight=depth_loss_weight,
+ ray_directions_loss_weight=ray_directions_loss_weight,
+ pose_quats_loss_weight=pose_quats_loss_weight,
+ pose_trans_loss_weight=pose_trans_loss_weight,
+ scale_loss_weight=scale_loss_weight,
+ compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
+ compute_world_frame_points_loss=compute_world_frame_points_loss,
+ world_frame_points_loss_weight=world_frame_points_loss_weight,
+ )
+ self.apply_normal_and_gm_loss_to_synthetic_data_only = (
+ apply_normal_and_gm_loss_to_synthetic_data_only
+ )
+ self.normal_loss_weight = normal_loss_weight
+ self.gm_loss_weight = gm_loss_weight
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ # Mask out samples in the batch where the gt depth validity mask is entirely zero
+ valid_norm_factor_masks = [
+ mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
+ ] # List of (B,)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ normal_losses = []
+ gradient_matching_losses = []
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+ cam_pts3d_losses = []
+ if self.compute_world_frame_points_loss:
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the camera frame points, log space depth_z & valid masks
+ pred_local_pts3d = pred_info[i]["pts3d_cam"]
+ pred_depth_z = pred_local_pts3d[..., 2:]
+ pred_depth_z = apply_log_to_norm(pred_depth_z)
+ gt_local_pts3d = gt_info[i]["pts3d_cam"]
+ gt_depth_z = gt_local_pts3d[..., 2:]
+ gt_depth_z = apply_log_to_norm(gt_depth_z)
+ valid_mask_for_normal_gm_loss = valid_masks[i].clone()
+
+ # Update the validity mask for normal & gm loss based on the synthetic data mask if required
+ if self.apply_normal_and_gm_loss_to_synthetic_data_only:
+ synthetic_mask = batch[i]["is_synthetic"] # (B, )
+ synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
+ synthetic_mask = synthetic_mask.expand(
+ -1, pred_depth_z.shape[1], pred_depth_z.shape[2]
+ ) # (B, H, W)
+ valid_mask_for_normal_gm_loss = (
+ valid_mask_for_normal_gm_loss & synthetic_mask
+ )
+
+ # Compute the normal loss
+ normal_loss = compute_normal_loss(
+ pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
+ )
+ normal_loss = normal_loss * self.normal_loss_weight
+ normal_losses.append(normal_loss)
+
+ # Compute the gradient matching loss
+ gradient_matching_loss = compute_gradient_matching_loss(
+ pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
+ )
+ gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
+ gradient_matching_losses.append(gradient_matching_loss)
+
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks and compute the metrics
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
+ gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
+ if self.compute_world_frame_points_loss:
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W and compute the metrics
+ batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
+ gt_ray_directions = gt_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ pred_ray_directions = pred_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
+ gt_depth = gt_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ pred_depth = pred_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
+ batch_size, -1, cam_pts_dim
+ )
+ if self.compute_world_frame_points_loss:
+ pts_dim = gt_info[i]["pts3d"].shape[-1]
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space for depth if specified
+ if self.loss_in_log:
+ gt_depth = apply_log_to_norm(gt_depth)
+ pred_depth = apply_log_to_norm(pred_depth)
+ gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
+ pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
+ if self.compute_world_frame_points_loss:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ if self.compute_pairwise_relative_pose_loss:
+ # Get the inverse of current view predicted pose
+ pred_inv_curr_view_pose_quats = quaternion_inverse(
+ pred_info[i]["pose_quats"]
+ )
+ pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ pred_inv_curr_view_pose_quats
+ )
+ pred_inv_curr_view_pose_trans = -1 * ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the inverse of the current view GT pose
+ gt_inv_curr_view_pose_quats = quaternion_inverse(
+ gt_info[i]["pose_quats"]
+ )
+ gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ gt_inv_curr_view_pose_quats
+ )
+ gt_inv_curr_view_pose_trans = -1 * ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the other N-1 relative poses using the current pose as reference frame
+ pred_rel_pose_quats = []
+ pred_rel_pose_trans = []
+ gt_rel_pose_quats = []
+ gt_rel_pose_trans = []
+ for ov_idx in range(n_views):
+ if ov_idx == i:
+ continue
+ # Get the relative predicted pose
+ pred_ov_rel_pose_quats = quaternion_multiply(
+ pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
+ )
+ pred_ov_rel_pose_trans = (
+ ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + pred_inv_curr_view_pose_trans
+ )
+
+ # Get the relative GT pose
+ gt_ov_rel_pose_quats = quaternion_multiply(
+ gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
+ )
+ gt_ov_rel_pose_trans = (
+ ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + gt_inv_curr_view_pose_trans
+ )
+
+ # Get the valid translations using valid_norm_factor_masks for current view and other view
+ overall_valid_mask_for_trans = (
+ valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
+ )
+
+ # Append the relative poses
+ pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
+ pred_rel_pose_trans.append(
+ pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+ gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
+ gt_rel_pose_trans.append(
+ gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+
+ # Cat the N-1 relative poses along the batch dimension
+ pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
+ pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
+ gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
+ gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(
+ pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
+ ),
+ self.criterion(
+ pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+ else:
+ # Get the pose info for the current view
+ pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans, gt_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
+ self.criterion(
+ pred_pose_quats, -gt_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions, gt_ray_directions, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute camera frame point loss
+ cam_pts3d_loss = self.criterion(
+ pred_cam_pts3d, gt_cam_pts3d, factor="points"
+ )
+ cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
+ cam_pts3d_losses.append(cam_pts3d_loss)
+
+ if self.compute_world_frame_points_loss:
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+ if self.compute_world_frame_points_loss:
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ }
+ else:
+ losses_dict = {}
+ losses_dict.update(
+ {
+ "cam_pts3d": {
+ "values": cam_pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ "normal": {
+ "values": normal_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "gradient_matching": {
+ "values": gradient_matching_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class DisentangledFactoredGeometryScaleRegr3D(Criterion, MultiLoss):
+ """
+ Disentangled Regression Loss for Factored Geometry & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ scale_loss_weight=1,
+ ):
+ """
+ Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale.
+ It isolates/disentangles the contribution of each factor to the final task of 3D reconstruction.
+ All the losses are in the same space where the loss for each factor is computed by constructing world-frame pointmaps.
+ This sidesteps the difficulty of finding a proper weighting.
+ For insance, for predicted rays, the GT depth & pose is used to construct the predicted world-frame pointmaps on which the loss is computed.
+ Inspired by https://openaccess.thecvf.com/content_ICCV_2019/papers/Simonelli_Disentangling_Monocular_3D_Object_Detection_ICCV_2019_paper.pdf
+
+ The pixel-level losses are computed in the following order:
+ (1) depth, (2) ray directions, (3) pose quats, (4) pose trans, (5) scale.
+ The predicited scene representation is always normalized w.r.t. the frame of view0.
+ Loss is applied between the predicted metric scale and the ground truth metric scale.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ """
+ super().__init__(criterion)
+ self.norm_predictions = norm_predictions
+ self.norm_mode = norm_mode
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+ self.depth_type_for_loss = depth_type_for_loss
+ assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
+ "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
+ )
+ self.depth_loss_weight = depth_loss_weight
+ self.ray_directions_loss_weight = ray_directions_loss_weight
+ self.pose_quats_loss_weight = pose_quats_loss_weight
+ self.pose_trans_loss_weight = pose_trans_loss_weight
+ self.scale_loss_weight = scale_loss_weight
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized w.r.t. camera of view0.
+ """
+ n_views = len(batch)
+
+ # Everything is normalized w.r.t. camera of view0
+ # Intialize lists to store data for all views
+ # Ground truth quantities
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+ no_norm_gt_pts = []
+ no_norm_gt_pts_cam = []
+ no_norm_gt_depth = []
+ no_norm_gt_pose_trans = []
+ valid_masks = []
+ gt_ray_directions = []
+ gt_pose_quats = []
+ # Predicted quantities
+ no_norm_pr_pts = []
+ no_norm_pr_pts_cam = []
+ no_norm_pr_depth = []
+ no_norm_pr_pose_trans = []
+ pr_ray_directions = []
+ pr_pose_quats = []
+ metric_pr_pts_to_compute_scale = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get the ground truth
+ no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+ no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
+ gt_ray_directions.append(batch[i]["ray_directions_cam"])
+ if self.depth_type_for_loss == "depth_along_ray":
+ no_norm_gt_depth.append(batch[i]["depth_along_ray"])
+ elif self.depth_type_for_loss == "depth_z":
+ no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
+ if i == 0:
+ # For view0, initialize identity pose
+ gt_pose_quats.append(
+ torch.tensor(
+ [0, 0, 0, 1],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ no_norm_gt_pose_trans.append(
+ torch.tensor(
+ [0, 0, 0],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ else:
+ # For other views, transform pose to view0's frame
+ gt_pose_quats_world = batch[i]["camera_pose_quats"]
+ no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
+ gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
+ transform_pose_using_quats_and_trans_2_to_1(
+ batch[0]["camera_pose_quats"],
+ batch[0]["camera_pose_trans"],
+ gt_pose_quats_world,
+ no_norm_gt_pose_trans_world,
+ )
+ )
+ gt_pose_quats.append(gt_pose_quats_in_view0)
+ no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
+
+ # Get predictions for normalized loss
+ if self.depth_type_for_loss == "depth_along_ray":
+ curr_view_no_norm_depth = preds[i]["depth_along_ray"]
+ elif self.depth_type_for_loss == "depth_z":
+ curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:]
+ if "metric_scaling_factor" in preds[i].keys():
+ # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
+ # This detaches the predicted metric scaling factor from the geometry based loss
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_pr_pose_trans = (
+ preds[i]["cam_trans"] / preds[i]["metric_scaling_factor"]
+ )
+ else:
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"]
+ curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
+ curr_view_no_norm_depth = curr_view_no_norm_depth
+ curr_view_no_norm_pr_pose_trans = preds[i]["cam_trans"]
+ no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
+ no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
+ no_norm_pr_depth.append(curr_view_no_norm_depth)
+ no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
+ pr_ray_directions.append(preds[i]["ray_directions"])
+ pr_pose_quats.append(preds[i]["cam_quats"])
+
+ # Get the predicted metric scale points
+ if "metric_scaling_factor" in preds[i].keys():
+ # Detach the raw predicted points so that the scale loss is only applied to the scaling factor
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.detach()
+ * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
+ )
+ else:
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.clone()
+ )
+ metric_pr_pts_to_compute_scale.append(
+ curr_view_metric_pr_pts_to_compute_scale
+ )
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
+ gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
+ gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
+
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+ pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
+ pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
+ pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
+
+ # Normalize the predicted points if specified
+ if self.norm_predictions:
+ pr_normalization_output = normalize_multiple_pointclouds(
+ no_norm_pr_pts,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+ pr_norm_factor = pr_normalization_output[-1]
+
+ # Normalize the ground truth points
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+ gt_norm_factor = gt_normalization_output[-1]
+
+ for i in range(n_views):
+ if self.norm_predictions:
+ # Assign the normalized predictions
+ pr_pts[i] = pr_pts_norm[i]
+ pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor
+ pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor
+ pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0]
+ else:
+ pr_pts[i] = no_norm_pr_pts[i]
+ pr_pts_cam[i] = no_norm_pr_pts_cam[i]
+ pr_depth[i] = no_norm_pr_depth[i]
+ pr_pose_trans[i] = no_norm_pr_pose_trans[i]
+ # Assign the normalized ground truth quantities
+ gt_pts[i] = gt_pts_norm[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor
+ gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0]
+
+ # Get the mask indicating ground truth metric scale quantities
+ metric_scale_mask = batch[0]["is_metric_scale"]
+ valid_gt_norm_factor_mask = (
+ gt_norm_factor[:, 0, 0, 0] > 1e-8
+ ) # Mask out cases where depth for all views is invalid
+ valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
+
+ if valid_metric_scale_mask.any():
+ # Compute the scale norm factor using the predicted metric scale points
+ metric_pr_normalization_output = normalize_multiple_pointclouds(
+ metric_pr_pts_to_compute_scale,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_metric_norm_factor = metric_pr_normalization_output[-1]
+
+ # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
+ gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
+ pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
+ else:
+ gt_metric_norm_factor = None
+ pr_metric_norm_factor = None
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for i in range(n_views):
+ ambiguous_masks.append(
+ (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
+ )
+
+ # Pack into info dicts
+ gt_info = []
+ pred_info = []
+ for i in range(n_views):
+ gt_info.append(
+ {
+ "ray_directions": gt_ray_directions[i],
+ self.depth_type_for_loss: gt_depth[i],
+ "pose_trans": gt_pose_trans[i],
+ "pose_quats": gt_pose_quats[i],
+ "pts3d": gt_pts[i],
+ "pts3d_cam": gt_pts_cam[i],
+ }
+ )
+ pred_info.append(
+ {
+ "ray_directions": pr_ray_directions[i],
+ self.depth_type_for_loss: pr_depth[i],
+ "pose_trans": pr_pose_trans[i],
+ "pose_quats": pr_pose_quats[i],
+ "pts3d": pr_pts[i],
+ "pts3d_cam": pr_pts_cam[i],
+ }
+ )
+
+ return (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ )
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+
+ for i in range(n_views):
+ # Get the GT factored quantities for the current view
+ gt_pts3d = gt_info[i]["pts3d"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ gt_depth = gt_info[i][self.depth_type_for_loss]
+ gt_pose_trans = gt_info[i]["pose_trans"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Get the predicted factored quantities for the current view
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss]
+ pred_pose_trans = pred_info[i]["pose_trans"]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+
+ # Get the predicted world-frame pointmaps using the different factors
+ if self.depth_type_for_loss == "depth_along_ray":
+ pred_ray_directions_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ pred_ray_directions,
+ gt_depth,
+ gt_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_depth_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ pred_depth,
+ gt_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_pose_trans_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ gt_depth,
+ pred_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_pose_quats_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ gt_depth,
+ gt_pose_trans,
+ pred_pose_quats,
+ )
+ )
+ else:
+ raise NotImplementedError
+
+ # Mask out the valid quantities as required
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]]
+ pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]]
+ pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]]
+ pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]]
+ gt_pts3d = gt_pts3d[valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, pts_dim = gt_pts3d.shape
+ pred_ray_directions_pts3d = pred_ray_directions_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim)
+ pred_pose_trans_pts3d = pred_pose_trans_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ pred_pose_quats_pts3d = pred_pose_quats_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space if specified
+ if self.loss_in_log:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d)
+ pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d)
+ pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d)
+ pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ pose_quats_loss = self.criterion(
+ pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats"
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+ losses_dict = {}
+ losses_dict.update(
+ {
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss(
+ DisentangledFactoredGeometryScaleRegr3D
+):
+ """
+ Disentangled Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ scale_loss_weight=1,
+ apply_normal_and_gm_loss_to_synthetic_data_only=True,
+ normal_loss_weight=1,
+ gm_loss_weight=1,
+ ):
+ """
+ Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale.
+ See parent class (DisentangledFactoredGeometryScaleRegr3D) for more details.
+ Additionally computes:
+ (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
+ (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
+ If False, apply the normal and gm loss to all data. Default: True.
+ normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
+ gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
+ """
+ super().__init__(
+ criterion=criterion,
+ norm_predictions=norm_predictions,
+ norm_mode=norm_mode,
+ loss_in_log=loss_in_log,
+ flatten_across_image_only=flatten_across_image_only,
+ depth_type_for_loss=depth_type_for_loss,
+ depth_loss_weight=depth_loss_weight,
+ ray_directions_loss_weight=ray_directions_loss_weight,
+ pose_quats_loss_weight=pose_quats_loss_weight,
+ pose_trans_loss_weight=pose_trans_loss_weight,
+ scale_loss_weight=scale_loss_weight,
+ )
+ self.apply_normal_and_gm_loss_to_synthetic_data_only = (
+ apply_normal_and_gm_loss_to_synthetic_data_only
+ )
+ self.normal_loss_weight = normal_loss_weight
+ self.gm_loss_weight = gm_loss_weight
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ normal_losses = []
+ gradient_matching_losses = []
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+
+ for i in range(n_views):
+ # Get the camera frame points, log space depth_z & valid masks
+ pred_local_pts3d = pred_info[i]["pts3d_cam"]
+ pred_depth_z = pred_local_pts3d[..., 2:]
+ pred_depth_z = apply_log_to_norm(pred_depth_z)
+ gt_local_pts3d = gt_info[i]["pts3d_cam"]
+ gt_depth_z = gt_local_pts3d[..., 2:]
+ gt_depth_z = apply_log_to_norm(gt_depth_z)
+ valid_mask_for_normal_gm_loss = valid_masks[i].clone()
+
+ # Update the validity mask for normal & gm loss based on the synthetic data mask if required
+ if self.apply_normal_and_gm_loss_to_synthetic_data_only:
+ synthetic_mask = batch[i]["is_synthetic"] # (B, )
+ synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
+ synthetic_mask = synthetic_mask.expand(
+ -1, pred_depth_z.shape[1], pred_depth_z.shape[2]
+ ) # (B, H, W)
+ valid_mask_for_normal_gm_loss = (
+ valid_mask_for_normal_gm_loss & synthetic_mask
+ )
+
+ # Compute the normal loss
+ normal_loss = compute_normal_loss(
+ pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
+ )
+ normal_loss = normal_loss * self.normal_loss_weight
+ normal_losses.append(normal_loss)
+
+ # Compute the gradient matching loss
+ gradient_matching_loss = compute_gradient_matching_loss(
+ pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
+ )
+ gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
+ gradient_matching_losses.append(gradient_matching_loss)
+
+ # Get the GT factored quantities for the current view
+ gt_pts3d = gt_info[i]["pts3d"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ gt_depth = gt_info[i][self.depth_type_for_loss]
+ gt_pose_trans = gt_info[i]["pose_trans"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Get the predicted factored quantities for the current view
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss]
+ pred_pose_trans = pred_info[i]["pose_trans"]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+
+ # Get the predicted world-frame pointmaps using the different factors
+ if self.depth_type_for_loss == "depth_along_ray":
+ pred_ray_directions_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ pred_ray_directions,
+ gt_depth,
+ gt_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_depth_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ pred_depth,
+ gt_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_pose_trans_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ gt_depth,
+ pred_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_pose_quats_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ gt_depth,
+ gt_pose_trans,
+ pred_pose_quats,
+ )
+ )
+ else:
+ raise NotImplementedError
+
+ # Mask out the valid quantities as required
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]]
+ pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]]
+ pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]]
+ pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]]
+ gt_pts3d = gt_pts3d[valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, pts_dim = gt_pts3d.shape
+ pred_ray_directions_pts3d = pred_ray_directions_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim)
+ pred_pose_trans_pts3d = pred_pose_trans_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ pred_pose_quats_pts3d = pred_pose_quats_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space if specified
+ if self.loss_in_log:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d)
+ pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d)
+ pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d)
+ pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ pose_quats_loss = self.criterion(
+ pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats"
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+ losses_dict = {}
+ losses_dict.update(
+ {
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ "normal": {
+ "values": normal_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "gradient_matching": {
+ "values": gradient_matching_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
diff --git a/mapanything/train/profile_dataloading.py b/mapanything/train/profile_dataloading.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fb9b7615be4cbb56e07d2d182cbbd16d103070e
--- /dev/null
+++ b/mapanything/train/profile_dataloading.py
@@ -0,0 +1,285 @@
+"""
+Debug script to profile dataloading for MapAnything training.
+
+This script measures and analyzes the performance of data loading operations
+for MapAnything training workflows. It simulates the training process without
+actual model training to isolate and profile the data loading components.
+"""
+
+import datetime
+import json
+import os
+import time
+from pathlib import Path
+from typing import Sized
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+
+import mapanything.utils.train_tools as train_tools
+from mapanything.datasets import get_test_data_loader, get_train_data_loader
+from mapanything.datasets.base.base_dataset import view_name
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+def profile_dataloading(args):
+ """
+ Main profiling function that simulates the training process to measure data loading performance.
+
+ This function initializes the distributed environment, sets up datasets and data loaders,
+ and runs through training epochs to profile the data loading operations. It measures
+ the time taken for data loading without performing actual model training or optimization.
+
+ In this simulation, an epoch represents a complete pass through a chunk of the dataset.
+
+ Args:
+ args: Configuration object containing all parameters including:
+ - dataset: Dataset configuration (train_dataset, test_dataset, num_workers)
+ - train_params: Training parameters (batch_size, epochs, seed, etc.)
+ - distributed: Distributed training configuration
+ - output_dir: Directory for saving logs and profiling results
+ """
+ # Initialize distributed training if required
+ train_tools.init_distributed_mode(args.distributed)
+ global_rank = train_tools.get_rank()
+ world_size = train_tools.get_world_size() # noqa
+
+ # Init output directory and device
+ print("output_dir: " + args.output_dir)
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(", ", ",\n"))
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Fix the seed
+ seed = args.train_params.seed + train_tools.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = not args.train_params.disable_cudnn_benchmark
+
+ # Datasets and Dataloaders
+ print("Building train dataset {:s}".format(args.dataset.train_dataset))
+ data_loader_train = build_dataset(
+ dataset=args.dataset.train_dataset,
+ num_workers=args.dataset.num_workers,
+ test=False,
+ max_num_of_imgs_per_gpu=args.train_params.max_num_of_imgs_per_gpu,
+ )
+ print("Building test dataset {:s}".format(args.dataset.test_dataset))
+ test_batch_size = 2 * (
+ args.train_params.max_num_of_imgs_per_gpu // args.dataset.num_views
+ ) # Since we don't have any backward overhead
+ data_loader_test = {
+ dataset.split("(")[0]: build_dataset(
+ dataset=dataset,
+ num_workers=args.dataset.num_workers,
+ test=True,
+ batch_size=test_batch_size,
+ )
+ for dataset in args.dataset.test_dataset.split("+")
+ if "(" in dataset
+ }
+
+ def write_log_stats(epoch, train_stats, test_stats):
+ """
+ Writes profiling statistics to log files and TensorBoard.
+
+ This function collects metrics from the training and testing phases and writes them
+ to log files and TensorBoard for visualization and analysis. It only executes on the
+ main process in a distributed setting.
+
+ Args:
+ epoch: int, current epoch number
+ train_stats: dict, containing training metrics and timing information
+ test_stats: dict, containing testing metrics for each test dataset
+ """
+ if train_tools.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+
+ log_stats = dict(
+ epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()}
+ )
+ for test_name in data_loader_test:
+ if test_name not in test_stats:
+ continue
+ log_stats.update(
+ {test_name + "_" + k: v for k, v in test_stats[test_name].items()}
+ )
+
+ with open(
+ os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
+ ) as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ if global_rank == 0 and args.output_dir is not None:
+ log_writer = SummaryWriter(log_dir=args.output_dir)
+ else:
+ log_writer = None
+
+ print(f"Start training for {args.train_params.epochs} epochs")
+ start_time = time.time()
+ train_stats = test_stats = {}
+ args.train_params.start_epoch = 0
+ for epoch in range(args.train_params.start_epoch, args.train_params.epochs + 1):
+ # Save more stuff
+ write_log_stats(epoch, train_stats, test_stats)
+
+ if epoch >= args.train_params.epochs:
+ break # exit after writing last test to disk
+
+ # Train
+ train_stats = train_one_epoch(
+ data_loader_train,
+ device,
+ epoch,
+ log_writer=log_writer,
+ args=args,
+ )
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print("Training time {}".format(total_time_str))
+
+
+def build_dataset(
+ dataset, num_workers, test, batch_size=None, max_num_of_imgs_per_gpu=None
+):
+ """
+ Builds data loaders for training or testing.
+
+ Args:
+ dataset: Dataset specification string.
+ num_workers: Number of worker processes for data loading.
+ test: Boolean flag indicating whether this is a test dataset.
+ batch_size: Number of samples per batch. Defaults to None. Used only for testing.
+ max_num_of_imgs_per_gpu: Maximum number of images per GPU. Defaults to None. Used only for training.
+
+ Returns:
+ DataLoader: PyTorch DataLoader configured for the specified dataset.
+ """
+ split = ["Train", "Test"][test]
+ print(f"Building {split} Data loader for dataset: ", dataset)
+ if test:
+ assert batch_size is not None, (
+ "batch_size must be specified for testing dataloader"
+ )
+ loader = get_test_data_loader(
+ dataset=dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_mem=True,
+ shuffle=False,
+ drop_last=False,
+ )
+ else:
+ assert max_num_of_imgs_per_gpu is not None, (
+ "max_num_of_imgs_per_gpu must be specified for training dataloader"
+ )
+ loader = get_train_data_loader(
+ dataset=dataset,
+ max_num_of_imgs_per_gpu=max_num_of_imgs_per_gpu,
+ num_workers=num_workers,
+ pin_mem=True,
+ shuffle=True,
+ drop_last=True,
+ )
+
+ print(f"{split} dataset length: ", len(loader))
+ return loader
+
+
+def train_one_epoch(
+ data_loader: Sized,
+ device: torch.device,
+ epoch: int,
+ args,
+ log_writer=None,
+):
+ """
+ Simulates training for one epoch to profile data loading performance.
+
+ This function runs through a single epoch, simulating the data loading and device transfer
+ operations that would occur during actual training. It measures and logs the time taken
+ for these operations without performing actual model training.
+
+ Args:
+ data_loader: Sized, DataLoader providing the training data
+ device: torch.device, device to transfer data to (CPU or GPU)
+ epoch: int, current epoch number
+ args: object, configuration object containing training parameters including:
+ - train_params.print_freq: frequency of logging during the epoch
+ log_writer: Optional[SummaryWriter], TensorBoard SummaryWriter for logging metrics
+
+ Returns:
+ dict: Dictionary containing profiling metrics averaged over the epoch
+ """
+ metric_logger = train_tools.MetricLogger(delimiter=" ")
+ header = "Epoch: [{}]".format(epoch)
+
+ if log_writer is not None:
+ print("log_dir: {}".format(log_writer.log_dir))
+
+ if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
+ data_loader.dataset.set_epoch(epoch)
+ if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"):
+ data_loader.sampler.set_epoch(epoch)
+ if hasattr(data_loader, "batch_sampler") and hasattr(
+ data_loader.batch_sampler, "set_epoch"
+ ):
+ data_loader.batch_sampler.set_epoch(epoch)
+
+ for data_iter_step, batch in enumerate(
+ metric_logger.log_every(data_loader, args.train_params.print_freq, header)
+ ):
+ epoch_f = epoch + data_iter_step / len(data_loader)
+
+ # Simulate the device loading in loss_of_one_batch_multi_view
+ ignore_keys = set(
+ [
+ "depthmap",
+ "dataset",
+ "label",
+ "instance",
+ "idx",
+ "true_shape",
+ "rng",
+ "data_norm_type",
+ ]
+ )
+ for view in batch:
+ for name in view.keys():
+ if name in ignore_keys:
+ continue
+ view[name] = view[name].to(device, non_blocking=True)
+
+ local_rank = train_tools.get_rank()
+ n_views = len(batch)
+ batch_shape = batch[0]["img"].shape
+ first_sample_name = view_name(batch[0], batch_index=0)
+ print(
+ f"Rank: {local_rank}, Num views: {n_views}, Batch Shape: {batch_shape}, First Sample Name: {first_sample_name}",
+ force=True,
+ )
+
+ del batch
+
+ metric_logger.update(epoch=epoch_f)
+ metric_logger.update(loss=0)
+
+ # # Gather the stats from all processes
+ # metric_logger.synchronize_between_processes()
+ # print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
diff --git a/mapanything/train/training.py b/mapanything/train/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5dd83cd08caa07c9a893aba6a8e07ebf7c05c26
--- /dev/null
+++ b/mapanything/train/training.py
@@ -0,0 +1,659 @@
+"""
+Training Code for MapAnything.
+
+References:
+DUSt3R: https://github.com/naver/dust3r
+"""
+
+import datetime
+import json
+import math
+import os
+import pickle
+import sys
+import time
+from collections import defaultdict
+from pathlib import Path
+from typing import Sized
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+
+import mapanything.utils.train_tools as train_tools
+from mapanything.datasets import get_test_data_loader, get_train_data_loader
+from mapanything.models import init_model
+from mapanything.train.losses import * # noqa
+from mapanything.utils.inference import loss_of_one_batch_multi_view
+from mapanything.utils.train_tools import NativeScalerWithGradNormCount as NativeScaler
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+def train(args):
+ """
+ Main training function that handles the entire training process.
+
+ This function initializes the distributed training environment, sets up datasets,
+ initializes the model, optimizer, and loss functions, and manages the training
+ and evaluation loop across multiple epochs.
+
+ In this training, an epoch is just a chunk of the entire dataset.
+
+ Args:
+ args: Configuration object containing all training parameters including
+ dataset configs, model configs, training parameters, and loss functions.
+ """
+ # Initialize distributed training if required
+ train_tools.init_distributed_mode(args.distributed)
+ global_rank = train_tools.get_rank()
+ world_size = train_tools.get_world_size() # noqa
+
+ # Init output directory and device
+ print("output_dir: " + args.output_dir)
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(", ", ",\n"))
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Fix the seed
+ seed = args.train_params.seed + train_tools.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = not args.train_params.disable_cudnn_benchmark
+
+ # Datasets and Dataloaders
+ print("Building train dataset {:s}".format(args.dataset.train_dataset))
+ data_loader_train = build_dataset(
+ dataset=args.dataset.train_dataset,
+ num_workers=args.dataset.num_workers,
+ test=False,
+ max_num_of_imgs_per_gpu=args.train_params.max_num_of_imgs_per_gpu,
+ )
+ print("Building test dataset {:s}".format(args.dataset.test_dataset))
+ test_batch_size = 2 * (
+ args.train_params.max_num_of_imgs_per_gpu // args.dataset.num_views
+ ) # Since we don't have any backward overhead
+ data_loader_test = {
+ dataset.split("(")[0]: build_dataset(
+ dataset=dataset,
+ num_workers=args.dataset.num_workers,
+ test=True,
+ batch_size=test_batch_size,
+ )
+ for dataset in args.dataset.test_dataset.split("+")
+ if "(" in dataset
+ }
+
+ # Load Model
+ if global_rank == 0:
+ model = init_model(
+ args.model.model_str,
+ args.model.model_config,
+ torch_hub_force_reload=args.model.torch_hub_force_reload,
+ )
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier() # Make sure the model is initialized before proceeding
+ if global_rank != 0:
+ model = init_model(
+ args.model.model_str, args.model.model_config, torch_hub_force_reload=False
+ )
+ model.to(device) # Move model to device
+ model_without_ddp = model
+ print("Model = %s" % str(model_without_ddp))
+
+ # Criterion
+ print(f">> Creating train criterion = {args.loss.train_criterion}")
+ train_criterion = eval(args.loss.train_criterion).to(device)
+ print(
+ f">> Creating test criterion = {args.loss.test_criterion or args.loss.train_criterion}"
+ )
+ test_criterion = eval(args.loss.test_criterion or args.loss.train_criterion).to(
+ device
+ )
+
+ # Load pretrained model if provided
+ if args.model.pretrained:
+ print("Loading pretrained: ", args.model.pretrained)
+ ckpt = torch.load(
+ args.model.pretrained, map_location=device, weights_only=False
+ )
+ print(model.load_state_dict(ckpt["model"], strict=False))
+ del ckpt # in case it occupies memory
+
+ # Init model for DDP training
+ if args.distributed.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[args.distributed.gpu],
+ find_unused_parameters=True,
+ static_graph=False,
+ )
+ model_without_ddp = model.module
+
+ # Optimizer and loss scaler for gradient accumulation
+ # Following timm: set wd as 0 for bias and norm layers
+ param_groups, param_groups_name_to_idx_map, param_groups_idx_to_name_map = (
+ train_tools.get_parameter_groups(
+ model_without_ddp,
+ args.train_params.lr,
+ args.train_params.weight_decay,
+ submodule_configs=args.train_params.submodule_configs,
+ warn_not_in_submodule=args.train_params.warn_not_in_submodule,
+ )
+ )
+ optimizer = torch.optim.AdamW(
+ param_groups, lr=args.train_params.lr, betas=(0.9, 0.95)
+ )
+ print(optimizer)
+ loss_scaler = NativeScaler()
+
+ def write_log_stats(epoch, train_stats, test_stats):
+ """
+ Writes training and testing statistics to log files and TensorBoard.
+
+ Args:
+ epoch: Current epoch number.
+ train_stats: Dictionary containing training metrics.
+ test_stats: Dictionary containing testing metrics for each test dataset.
+ """
+ if train_tools.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+
+ log_stats = dict(
+ epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()}
+ )
+ for test_name in data_loader_test:
+ if test_name not in test_stats:
+ continue
+ log_stats.update(
+ {test_name + "_" + k: v for k, v in test_stats[test_name].items()}
+ )
+
+ with open(
+ os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
+ ) as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ def save_model(epoch, fname, best_so_far):
+ """
+ Saves model checkpoint to disk.
+
+ Args:
+ epoch: Current epoch number.
+ fname: Filename or identifier for the checkpoint.
+ best_so_far: Best validation metric achieved so far.
+ """
+ train_tools.save_model(
+ args=args,
+ model_without_ddp=model_without_ddp,
+ optimizer=optimizer,
+ loss_scaler=loss_scaler,
+ epoch=epoch,
+ fname=fname,
+ best_so_far=best_so_far,
+ )
+
+ # Resume from a checkpoint if needed
+ last_ckpt_fname = os.path.join(args.output_dir, "checkpoint-last.pth")
+ if args.train_params.resume and os.path.isfile(last_ckpt_fname):
+ args.train_params.resume_ckpt = last_ckpt_fname
+ else:
+ args.train_params.resume_ckpt = None
+ best_so_far = train_tools.load_model(
+ train_args=args.train_params,
+ model_without_ddp=model_without_ddp,
+ optimizer=optimizer,
+ loss_scaler=loss_scaler,
+ )
+ if best_so_far is None:
+ best_so_far = float("inf")
+
+ if global_rank == 0 and args.output_dir is not None:
+ log_writer = SummaryWriter(log_dir=args.output_dir)
+ else:
+ log_writer = None
+
+ print(f"Start training for {args.train_params.epochs} epochs")
+ start_time = time.time()
+ train_stats = test_stats = {}
+ for epoch in range(args.train_params.start_epoch, args.train_params.epochs + 1):
+ # Save immediately the last checkpoint
+ if epoch > args.train_params.start_epoch:
+ if (
+ args.train_params.save_freq
+ and epoch % args.train_params.save_freq == 0
+ or epoch == args.train_params.epochs
+ ):
+ save_model(epoch - 1, "last", best_so_far)
+
+ # Test on multiple datasets
+ new_best = False
+ test_stats = {}
+ if (
+ args.train_params.eval_freq > 0
+ and epoch % args.train_params.eval_freq == 0
+ and epoch > 0
+ ):
+ for test_name, testset in data_loader_test.items():
+ print(f"Testing on {test_name} ...")
+ stats = test_one_epoch(
+ model,
+ test_criterion,
+ testset,
+ device,
+ epoch,
+ log_writer=log_writer,
+ args=args,
+ prefix=test_name,
+ )
+ test_stats[test_name] = stats
+
+ # Calculate average test loss median
+ avg_test_loss_med = np.mean(
+ [stats["loss_med"] for stats in test_stats.values()]
+ )
+ test_stats["Average Test Loss Median"] = avg_test_loss_med
+ # Save best
+ if avg_test_loss_med < best_so_far:
+ best_so_far = avg_test_loss_med
+ new_best = True
+
+ # Save more stuff
+ write_log_stats(epoch, train_stats, test_stats)
+
+ if epoch > args.train_params.start_epoch:
+ if args.train_params.keep_freq and epoch % args.train_params.keep_freq == 0:
+ save_model(epoch - 1, str(epoch), best_so_far)
+ if new_best:
+ save_model(epoch - 1, "best", best_so_far)
+ if epoch >= args.train_params.epochs:
+ break # exit after writing last test to disk
+
+ # Train
+ train_stats = train_one_epoch(
+ model,
+ train_criterion,
+ data_loader_train,
+ optimizer,
+ device,
+ epoch,
+ loss_scaler,
+ log_writer=log_writer,
+ args=args,
+ param_groups_name_to_idx_map=param_groups_name_to_idx_map,
+ param_groups_idx_to_name_map=param_groups_idx_to_name_map,
+ model_without_ddp=model_without_ddp,
+ )
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print("Training time {}".format(total_time_str))
+
+ save_final_model(
+ args, args.train_params.epochs, model_without_ddp, best_so_far=best_so_far
+ )
+
+
+def save_final_model(args, epoch, model_without_ddp, best_so_far=None):
+ """
+ Saves the final model checkpoint after training completion.
+
+ Args:
+ args: Configuration object containing output directory information.
+ epoch: Current epoch number.
+ model_without_ddp: Model state dictionary or model instance without DistributedDataParallel wrapper.
+ best_so_far: Optional; Best validation metric achieved during training.
+ """
+ output_dir = Path(args.output_dir)
+ checkpoint_path = output_dir / "checkpoint-final.pth"
+ to_save = {
+ "args": args,
+ "model": model_without_ddp
+ if isinstance(model_without_ddp, dict)
+ else model_without_ddp.cpu().state_dict(),
+ "epoch": epoch,
+ }
+ if best_so_far is not None:
+ to_save["best_so_far"] = best_so_far
+ print(f">> Saving model to {checkpoint_path} ...")
+ train_tools.save_on_master(to_save, checkpoint_path)
+
+
+def build_dataset(
+ dataset, num_workers, test, batch_size=None, max_num_of_imgs_per_gpu=None
+):
+ """
+ Builds data loaders for training or testing.
+
+ Args:
+ dataset: Dataset specification string.
+ num_workers: Number of worker processes for data loading.
+ test: Boolean flag indicating whether this is a test dataset.
+ batch_size: Number of samples per batch. Defaults to None. Used only for testing.
+ max_num_of_imgs_per_gpu: Maximum number of images per GPU. Defaults to None. Used only for training.
+
+ Returns:
+ DataLoader: PyTorch DataLoader configured for the specified dataset.
+ """
+ split = ["Train", "Test"][test]
+ print(f"Building {split} Data loader for dataset: ", dataset)
+ if test:
+ assert batch_size is not None, (
+ "batch_size must be specified for testing dataloader"
+ )
+ loader = get_test_data_loader(
+ dataset=dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_mem=True,
+ shuffle=False,
+ drop_last=False,
+ )
+ else:
+ assert max_num_of_imgs_per_gpu is not None, (
+ "max_num_of_imgs_per_gpu must be specified for training dataloader"
+ )
+ loader = get_train_data_loader(
+ dataset=dataset,
+ max_num_of_imgs_per_gpu=max_num_of_imgs_per_gpu,
+ num_workers=num_workers,
+ pin_mem=True,
+ shuffle=True,
+ drop_last=True,
+ )
+
+ print(f"{split} dataset length: ", len(loader))
+ return loader
+
+
+def train_one_epoch(
+ model: torch.nn.Module,
+ criterion: torch.nn.Module,
+ data_loader: Sized,
+ optimizer: torch.optim.Optimizer,
+ device: torch.device,
+ epoch: int,
+ loss_scaler,
+ args,
+ log_writer=None,
+ param_groups_name_to_idx_map=None,
+ param_groups_idx_to_name_map=None,
+ model_without_ddp=None,
+):
+ """
+ Trains the model for one epoch.
+ Epoch is just a chunk of the entire dataset.
+
+ This function handles the training loop for a single epoch, including forward/backward passes,
+ gradient accumulation, learning rate scheduling, and logging metrics.
+
+ Args:
+ model: The neural network model to train.
+ criterion: Loss function to optimize.
+ data_loader: DataLoader providing the training data.
+ optimizer: Optimizer for updating model parameters.
+ device: Device to run training on (CPU or GPU).
+ epoch: Current epoch number.
+ loss_scaler: Scaler for gradient accumulation and mixed precision training.
+ args: Configuration object containing training parameters.
+ log_writer: Optional; TensorBoard SummaryWriter for logging.
+ param_groups_name_to_idx_map: Mapping from parameter group names to indices.
+ param_groups_idx_to_name_map: Mapping from parameter group indices to names.
+ model_without_ddp: Model without DistributedDataParallel wrapper for debugging.
+
+ Returns:
+ dict: Dictionary containing training metrics averaged over the epoch.
+ """
+ model.train(True)
+ metric_logger = train_tools.MetricLogger(delimiter=" ")
+ for submodule_name in param_groups_name_to_idx_map:
+ lr_name = f"lr_{submodule_name}" if submodule_name != "default" else "lr"
+ metric_logger.add_meter(
+ lr_name, train_tools.SmoothedValue(window_size=1, fmt="{value:.6f}")
+ )
+ header = "Epoch: [{}]".format(epoch)
+ accum_iter = args.train_params.accum_iter
+
+ if log_writer is not None:
+ print("log_dir: {}".format(log_writer.log_dir))
+
+ if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
+ data_loader.dataset.set_epoch(epoch)
+ if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"):
+ data_loader.sampler.set_epoch(epoch)
+ if hasattr(data_loader, "batch_sampler") and hasattr(
+ data_loader.batch_sampler, "set_epoch"
+ ):
+ data_loader.batch_sampler.set_epoch(epoch)
+
+ optimizer.zero_grad()
+
+ for data_iter_step, batch in enumerate(
+ metric_logger.log_every(data_loader, args.train_params.print_freq, header)
+ ):
+ n_views = len(batch)
+ epoch_f = epoch + data_iter_step / len(data_loader)
+
+ # We use a per iteration (instead of per epoch) lr scheduler
+ if data_iter_step % accum_iter == 0:
+ train_tools.adjust_learning_rate(
+ optimizer,
+ epoch_f,
+ args.train_params,
+ param_groups_idx_to_name_map,
+ args.train_params.submodule_configs,
+ )
+
+ loss_tuple = loss_of_one_batch_multi_view(
+ batch,
+ model,
+ criterion,
+ device,
+ use_amp=bool(args.train_params.amp),
+ amp_dtype=args.train_params.amp_dtype,
+ ret="loss",
+ )
+ loss, loss_details = loss_tuple # criterion returns two values
+ if n_views > 2:
+ loss = loss * (
+ 2 / n_views
+ ) # scale the loss relative to the number of views (base is 2 views)
+ loss_value = float(loss)
+
+ if not math.isfinite(loss_value) or (loss_value > 1000):
+ print("Loss is {}, stopping training".format(loss_value), force=True)
+ print(f"Loss Details: {loss_details}", force=True)
+ print(f"Epoch: {epoch}, Data Iteration: {data_iter_step}", force=True)
+ # Save the current batch to the output folder for further inspection
+ for view_idx, view in enumerate(batch):
+ view_cpu = {}
+ for k, v in view.items():
+ view_cpu[k] = v.cpu() if isinstance(v, torch.Tensor) else v
+ with open(
+ os.path.join(args.output_dir, f"batch_view_{view_idx}.pkl"), "wb"
+ ) as f:
+ pickle.dump(view_cpu, f)
+ # Save the model to the output folder for further inspection
+ checkpoint_debug_path = os.path.join(
+ args.output_dir, "checkpoint-debug.pth"
+ )
+ to_save_debug = {
+ "args": args,
+ "model": (
+ model_without_ddp
+ if isinstance(model_without_ddp, dict)
+ else model_without_ddp.cpu().state_dict()
+ ),
+ "epoch": epoch,
+ "data_iter_step": data_iter_step,
+ }
+ torch.save(to_save_debug, checkpoint_debug_path)
+ print(f"Saved debugging material to {args.output_dir}", force=True)
+ sys.exit(1)
+
+ # Scale the loss by the number of gradient accumulation iterations
+ loss /= accum_iter
+
+ # Compute the scaled gradients (also clip the gradients to max norm of 1)
+ gradient_norm = loss_scaler(
+ loss,
+ optimizer,
+ parameters=model.parameters(),
+ update_grad=(data_iter_step + 1) % accum_iter == 0,
+ clip_grad=1.0,
+ )
+
+ # Zero out the gradients to prepare for the next iteration of gradient descent
+ if (data_iter_step + 1) % accum_iter == 0:
+ optimizer.zero_grad()
+
+ del loss
+ del batch
+
+ metric_logger.update(epoch=epoch_f)
+ for submodule_name in param_groups_name_to_idx_map:
+ lr_name = f"lr_{submodule_name}" if submodule_name != "default" else "lr"
+ log_lr = optimizer.param_groups[
+ param_groups_name_to_idx_map[submodule_name][0]
+ ]["lr"]
+ metric_logger.meters[lr_name].update(log_lr)
+ metric_logger.update(loss=loss_value, **loss_details)
+
+ if (data_iter_step + 1) % accum_iter == 0 and (
+ (data_iter_step + 1) % (accum_iter * args.train_params.print_freq)
+ ) == 0:
+ loss_value_reduce = train_tools.all_reduce_mean(
+ loss_value
+ ) # MUST BE EXECUTED BY ALL NODES
+ if log_writer is None:
+ continue
+ """
+ We use epoch_1000x as the x-axis in tensorboard.
+ This calibrates different curves when batch size changes.
+ """
+ epoch_1000x = int(epoch_f * 1000)
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
+ if gradient_norm is not None:
+ log_writer.add_scalar("train_grad_norm", gradient_norm, epoch_1000x)
+ for submodule_name in param_groups_name_to_idx_map:
+ lr_name = (
+ f"train_lr_{submodule_name}"
+ if submodule_name != "default"
+ else "train_lr"
+ )
+ log_lr = optimizer.param_groups[
+ param_groups_name_to_idx_map[submodule_name][0]
+ ]["lr"]
+ log_writer.add_scalar(lr_name, log_lr, epoch_1000x)
+ log_writer.add_scalar("train_iter", epoch_1000x, epoch_1000x)
+ for name, val in loss_details.items():
+ log_writer.add_scalar("train_" + name, val, epoch_1000x)
+
+ # # Gather the stats from all processes
+ # metric_logger.synchronize_between_processes()
+ # print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def test_one_epoch(
+ model: torch.nn.Module,
+ criterion: torch.nn.Module,
+ data_loader: Sized,
+ device: torch.device,
+ epoch: int,
+ args,
+ log_writer=None,
+ prefix="test",
+):
+ """
+ Evaluates the model on a test dataset for one epoch.
+ Epoch is just a chunk of the entire dataset.
+
+ This function runs evaluation on the test dataset without computing gradients,
+ and collects metrics for model performance assessment.
+
+ Args:
+ model: The neural network model to evaluate.
+ criterion: Loss function for evaluation.
+ data_loader: DataLoader providing the test data.
+ device: Device to run evaluation on (CPU or GPU).
+ epoch: Current epoch number.
+ args: Configuration object containing evaluation parameters.
+ log_writer: Optional; TensorBoard SummaryWriter for logging.
+ prefix: String prefix for logging metrics.
+
+ Returns:
+ dict: Dictionary containing evaluation metrics (average and median values).
+ """
+ model.eval()
+ metric_logger = train_tools.MetricLogger(delimiter=" ")
+ metric_logger.meters = defaultdict(
+ lambda: train_tools.SmoothedValue(window_size=9**9)
+ )
+ header = "Test Epoch: [{}]".format(epoch)
+
+ if log_writer is not None:
+ print("log_dir: {}".format(log_writer.log_dir))
+
+ if args.train_params.freeze_val_samples_across_all_epochs:
+ dataloader_epoch = 0
+ else:
+ dataloader_epoch = epoch
+ if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
+ data_loader.dataset.set_epoch(dataloader_epoch)
+ if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"):
+ data_loader.sampler.set_epoch(dataloader_epoch)
+ if hasattr(data_loader, "batch_sampler") and hasattr(
+ data_loader.batch_sampler, "set_epoch"
+ ):
+ data_loader.batch_sampler.set_epoch(dataloader_epoch)
+
+ for _, batch in enumerate(
+ metric_logger.log_every(data_loader, args.train_params.print_freq, header)
+ ):
+ n_views = len(batch)
+ loss_tuple = loss_of_one_batch_multi_view(
+ batch,
+ model,
+ criterion,
+ device,
+ use_amp=bool(args.train_params.amp),
+ amp_dtype=args.train_params.amp_dtype,
+ ret="loss",
+ )
+ loss_value, loss_details = loss_tuple # criterion returns two values
+ if n_views > 2:
+ loss_value = loss_value * (
+ 2 / n_views
+ ) # scale the loss relative to the number of views (base is 2 views)
+ metric_logger.update(loss=float(loss_value), **loss_details)
+
+ # # Gather the stats from all processes
+ # metric_logger.synchronize_between_processes()
+ # print("Averaged stats:", metric_logger)
+
+ aggs = [("avg", "global_avg"), ("med", "median")]
+ results = {
+ f"{k}_{tag}": getattr(meter, attr)
+ for k, meter in metric_logger.meters.items()
+ for tag, attr in aggs
+ }
+
+ if log_writer is not None:
+ for name, val in results.items():
+ log_writer.add_scalar(prefix + "_" + name, val, 1000 * epoch)
+
+ return results
diff --git a/mapanything/utils/__init__.py b/mapanything/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/utils/cropping.py b/mapanything/utils/cropping.py
new file mode 100644
index 0000000000000000000000000000000000000000..709ff34f56276e80dc842d6689a88bc84363679b
--- /dev/null
+++ b/mapanything/utils/cropping.py
@@ -0,0 +1,462 @@
+"""
+Utility functions for cropping and resizing data while maintaining proper cameras.
+
+References: DUSt3R
+"""
+
+import cv2
+import numpy as np
+import PIL.Image
+
+try:
+ lanczos = PIL.Image.Resampling.LANCZOS
+ bicubic = PIL.Image.Resampling.BICUBIC
+except AttributeError:
+ lanczos = PIL.Image.LANCZOS
+ bicubic = PIL.Image.BICUBIC
+
+from mapanything.utils.geometry import (
+ colmap_to_opencv_intrinsics,
+ opencv_to_colmap_intrinsics,
+)
+
+
+class ImageList:
+ """
+ Convenience class to apply the same operation to a whole set of images.
+
+ This class wraps a list of PIL.Image objects and provides methods to perform
+ operations on all images simultaneously.
+ """
+
+ def __init__(self, images):
+ if not isinstance(images, (tuple, list, set)):
+ images = [images]
+ self.images = []
+ for image in images:
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+ self.images.append(image)
+
+ def __len__(self):
+ """Return the number of images in the list."""
+ return len(self.images)
+
+ def to_pil(self):
+ """
+ Convert ImageList back to PIL Image(s).
+
+ Returns:
+ PIL.Image.Image or tuple: Single PIL Image if list contains one image,
+ or tuple of PIL Images if multiple images
+ """
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
+
+ @property
+ def size(self):
+ """
+ Get the size of images in the list.
+
+ Returns:
+ tuple: (width, height) of the images
+
+ Raises:
+ AssertionError: If images have different sizes
+ """
+ sizes = [im.size for im in self.images]
+ assert all(sizes[0] == s for s in sizes), "All images must have the same size"
+ return sizes[0]
+
+ def resize(self, *args, **kwargs):
+ """
+ Resize all images with the same parameters.
+
+ Args:
+ *args, **kwargs: Arguments passed to PIL.Image.resize()
+
+ Returns:
+ ImageList: New ImageList containing resized images
+ """
+ return ImageList(self._dispatch("resize", *args, **kwargs))
+
+ def crop(self, *args, **kwargs):
+ """
+ Crop all images with the same parameters.
+
+ Args:
+ *args, **kwargs: Arguments passed to PIL.Image.crop()
+
+ Returns:
+ ImageList: New ImageList containing cropped images
+ """
+ return ImageList(self._dispatch("crop", *args, **kwargs))
+
+ def _dispatch(self, func, *args, **kwargs):
+ """
+ Apply a PIL.Image method to all images in the list.
+
+ Args:
+ func (str): Name of the PIL.Image method to call
+ *args, **kwargs: Arguments to pass to the method
+
+ Returns:
+ list: List of results from applying the method to each image
+ """
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
+
+
+def resize_with_nearest_interpolation_to_match_aspect_ratio(input_data, img_h, img_w):
+ """
+ Resize input map to match the aspect ratio of an image while ensuring
+ the input resolution never increases beyond the original.
+ Uses nearest interpolation for resizing.
+
+ Args:
+ input_data (np.ndarray): The input map to resize
+ img_h (int): Height of the target image
+ img_w (int): Width of the target image
+
+ Returns:
+ tuple: (resized_input, target_h, target_w)
+ - resized_input: The resized input map
+ - target_h: The target height used for resizing
+ - target_w: The target width used for resizing
+ """
+ # Get the dimensions of the input map
+ input_h, input_w = input_data.shape[:2]
+
+ # Calculate aspect ratios
+ img_aspect = img_w / img_h
+
+ # Option 1: Keep input_w fixed and calculate new height
+ option1_h = int(input_w / img_aspect)
+ # Option 2: Keep input_h fixed and calculate new width
+ option2_w = int(input_h * img_aspect)
+
+ # Check if either option would increase a dimension
+ option1_increases = option1_h > input_h
+ option2_increases = option2_w > input_w
+
+ if option1_increases and option2_increases:
+ # Both options would increase a dimension, so we need to scale down both dimensions
+ # Find the scaling factor that preserves aspect ratio and ensures no dimension increases
+ scale_h = input_h / img_h
+ scale_w = input_w / img_w
+ scale = min(scale_h, scale_w)
+
+ target_input_h = int(img_h * scale)
+ target_input_w = int(img_w * scale)
+ elif option1_increases:
+ # Option 1 would increase height, so use option 2
+ target_input_h = input_h
+ target_input_w = option2_w
+ elif option2_increases:
+ # Option 2 would increase width, so use option 1
+ target_input_w = input_w
+ target_input_h = option1_h
+ else:
+ # Neither option increases dimensions, choose the one that maintains resolution better
+ if abs(input_h * input_w - input_w * option1_h) < abs(
+ input_h * input_w - option2_w * input_h
+ ):
+ # Option 1 is better: keep width fixed, adjust height
+ target_input_w = input_w
+ target_input_h = option1_h
+ else:
+ # Option 2 is better: keep height fixed, adjust width
+ target_input_h = input_h
+ target_input_w = option2_w
+
+ # Resize input using nearest interpolation to maintain input values
+ if target_input_h != input_h or target_input_w != input_w:
+ resized_input = cv2.resize(
+ input_data,
+ (target_input_w, target_input_h),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ else:
+ resized_input = input_data
+
+ return resized_input, target_input_h, target_input_w
+
+
+def rescale_image_and_other_optional_info(
+ image,
+ output_resolution,
+ depthmap=None,
+ camera_intrinsics=None,
+ force=True,
+ additional_quantities_to_be_resized_with_nearest=None,
+):
+ """
+ Rescale the image and depthmap to the output resolution.
+ If the image is larger than the output resolution, it is rescaled with lanczos interpolation.
+ If force is false and the image is smaller than the output resolution, it is not rescaled.
+ If force is true and the image is smaller than the output resolution, it is rescaled with bicubic interpolation.
+ Depth and other quantities are rescaled with nearest interpolation.
+
+ Args:
+ image (PIL.Image.Image or np.ndarray): The input image to be rescaled.
+ output_resolution (tuple): The desired output resolution as a tuple (width, height).
+ depthmap (np.ndarray, optional): The depth map associated with the image. Defaults to None.
+ camera_intrinsics (np.ndarray, optional): The camera intrinsics matrix. Defaults to None.
+ force (bool, optional): If True, force rescaling even if the image is smaller than the output resolution. Defaults to True.
+ additional_quantities_to_be_resized_with_nearest (list of np.ndarray, optional): Additional quantities to be rescaled using nearest interpolation. Defaults to None.
+
+ Returns:
+ tuple: A tuple containing:
+ - The rescaled image (PIL.Image.Image)
+ - The rescaled depthmap (numpy.ndarray or None)
+ - The updated camera intrinsics (numpy.ndarray or None)
+ - The list of rescaled additional quantities (list of numpy.ndarray or None)
+ """
+ image = ImageList(image)
+ input_resolution = np.array(image.size) # (W, H)
+ output_resolution = np.array(output_resolution)
+ if depthmap is not None:
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
+ if additional_quantities_to_be_resized_with_nearest is not None:
+ assert all(
+ tuple(additional_quantity.shape[:2]) == image.size[::-1]
+ for additional_quantity in additional_quantities_to_be_resized_with_nearest
+ )
+
+ # Define output resolution
+ assert output_resolution.shape == (2,)
+ scale_final = max(output_resolution / image.size) + 1e-8
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
+ output = (
+ image.to_pil(),
+ depthmap,
+ camera_intrinsics,
+ additional_quantities_to_be_resized_with_nearest,
+ )
+ return output
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
+
+ # First rescale the image so that it contains the crop
+ image = image.resize(
+ tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic
+ )
+ if depthmap is not None:
+ depthmap = cv2.resize(
+ depthmap,
+ output_resolution,
+ fx=scale_final,
+ fy=scale_final,
+ interpolation=cv2.INTER_NEAREST,
+ )
+ if additional_quantities_to_be_resized_with_nearest is not None:
+ resized_additional_quantities = []
+ for quantity in additional_quantities_to_be_resized_with_nearest:
+ resized_additional_quantities.append(
+ cv2.resize(
+ quantity,
+ output_resolution,
+ fx=scale_final,
+ fy=scale_final,
+ interpolation=cv2.INTER_NEAREST,
+ )
+ )
+ additional_quantities_to_be_resized_with_nearest = resized_additional_quantities
+
+ # No offset here; simple rescaling
+ if camera_intrinsics is not None:
+ camera_intrinsics = camera_matrix_of_crop(
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
+ )
+
+ # Return
+ return (
+ image.to_pil(),
+ depthmap,
+ camera_intrinsics,
+ additional_quantities_to_be_resized_with_nearest,
+ )
+
+
+def camera_matrix_of_crop(
+ input_camera_matrix,
+ input_resolution,
+ output_resolution,
+ scaling=1,
+ offset_factor=0.5,
+ offset=None,
+):
+ """
+ Calculate the camera matrix for a cropped image.
+
+ Args:
+ input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix
+ input_resolution (tuple or numpy.ndarray): Original image resolution as (width, height)
+ output_resolution (tuple or numpy.ndarray): Target image resolution as (width, height)
+ scaling (float, optional): Scaling factor for the image. Defaults to 1.
+ offset_factor (float, optional): Factor to determine crop offset. Defaults to 0.5 (centered).
+ offset (tuple or numpy.ndarray, optional): Explicit offset to use. If None, calculated from offset_factor.
+
+ Returns:
+ numpy.ndarray: Updated camera matrix for the cropped image
+ """
+ # Margins to offset the origin
+ margins = np.asarray(input_resolution) * scaling - output_resolution
+ assert np.all(margins >= 0.0)
+ if offset is None:
+ offset = offset_factor * margins
+
+ # Generate new camera parameters
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
+ output_camera_matrix_colmap[:2, :] *= scaling
+ output_camera_matrix_colmap[:2, 2] -= offset
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
+
+ return output_camera_matrix
+
+
+def crop_image_and_other_optional_info(
+ image,
+ crop_bbox,
+ depthmap=None,
+ camera_intrinsics=None,
+ additional_quantities=None,
+):
+ """
+ Return a crop of the input view and associated data.
+
+ Args:
+ image (PIL.Image.Image or numpy.ndarray): The input image to be cropped
+ crop_bbox (tuple): Crop bounding box as (left, top, right, bottom)
+ depthmap (numpy.ndarray, optional): Depth map associated with the image
+ camera_intrinsics (numpy.ndarray, optional): Camera intrinsics matrix
+ additional_quantities (list of numpy.ndarray, optional): Additional data arrays to crop
+
+ Returns:
+ tuple: A tuple containing:
+ - The cropped image
+ - The cropped depth map (if provided or None)
+ - Updated camera intrinsics (if provided or None)
+ - List of cropped additional quantities (if provided or None)
+ """
+ image = ImageList(image)
+ left, top, right, bottom = crop_bbox
+
+ image = image.crop((left, top, right, bottom))
+ if depthmap is not None:
+ depthmap = depthmap[top:bottom, left:right]
+ if additional_quantities is not None:
+ additional_quantities = [
+ quantity[top:bottom, left:right] for quantity in additional_quantities
+ ]
+
+ if camera_intrinsics is not None:
+ camera_intrinsics = camera_intrinsics.copy()
+ camera_intrinsics[0, 2] -= left
+ camera_intrinsics[1, 2] -= top
+
+ return (image.to_pil(), depthmap, camera_intrinsics, additional_quantities)
+
+
+def bbox_from_intrinsics_in_out(
+ input_camera_matrix, output_camera_matrix, output_resolution
+):
+ """
+ Calculate the bounding box for cropping based on input and output camera intrinsics.
+
+ Args:
+ input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix
+ output_camera_matrix (numpy.ndarray): Target camera intrinsics matrix
+ output_resolution (tuple): Target resolution as (width, height)
+
+ Returns:
+ tuple: Crop bounding box as (left, top, right, bottom)
+ """
+ out_width, out_height = output_resolution
+ left, top = np.int32(
+ np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])
+ )
+ crop_bbox = (left, top, left + out_width, top + out_height)
+ return crop_bbox
+
+
+def crop_resize_if_necessary(
+ image,
+ resolution,
+ depthmap=None,
+ intrinsics=None,
+ additional_quantities=None,
+):
+ """
+ First downsample image using LANCZOS and then crop if necessary to achieve target resolution.
+
+ This function performs high-quality downsampling followed by cropping to achieve the
+ desired output resolution while maintaining proper camera intrinsics.
+
+ Args:
+ image (PIL.Image.Image or numpy.ndarray): The input image to be processed
+ resolution (tuple): Target resolution as (width, height)
+ depthmap (numpy.ndarray, optional): Depth map associated with the image
+ intrinsics (numpy.ndarray, optional): Camera intrinsics matrix
+ additional_quantities (list of numpy.ndarray, optional): Additional data arrays to process
+
+ Returns:
+ tuple: A tuple containing the processed image and any provided additional data
+ (depthmap, intrinsics, additional_quantities) that have been similarly processed
+ """
+ # Convert image to PIL.Image.Image if necessary
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+
+ # Get width and height of image
+ original_width, original_height = image.size
+
+ # High-quality Lanczos down-scaling
+ target_rescale_resolution = np.array(resolution)
+ image, depthmap, intrinsics, additional_quantities = (
+ rescale_image_and_other_optional_info(
+ image=image,
+ output_resolution=target_rescale_resolution,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities_to_be_resized_with_nearest=additional_quantities,
+ )
+ )
+
+ # Actual cropping (if necessary)
+ if intrinsics is not None:
+ new_intrinsics = camera_matrix_of_crop(
+ input_camera_matrix=intrinsics,
+ input_resolution=image.size,
+ output_resolution=resolution,
+ offset_factor=0.5,
+ )
+ crop_bbox = bbox_from_intrinsics_in_out(
+ input_camera_matrix=intrinsics,
+ output_camera_matrix=new_intrinsics,
+ output_resolution=resolution,
+ )
+ else:
+ # Create a centered crop if no intrinsics are available
+ w, h = image.size
+ target_w, target_h = resolution
+ left = (w - target_w) // 2
+ top = (h - target_h) // 2
+ crop_bbox = (left, top, left + target_w, top + target_h)
+
+ image, depthmap, new_intrinsics, additional_quantities = (
+ crop_image_and_other_optional_info(
+ image=image,
+ crop_bbox=crop_bbox,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities=additional_quantities,
+ )
+ )
+
+ # Return the output
+ output = (image,)
+ if depthmap is not None:
+ output += (depthmap,)
+ if new_intrinsics is not None:
+ output += (new_intrinsics,)
+ if additional_quantities is not None:
+ output += (additional_quantities,)
+ return output
diff --git a/mapanything/utils/device.py b/mapanything/utils/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc9ca662767457f89d8801bdfba786bb99c6ec2a
--- /dev/null
+++ b/mapanything/utils/device.py
@@ -0,0 +1,83 @@
+"""
+Utility functions for managing computation device
+"""
+
+import numpy as np
+import torch
+
+
+def to_device(batch, device, callback=None, non_blocking=False):
+ """
+ Transfer data to another device (i.e. GPU, CPU:torch, CPU:numpy).
+
+ This function recursively processes nested data structures (lists, tuples, dicts)
+ and transfers each tensor to the specified device.
+
+ Args:
+ batch: Data to transfer (list, tuple, dict of tensors or other objects)
+ device: Target device - pytorch device (e.g., 'cuda', 'cpu') or 'numpy'
+ callback: Optional function that would be called on every element before processing
+ non_blocking: If True, allows asynchronous copy to GPU (may be faster)
+
+ Returns:
+ Data with the same structure as input but with tensors transferred to target device
+ """
+ if callback:
+ batch = callback(batch)
+
+ if isinstance(batch, dict):
+ return {
+ k: to_device(v, device, non_blocking=non_blocking) for k, v in batch.items()
+ }
+
+ if isinstance(batch, (tuple, list)):
+ return type(batch)(
+ to_device(x, device, non_blocking=non_blocking) for x in batch
+ )
+
+ x = batch
+ if device == "numpy":
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ elif x is not None:
+ if isinstance(x, np.ndarray):
+ x = torch.from_numpy(x)
+ if torch.is_tensor(x):
+ x = x.to(device, non_blocking=non_blocking)
+ return x
+
+
+def to_numpy(x):
+ """Convert data to numpy arrays.
+
+ Args:
+ x: Input data (can be tensor, array, or nested structure)
+
+ Returns:
+ Data with the same structure but with tensors converted to numpy arrays
+ """
+ return to_device(x, "numpy")
+
+
+def to_cpu(x):
+ """Transfer data to CPU.
+
+ Args:
+ x: Input data (can be tensor, array, or nested structure)
+
+ Returns:
+ Data with the same structure but with tensors moved to CPU
+ """
+ return to_device(x, "cpu")
+
+
+def to_cuda(x):
+ """Transfer data to CUDA device (GPU).
+
+ Args:
+ x: Input data (can be tensor, array, or nested structure)
+
+ Returns:
+ Data with the same structure but with tensors moved to GPU
+ """
+ return to_device(x, "cuda")
diff --git a/mapanything/utils/geometry.py b/mapanything/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..774f090f813c1770aca791e049fd940700cee7a5
--- /dev/null
+++ b/mapanything/utils/geometry.py
@@ -0,0 +1,2092 @@
+"""
+Utilities for geometry operations.
+
+References: DUSt3R, MoGe
+"""
+
+from numbers import Number
+from typing import Tuple, Union
+
+import einops as ein
+import numpy as np
+import torch
+
+from mapanything.utils.misc import invalid_to_zeros
+from mapanything.utils.warnings import no_warnings
+
+
+def depthmap_to_camera_frame(depthmap, intrinsics):
+ """
+ Convert depth image to a pointcloud in camera frame.
+
+ Args:
+ - depthmap: HxW or BxHxW torch tensor
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+
+ Returns:
+ pointmap in camera frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels.
+ """
+ # Add batch dimension if not present
+ if depthmap.dim() == 2:
+ depthmap = depthmap.unsqueeze(0)
+ intrinsics = intrinsics.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size, height, width = depthmap.shape
+ device = depthmap.device
+
+ # Compute 3D point in camera frame associated with each pixel
+ x_grid, y_grid = torch.meshgrid(
+ torch.arange(width, device=device).float(),
+ torch.arange(height, device=device).float(),
+ indexing="xy",
+ )
+ x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1)
+ y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1)
+
+ fx = intrinsics[:, 0, 0].view(-1, 1, 1)
+ fy = intrinsics[:, 1, 1].view(-1, 1, 1)
+ cx = intrinsics[:, 0, 2].view(-1, 1, 1)
+ cy = intrinsics[:, 1, 2].view(-1, 1, 1)
+
+ depth_z = depthmap
+ xx = (x_grid - cx) * depth_z / fx
+ yy = (y_grid - cy) * depth_z / fy
+ pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1)
+
+ # Compute mask of valid non-zero depth pixels
+ valid_mask = depthmap > 0.0
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ pts3d_cam = pts3d_cam.squeeze(0)
+ valid_mask = valid_mask.squeeze(0)
+
+ return pts3d_cam, valid_mask
+
+
+def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None):
+ """
+ Convert depth image to a pointcloud in world frame.
+
+ Args:
+ - depthmap: HxW or BxHxW torch tensor
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+ - camera_pose: 4x4 or Bx4x4 torch tensor
+
+ Returns:
+ pointmap in world frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels.
+ """
+ pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics)
+
+ if camera_pose is not None:
+ # Add batch dimension if not present
+ if camera_pose.dim() == 2:
+ camera_pose = camera_pose.unsqueeze(0)
+ pts3d_cam = pts3d_cam.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Convert points from camera frame to world frame
+ pts3d_cam_homo = torch.cat(
+ [pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1
+ )
+ pts3d_world = ein.einsum(
+ camera_pose, pts3d_cam_homo, "b i k, b h w k -> b h w i"
+ )
+ pts3d_world = pts3d_world[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ pts3d_world = pts3d_world.squeeze(0)
+ else:
+ pts3d_world = pts3d_cam
+
+ return pts3d_world, valid_mask
+
+
+def transform_pts3d(pts3d, transformation):
+ """
+ Transform 3D points using a 4x4 transformation matrix.
+
+ Args:
+ - pts3d: HxWx3 or BxHxWx3 torch tensor
+ - transformation: 4x4 or Bx4x4 torch tensor
+
+ Returns:
+ transformed points (HxWx3 or BxHxWx3 tensor)
+ """
+ # Add batch dimension if not present
+ if pts3d.dim() == 3:
+ pts3d = pts3d.unsqueeze(0)
+ transformation = transformation.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Convert points to homogeneous coordinates
+ pts3d_homo = torch.cat([pts3d, torch.ones_like(pts3d[..., :1])], dim=-1)
+
+ # Transform points
+ transformed_pts3d = ein.einsum(
+ transformation, pts3d_homo, "b i k, b h w k -> b h w i"
+ )
+ transformed_pts3d = transformed_pts3d[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ transformed_pts3d = transformed_pts3d.squeeze(0)
+
+ return transformed_pts3d
+
+
+def project_pts3d_to_image(pts3d, intrinsics, return_z_dim):
+ """
+ Project 3D points to image plane (assumes pinhole camera model with no distortion).
+
+ Args:
+ - pts3d: HxWx3 or BxHxWx3 torch tensor
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+ - return_z_dim: bool, whether to return the third dimension of the projected points
+
+ Returns:
+ projected points (HxWx2)
+ """
+ if pts3d.dim() == 3:
+ pts3d = pts3d.unsqueeze(0)
+ intrinsics = intrinsics.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Project points to image plane
+ projected_pts2d = ein.einsum(intrinsics, pts3d, "b i k, b h w k -> b h w i")
+ projected_pts2d[..., :2] /= projected_pts2d[..., 2].unsqueeze(-1).clamp(min=1e-6)
+
+ # Remove the z dimension if not required
+ if not return_z_dim:
+ projected_pts2d = projected_pts2d[..., :2]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ projected_pts2d = projected_pts2d.squeeze(0)
+
+ return projected_pts2d
+
+
+def get_rays_in_camera_frame(intrinsics, height, width, normalize_to_unit_sphere):
+ """
+ Convert camera intrinsics to a raymap (ray origins + directions) in camera frame.
+ Note: Currently only supports pinhole camera model.
+
+ Args:
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+ - height: int
+ - width: int
+ - normalize_to_unit_sphere: bool
+
+ Returns:
+ - ray_origins: (HxWx3 or BxHxWx3) tensor
+ - ray_directions: (HxWx3 or BxHxWx3) tensor
+ """
+ # Add batch dimension if not present
+ if intrinsics.dim() == 2:
+ intrinsics = intrinsics.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size = intrinsics.shape[0]
+ device = intrinsics.device
+
+ # Compute rays in camera frame associated with each pixel
+ x_grid, y_grid = torch.meshgrid(
+ torch.arange(width, device=device).float(),
+ torch.arange(height, device=device).float(),
+ indexing="xy",
+ )
+ x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1)
+ y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1)
+
+ fx = intrinsics[:, 0, 0].view(-1, 1, 1)
+ fy = intrinsics[:, 1, 1].view(-1, 1, 1)
+ cx = intrinsics[:, 0, 2].view(-1, 1, 1)
+ cy = intrinsics[:, 1, 2].view(-1, 1, 1)
+
+ ray_origins = torch.zeros((batch_size, height, width, 3), device=device)
+ xx = (x_grid - cx) / fx
+ yy = (y_grid - cy) / fy
+ ray_directions = torch.stack((xx, yy, torch.ones_like(xx)), dim=-1)
+
+ # Normalize ray directions to unit sphere if required (else rays will lie on unit plane)
+ if normalize_to_unit_sphere:
+ ray_directions = ray_directions / torch.norm(
+ ray_directions, dim=-1, keepdim=True
+ )
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ ray_origins = ray_origins.squeeze(0)
+ ray_directions = ray_directions.squeeze(0)
+
+ return ray_origins, ray_directions
+
+
+def get_rays_in_world_frame(
+ intrinsics, height, width, normalize_to_unit_sphere, camera_pose=None
+):
+ """
+ Convert camera intrinsics & camera_pose (if provided) to a raymap (ray origins + directions) in camera or world frame (if camera_pose is provided).
+ Note: Currently only supports pinhole camera model.
+
+ Args:
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+ - height: int
+ - width: int
+ - normalize_to_unit_sphere: bool
+ - camera_pose: 4x4 or Bx4x4 torch tensor
+
+ Returns:
+ - ray_origins: (HxWx3 or BxHxWx3) tensor
+ - ray_directions: (HxWx3 or BxHxWx3) tensor
+ """
+ # Get rays in camera frame
+ ray_origins, ray_directions = get_rays_in_camera_frame(
+ intrinsics, height, width, normalize_to_unit_sphere
+ )
+
+ if camera_pose is not None:
+ # Add batch dimension if not present
+ if camera_pose.dim() == 2:
+ camera_pose = camera_pose.unsqueeze(0)
+ ray_origins = ray_origins.unsqueeze(0)
+ ray_directions = ray_directions.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Convert rays from camera frame to world frame
+ ray_origins_homo = torch.cat(
+ [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1
+ )
+ ray_directions_homo = torch.cat(
+ [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1
+ )
+ ray_origins_world = ein.einsum(
+ camera_pose, ray_origins_homo, "b i k, b h w k -> b h w i"
+ )
+ ray_directions_world = ein.einsum(
+ camera_pose, ray_directions_homo, "b i k, b h w k -> b h w i"
+ )
+ ray_origins_world = ray_origins_world[..., :3]
+ ray_directions_world = ray_directions_world[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ ray_origins_world = ray_origins_world.squeeze(0)
+ ray_directions_world = ray_directions_world.squeeze(0)
+ else:
+ ray_origins_world = ray_origins
+ ray_directions_world = ray_directions
+
+ return ray_origins_world, ray_directions_world
+
+
+def recover_pinhole_intrinsics_from_ray_directions(
+ ray_directions, use_geometric_calculation=False
+):
+ """
+ Recover pinhole camera intrinsics from ray directions, supporting both batched and non-batched inputs.
+
+ Args:
+ ray_directions: Tensor of shape [H, W, 3] or [B, H, W, 3] containing unit normalized ray directions
+
+ Returns:
+ Dictionary containing camera intrinsics (fx, fy, cx, cy) as tensors
+ """
+ # Add batch dimension if not present
+ if ray_directions.dim() == 3: # [H, W, 3]
+ ray_directions = ray_directions.unsqueeze(0) # [1, H, W, 3]
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size, height, width, _ = ray_directions.shape
+ device = ray_directions.device
+
+ # Create pixel coordinate grid
+ x_grid, y_grid = torch.meshgrid(
+ torch.arange(width, device=device).float(),
+ torch.arange(height, device=device).float(),
+ indexing="xy",
+ )
+
+ # Expand grid for all batches
+ x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W]
+ y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W]
+
+ # Determine if high resolution or not
+ is_high_res = height * width > 1000000
+
+ if is_high_res or use_geometric_calculation:
+ # For high-resolution cases, use direct geometric calculation
+ # Define key points
+ center_h, center_w = height // 2, width // 2
+ quarter_w, three_quarter_w = width // 4, 3 * width // 4
+ quarter_h, three_quarter_h = height // 4, 3 * height // 4
+
+ # Get rays at key points
+ center_rays = ray_directions[:, center_h, center_w, :].clone() # [B, 3]
+ left_rays = ray_directions[:, center_h, quarter_w, :].clone() # [B, 3]
+ right_rays = ray_directions[:, center_h, three_quarter_w, :].clone() # [B, 3]
+ top_rays = ray_directions[:, quarter_h, center_w, :].clone() # [B, 3]
+ bottom_rays = ray_directions[:, three_quarter_h, center_w, :].clone() # [B, 3]
+
+ # Normalize rays to have dz = 1
+ center_rays = center_rays / center_rays[:, 2].unsqueeze(1) # [B, 3]
+ left_rays = left_rays / left_rays[:, 2].unsqueeze(1) # [B, 3]
+ right_rays = right_rays / right_rays[:, 2].unsqueeze(1) # [B, 3]
+ top_rays = top_rays / top_rays[:, 2].unsqueeze(1) # [B, 3]
+ bottom_rays = bottom_rays / bottom_rays[:, 2].unsqueeze(1) # [B, 3]
+
+ # Calculate fx directly (vectorized across batch)
+ fx_left = (quarter_w - center_w) / (left_rays[:, 0] - center_rays[:, 0])
+ fx_right = (three_quarter_w - center_w) / (right_rays[:, 0] - center_rays[:, 0])
+ fx = (fx_left + fx_right) / 2 # Average for robustness
+
+ # Calculate cx
+ cx = center_w - fx * center_rays[:, 0]
+
+ # Calculate fy and cy
+ fy_top = (quarter_h - center_h) / (top_rays[:, 1] - center_rays[:, 1])
+ fy_bottom = (three_quarter_h - center_h) / (
+ bottom_rays[:, 1] - center_rays[:, 1]
+ )
+ fy = (fy_top + fy_bottom) / 2
+
+ cy = center_h - fy * center_rays[:, 1]
+ else:
+ # For standard resolution, use regression with sampling for efficiency
+ # Sample a grid of points (but more dense than for high-res)
+ step_h = max(1, height // 50)
+ step_w = max(1, width // 50)
+
+ h_indices = torch.arange(0, height, step_h, device=device)
+ w_indices = torch.arange(0, width, step_w, device=device)
+
+ # Extract subset of coordinates
+ x_sampled = x_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W']
+ y_sampled = y_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W']
+ rays_sampled = ray_directions[
+ :, h_indices[:, None], w_indices[None, :], :
+ ] # [B, H', W', 3]
+
+ # Reshape for linear regression
+ x_flat = x_sampled.reshape(batch_size, -1) # [B, N]
+ y_flat = y_sampled.reshape(batch_size, -1) # [B, N]
+
+ # Extract ray direction components
+ dx = rays_sampled[..., 0].reshape(batch_size, -1) # [B, N]
+ dy = rays_sampled[..., 1].reshape(batch_size, -1) # [B, N]
+ dz = rays_sampled[..., 2].reshape(batch_size, -1) # [B, N]
+
+ # Compute ratios for linear regression
+ ratio_x = dx / dz # [B, N]
+ ratio_y = dy / dz # [B, N]
+
+ # Since torch.linalg.lstsq doesn't support batched input, we'll use a different approach
+ # For x-direction: x = cx + fx * (dx/dz)
+ # We can solve this using normal equations: A^T A x = A^T b
+ # Create design matrices
+ ones = torch.ones_like(x_flat) # [B, N]
+ A_x = torch.stack([ones, ratio_x], dim=2) # [B, N, 2]
+ b_x = x_flat.unsqueeze(2) # [B, N, 1]
+
+ # Compute A^T A and A^T b for each batch
+ ATA_x = torch.bmm(A_x.transpose(1, 2), A_x) # [B, 2, 2]
+ ATb_x = torch.bmm(A_x.transpose(1, 2), b_x) # [B, 2, 1]
+
+ # Solve the system for each batch
+ solution_x = torch.linalg.solve(ATA_x, ATb_x).squeeze(2) # [B, 2]
+ cx, fx = solution_x[:, 0], solution_x[:, 1]
+
+ # Repeat for y-direction
+ A_y = torch.stack([ones, ratio_y], dim=2) # [B, N, 2]
+ b_y = y_flat.unsqueeze(2) # [B, N, 1]
+
+ ATA_y = torch.bmm(A_y.transpose(1, 2), A_y) # [B, 2, 2]
+ ATb_y = torch.bmm(A_y.transpose(1, 2), b_y) # [B, 2, 1]
+
+ solution_y = torch.linalg.solve(ATA_y, ATb_y).squeeze(2) # [B, 2]
+ cy, fy = solution_y[:, 0], solution_y[:, 1]
+
+ # Create intrinsics matrices
+ batch_size = fx.shape[0]
+ intrinsics = torch.zeros(batch_size, 3, 3, device=ray_directions.device)
+
+ # Fill in the intrinsics matrices
+ intrinsics[:, 0, 0] = fx # focal length x
+ intrinsics[:, 1, 1] = fy # focal length y
+ intrinsics[:, 0, 2] = cx # principal point x
+ intrinsics[:, 1, 2] = cy # principal point y
+ intrinsics[:, 2, 2] = 1.0 # bottom-right element is always 1
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ intrinsics = intrinsics.squeeze(0)
+
+ return intrinsics
+
+
+def transform_rays(ray_origins, ray_directions, transformation):
+ """
+ Transform 6D rays (ray origins and ray directions) using a 4x4 transformation matrix.
+
+ Args:
+ - ray_origins: HxWx3 or BxHxWx3 torch tensor
+ - ray_directions: HxWx3 or BxHxWx3 torch tensor
+ - transformation: 4x4 or Bx4x4 torch tensor
+ - normalize_to_unit_sphere: bool, whether to normalize the transformed ray directions to unit length
+
+ Returns:
+ transformed ray_origins (HxWx3 or BxHxWx3 tensor) and ray_directions (HxWx3 or BxHxWx3 tensor)
+ """
+ # Add batch dimension if not present
+ if ray_origins.dim() == 3:
+ ray_origins = ray_origins.unsqueeze(0)
+ ray_directions = ray_directions.unsqueeze(0)
+ transformation = transformation.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Transform ray origins and directions
+ ray_origins_homo = torch.cat(
+ [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1
+ )
+ ray_directions_homo = torch.cat(
+ [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1
+ )
+ transformed_ray_origins = ein.einsum(
+ transformation, ray_origins_homo, "b i k, b h w k -> b h w i"
+ )
+ transformed_ray_directions = ein.einsum(
+ transformation, ray_directions_homo, "b i k, b h w k -> b h w i"
+ )
+ transformed_ray_origins = transformed_ray_origins[..., :3]
+ transformed_ray_directions = transformed_ray_directions[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ transformed_ray_origins = transformed_ray_origins.squeeze(0)
+ transformed_ray_directions = transformed_ray_directions.squeeze(0)
+
+ return transformed_ray_origins, transformed_ray_directions
+
+
+def convert_z_depth_to_depth_along_ray(z_depth, intrinsics):
+ """
+ Convert z-depth image to depth along camera rays.
+
+ Args:
+ - z_depth: HxW or BxHxW torch tensor
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+
+ Returns:
+ - depth_along_ray: HxW or BxHxW torch tensor
+ """
+ # Add batch dimension if not present
+ if z_depth.dim() == 2:
+ z_depth = z_depth.unsqueeze(0)
+ intrinsics = intrinsics.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Get rays in camera frame
+ batch_size, height, width = z_depth.shape
+ _, ray_directions = get_rays_in_camera_frame(
+ intrinsics, height, width, normalize_to_unit_sphere=False
+ )
+
+ # Compute depth along ray
+ pts3d_cam = z_depth[..., None] * ray_directions
+ depth_along_ray = torch.norm(pts3d_cam, dim=-1)
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ depth_along_ray = depth_along_ray.squeeze(0)
+
+ return depth_along_ray
+
+
+def convert_raymap_z_depth_quats_to_pointmap(ray_origins, ray_directions, depth, quats):
+ """
+ Convert raymap (ray origins + directions on unit plane), z-depth and
+ unit quaternions (representing rotation) to a pointmap in world frame.
+
+ Args:
+ - ray_origins: (HxWx3 or BxHxWx3) torch tensor
+ - ray_directions: (HxWx3 or BxHxWx3) torch tensor
+ - depth: (HxWx1 or BxHxWx1) torch tensor
+ - quats: (HxWx4 or BxHxWx4) torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - pointmap: (HxWx3 or BxHxWx3) torch tensor
+ """
+ # Add batch dimension if not present
+ if ray_origins.dim() == 3:
+ ray_origins = ray_origins.unsqueeze(0)
+ ray_directions = ray_directions.unsqueeze(0)
+ depth = depth.unsqueeze(0)
+ quats = quats.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size, height, width, _ = depth.shape
+ device = depth.device
+
+ # Normalize the quaternions to ensure they are unit quaternions
+ quats = quats / torch.norm(quats, dim=-1, keepdim=True)
+
+ # Convert quaternions to pixel-wise rotation matrices
+ qx, qy, qz, qw = quats[..., 0], quats[..., 1], quats[..., 2], quats[..., 3]
+ rot_mat = (
+ torch.stack(
+ [
+ qw**2 + qx**2 - qy**2 - qz**2,
+ 2 * (qx * qy - qw * qz),
+ 2 * (qw * qy + qx * qz),
+ 2 * (qw * qz + qx * qy),
+ qw**2 - qx**2 + qy**2 - qz**2,
+ 2 * (qy * qz - qw * qx),
+ 2 * (qx * qz - qw * qy),
+ 2 * (qw * qx + qy * qz),
+ qw**2 - qx**2 - qy**2 + qz**2,
+ ],
+ dim=-1,
+ )
+ .reshape(batch_size, height, width, 3, 3)
+ .to(device)
+ )
+
+ # Compute 3D points in local camera frame
+ pts3d_local = depth * ray_directions
+
+ # Rotate the local points using the quaternions
+ rotated_pts3d_local = ein.einsum(
+ rot_mat, pts3d_local, "b h w i k, b h w k -> b h w i"
+ )
+
+ # Compute 3D point in world frame associated with each pixel
+ pts3d = ray_origins + rotated_pts3d_local
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ pts3d = pts3d.squeeze(0)
+
+ return pts3d
+
+
+def quaternion_to_rotation_matrix(quat):
+ """
+ Convert a quaternion into a 3x3 rotation matrix.
+
+ Args:
+ - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - rot_matrix: 3x3 or Bx3x3 torch tensor
+ """
+ if quat.dim() == 1:
+ quat = quat.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Ensure the quaternion is normalized
+ quat = quat / quat.norm(dim=1, keepdim=True)
+ x, y, z, w = quat.unbind(dim=1)
+
+ # Compute the rotation matrix elements
+ xx = x * x
+ yy = y * y
+ zz = z * z
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ wx = w * x
+ wy = w * y
+ wz = w * z
+
+ # Construct the rotation matrix
+ rot_matrix = torch.stack(
+ [
+ 1 - 2 * (yy + zz),
+ 2 * (xy - wz),
+ 2 * (xz + wy),
+ 2 * (xy + wz),
+ 1 - 2 * (xx + zz),
+ 2 * (yz - wx),
+ 2 * (xz - wy),
+ 2 * (yz + wx),
+ 1 - 2 * (xx + yy),
+ ],
+ dim=1,
+ ).view(-1, 3, 3)
+
+ # Squeeze batch dimension if it was unsqueezed
+ if squeeze_batch_dim:
+ rot_matrix = rot_matrix.squeeze(0)
+
+ return rot_matrix
+
+
+def quaternion_inverse(quat):
+ """
+ Compute the inverse of a quaternion.
+
+ Args:
+ - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - inv_quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ """
+ # Unsqueeze batch dimension if not present
+ if quat.dim() == 1:
+ quat = quat.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Compute the inverse
+ quat_conj = quat.clone()
+ quat_conj[:, :3] = -quat_conj[:, :3]
+ quat_norm = torch.sum(quat * quat, dim=1, keepdim=True)
+ inv_quat = quat_conj / quat_norm
+
+ # Squeeze batch dimension if it was unsqueezed
+ if squeeze_batch_dim:
+ inv_quat = inv_quat.squeeze(0)
+
+ return inv_quat
+
+
+def quaternion_multiply(q1, q2):
+ """
+ Multiply two quaternions.
+
+ Args:
+ - q1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ - q2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - qm: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ """
+ # Unsqueeze batch dimension if not present
+ if q1.dim() == 1:
+ q1 = q1.unsqueeze(0)
+ q2 = q2.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Unbind the quaternions
+ x1, y1, z1, w1 = q1.unbind(dim=1)
+ x2, y2, z2, w2 = q2.unbind(dim=1)
+
+ # Compute the product
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
+
+ # Stack the components
+ qm = torch.stack([x, y, z, w], dim=1)
+
+ # Squeeze batch dimension if it was unsqueezed
+ if squeeze_batch_dim:
+ qm = qm.squeeze(0)
+
+ return qm
+
+
+def transform_pose_using_quats_and_trans_2_to_1(quats1, trans1, quats2, trans2):
+ """
+ Transform quats and translation of pose2 from absolute frame (pose2 to world) to relative frame (pose2 to pose1).
+
+ Args:
+ - quats1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ - trans1: 3 or Bx3 torch tensor
+ - quats2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ - trans2: 3 or Bx3 torch tensor
+
+ Returns:
+ - quats: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ - trans: 3 or Bx3 torch tensor
+ """
+ # Unsqueeze batch dimension if not present
+ if quats1.dim() == 1:
+ quats1 = quats1.unsqueeze(0)
+ trans1 = trans1.unsqueeze(0)
+ quats2 = quats2.unsqueeze(0)
+ trans2 = trans2.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Compute the inverse of view1's pose
+ inv_quats1 = quaternion_inverse(quats1)
+ R1_inv = quaternion_to_rotation_matrix(inv_quats1)
+ t1_inv = -1 * ein.einsum(R1_inv, trans1, "b i j, b j -> b i")
+
+ # Transform view2's pose to view1's frame
+ quats = quaternion_multiply(inv_quats1, quats2)
+ trans = ein.einsum(R1_inv, trans2, "b i j, b j -> b i") + t1_inv
+
+ # Squeeze batch dimension if it was unsqueezed
+ if squeeze_batch_dim:
+ quats = quats.squeeze(0)
+ trans = trans.squeeze(0)
+
+ return quats, trans
+
+
+def convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ ray_directions, depth_along_ray, pose_trans, pose_quats
+):
+ """
+ Convert ray directions, depth along ray, pose translation, and
+ unit quaternions (representing pose rotation) to a pointmap in world frame.
+
+ Args:
+ - ray_directions: (HxWx3 or BxHxWx3) torch tensor
+ - depth_along_ray: (HxWx1 or BxHxWx1) torch tensor
+ - pose_trans: (3 or Bx3) torch tensor
+ - pose_quats: (4 or Bx4) torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - pointmap: (HxWx3 or BxHxWx3) torch tensor
+ """
+ # Add batch dimension if not present
+ if ray_directions.dim() == 3:
+ ray_directions = ray_directions.unsqueeze(0)
+ depth_along_ray = depth_along_ray.unsqueeze(0)
+ pose_trans = pose_trans.unsqueeze(0)
+ pose_quats = pose_quats.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size, height, width, _ = depth_along_ray.shape
+ device = depth_along_ray.device
+
+ # Normalize the quaternions to ensure they are unit quaternions
+ pose_quats = pose_quats / torch.norm(pose_quats, dim=-1, keepdim=True)
+
+ # Convert quaternions to rotation matrices (B x 3 x 3)
+ rot_mat = quaternion_to_rotation_matrix(pose_quats)
+
+ # Get pose matrix (B x 4 x 4)
+ pose_mat = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
+ pose_mat[:, :3, :3] = rot_mat
+ pose_mat[:, :3, 3] = pose_trans
+
+ # Compute 3D points in local camera frame
+ pts3d_local = depth_along_ray * ray_directions
+
+ # Compute 3D points in world frame
+ pts3d_homo = torch.cat([pts3d_local, torch.ones_like(pts3d_local[..., :1])], dim=-1)
+ pts3d_world = ein.einsum(pose_mat, pts3d_homo, "b i k, b h w k -> b h w i")
+ pts3d_world = pts3d_world[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ pts3d_world = pts3d_world.squeeze(0)
+
+ return pts3d_world
+
+
+def xy_grid(
+ W,
+ H,
+ device=None,
+ origin=(0, 0),
+ unsqueeze=None,
+ cat_dim=-1,
+ homogeneous=False,
+ **arange_kw,
+):
+ """
+ Generate a coordinate grid of shape (H,W,2) or (H,W,3) if homogeneous=True.
+
+ Args:
+ W (int): Width of the grid
+ H (int): Height of the grid
+ device (torch.device, optional): Device to place the grid on. If None, uses numpy arrays
+ origin (tuple, optional): Origin coordinates (x,y) for the grid. Default is (0,0)
+ unsqueeze (int, optional): Dimension to unsqueeze in the output tensors
+ cat_dim (int, optional): Dimension to concatenate the x,y coordinates. If None, returns tuple
+ homogeneous (bool, optional): If True, adds a third dimension of ones to make homogeneous coordinates
+ **arange_kw: Additional keyword arguments passed to np.arange or torch.arange
+
+ Returns:
+ numpy.ndarray or torch.Tensor: Coordinate grid where:
+ - output[j,i,0] = i + origin[0] (x-coordinate)
+ - output[j,i,1] = j + origin[1] (y-coordinate)
+ - output[j,i,2] = 1 (if homogeneous=True)
+ """
+ if device is None:
+ # numpy
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
+ else:
+ # torch
+ def arange(*a, **kw):
+ return torch.arange(*a, device=device, **kw)
+
+ meshgrid, stack = torch.meshgrid, torch.stack
+
+ def ones(*a):
+ return torch.ones(*a, device=device)
+
+ tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
+ grid = meshgrid(tw, th, indexing="xy")
+ if homogeneous:
+ grid = grid + (ones((H, W)),)
+ if unsqueeze is not None:
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
+ if cat_dim is not None:
+ grid = stack(grid, cat_dim)
+
+ return grid
+
+
+def geotrf(Trf, pts, ncol=None, norm=False):
+ """
+ Apply a geometric transformation to a set of 3-D points.
+
+ Args:
+ Trf: 3x3 or 4x4 projection matrix (typically a Homography) or batch of matrices
+ with shape (B, 3, 3) or (B, 4, 4)
+ pts: numpy/torch/tuple of coordinates with shape (..., 2) or (..., 3)
+ ncol: int, number of columns of the result (2 or 3)
+ norm: float, if not 0, the result is projected on the z=norm plane
+ (homogeneous normalization)
+
+ Returns:
+ Array or tensor of projected points with the same type as input and shape (..., ncol)
+ """
+ assert Trf.ndim >= 2
+ if isinstance(Trf, np.ndarray):
+ pts = np.asarray(pts)
+ elif isinstance(Trf, torch.Tensor):
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
+
+ # Adapt shape if necessary
+ output_reshape = pts.shape[:-1]
+ ncol = ncol or pts.shape[-1]
+
+ # Optimized code
+ if (
+ isinstance(Trf, torch.Tensor)
+ and isinstance(pts, torch.Tensor)
+ and Trf.ndim == 3
+ and pts.ndim == 4
+ ):
+ d = pts.shape[3]
+ if Trf.shape[-1] == d:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
+ elif Trf.shape[-1] == d + 1:
+ pts = (
+ torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
+ + Trf[:, None, None, :d, d]
+ )
+ else:
+ raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
+ else:
+ if Trf.ndim >= 3:
+ n = Trf.ndim - 2
+ assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
+
+ if pts.ndim > Trf.ndim:
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
+ elif pts.ndim == 2:
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
+ pts = pts[:, None, :]
+
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
+ elif pts.shape[-1] == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf
+ else:
+ pts = Trf @ pts.T
+ if pts.ndim >= 2:
+ pts = pts.swapaxes(-1, -2)
+
+ if norm:
+ pts = pts / pts[..., -1:] # DONT DO /=, it will lead to a bug
+ if norm != 1:
+ pts *= norm
+
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
+
+ return res
+
+
+def inv(mat):
+ """
+ Invert a torch or numpy matrix
+ """
+ if isinstance(mat, torch.Tensor):
+ return torch.linalg.inv(mat)
+ if isinstance(mat, np.ndarray):
+ return np.linalg.inv(mat)
+ raise ValueError(f"bad matrix type = {type(mat)}")
+
+
+def closed_form_pose_inverse(
+ pose_matrices, rotation_matrices=None, translation_vectors=None
+):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 pose matrices in a batch.
+
+ If `rotation_matrices` and `translation_vectors` are provided, they must correspond to the rotation and translation
+ components of `pose_matrices`. Otherwise, they will be extracted from `pose_matrices`.
+
+ Args:
+ pose_matrices: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ rotation_matrices (optional): Nx3x3 array or tensor of rotation matrices.
+ translation_vectors (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as input `pose_matrices`.
+
+ Shapes:
+ pose_matrices: (N, 4, 4)
+ rotation_matrices: (N, 3, 3)
+ translation_vectors: (N, 3, 1)
+ """
+ # Check if pose_matrices is a numpy array or a torch tensor
+ is_numpy = isinstance(pose_matrices, np.ndarray)
+
+ # Validate shapes
+ if pose_matrices.shape[-2:] != (4, 4) and pose_matrices.shape[-2:] != (3, 4):
+ raise ValueError(
+ f"pose_matrices must be of shape (N,4,4), got {pose_matrices.shape}."
+ )
+
+ # Extract rotation_matrices and translation_vectors if not provided
+ if rotation_matrices is None:
+ rotation_matrices = pose_matrices[:, :3, :3]
+ if translation_vectors is None:
+ translation_vectors = pose_matrices[:, :3, 3:]
+
+ # Compute the inverse of input SE3 matrices
+ if is_numpy:
+ rotation_transposed = np.transpose(rotation_matrices, (0, 2, 1))
+ new_translation = -np.matmul(rotation_transposed, translation_vectors)
+ inverted_matrix = np.tile(np.eye(4), (len(rotation_matrices), 1, 1))
+ else:
+ rotation_transposed = rotation_matrices.transpose(1, 2)
+ new_translation = -torch.bmm(rotation_transposed, translation_vectors)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(rotation_matrices), 1, 1)
+ inverted_matrix = inverted_matrix.to(rotation_matrices.dtype).to(
+ rotation_matrices.device
+ )
+ inverted_matrix[:, :3, :3] = rotation_transposed
+ inverted_matrix[:, :3, 3:] = new_translation
+
+ return inverted_matrix
+
+
+def relative_pose_transformation(trans_01, trans_02):
+ r"""
+ Function that computes the relative homogenous transformation from a
+ reference transformation :math:`T_1^{0} = \begin{bmatrix} R_1 & t_1 \\
+ \mathbf{0} & 1 \end{bmatrix}` to destination :math:`T_2^{0} =
+ \begin{bmatrix} R_2 & t_2 \\ \mathbf{0} & 1 \end{bmatrix}`.
+
+ The relative transformation is computed as follows:
+
+ .. math::
+
+ T_1^{2} = (T_0^{1})^{-1} \cdot T_0^{2}
+
+ Arguments:
+ trans_01 (torch.Tensor): reference transformation tensor of shape
+ :math:`(N, 4, 4)` or :math:`(4, 4)`.
+ trans_02 (torch.Tensor): destination transformation tensor of shape
+ :math:`(N, 4, 4)` or :math:`(4, 4)`.
+
+ Shape:
+ - Output: :math:`(N, 4, 4)` or :math:`(4, 4)`.
+
+ Returns:
+ torch.Tensor: the relative transformation between the transformations.
+
+ Example::
+ >>> trans_01 = torch.eye(4) # 4x4
+ >>> trans_02 = torch.eye(4) # 4x4
+ >>> trans_12 = relative_transformation(trans_01, trans_02) # 4x4
+ """
+ if not torch.is_tensor(trans_01):
+ raise TypeError(
+ "Input trans_01 type is not a torch.Tensor. Got {}".format(type(trans_01))
+ )
+ if not torch.is_tensor(trans_02):
+ raise TypeError(
+ "Input trans_02 type is not a torch.Tensor. Got {}".format(type(trans_02))
+ )
+ if trans_01.dim() not in (2, 3) and trans_01.shape[-2:] == (4, 4):
+ raise ValueError(
+ "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_01.shape)
+ )
+ if trans_02.dim() not in (2, 3) and trans_02.shape[-2:] == (4, 4):
+ raise ValueError(
+ "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_02.shape)
+ )
+ if not trans_01.dim() == trans_02.dim():
+ raise ValueError(
+ "Input number of dims must match. Got {} and {}".format(
+ trans_01.dim(), trans_02.dim()
+ )
+ )
+
+ # Convert to Nx4x4 if inputs are 4x4
+ squeeze_batch_dim = False
+ if trans_01.dim() == 2:
+ trans_01 = trans_01.unsqueeze(0)
+ trans_02 = trans_02.unsqueeze(0)
+ squeeze_batch_dim = True
+
+ # Compute inverse of trans_01 using closed form
+ trans_10 = closed_form_pose_inverse(trans_01)
+
+ # Compose transformations using matrix multiplication
+ trans_12 = torch.matmul(trans_10, trans_02)
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ trans_12 = trans_12.squeeze(0)
+
+ return trans_12
+
+
+def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
+ """
+ Args:
+ - depthmap (BxHxW array):
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
+ Returns:
+ pointmap of absolute coordinates (BxHxWx3 array)
+ """
+
+ if len(depth.shape) == 4:
+ B, H, W, n = depth.shape
+ else:
+ B, H, W = depth.shape
+ n = None
+
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
+ pseudo_focalx = pseudo_focaly = pseudo_focal
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
+ pseudo_focalx = pseudo_focal[:, 0]
+ if pseudo_focal.shape[1] == 2:
+ pseudo_focaly = pseudo_focal[:, 1]
+ else:
+ pseudo_focaly = pseudo_focalx
+ else:
+ raise NotImplementedError("Error, unknown input focal shape format.")
+
+ assert pseudo_focalx.shape == depth.shape[:3]
+ assert pseudo_focaly.shape == depth.shape[:3]
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
+
+ # set principal point
+ if pp is None:
+ grid_x = grid_x - (W - 1) / 2
+ grid_y = grid_y - (H - 1) / 2
+ else:
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
+
+ if n is None:
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
+ pts3d[..., 2] = depth
+ else:
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
+ pts3d[..., 2, :] = depth
+ return pts3d
+
+
+def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ # Compute 3D ray associated with each pixel
+ # Strong assumption: there are no skew terms
+ assert camera_intrinsics[0, 1] == 0.0
+ assert camera_intrinsics[1, 0] == 0.0
+ if pseudo_focal is None:
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ else:
+ assert pseudo_focal.shape == (H, W)
+ fu = fv = pseudo_focal
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ z_cam = depthmap
+ x_cam = (u - cu) * z_cam / fu
+ y_cam = (v - cv) * z_cam / fv
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ # Mask for valid coordinates
+ valid_mask = depthmap > 0.0
+
+ return X_cam, valid_mask
+
+
+def depthmap_to_absolute_camera_coordinates(
+ depthmap, camera_intrinsics, camera_pose, **kw
+):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
+
+ X_world = X_cam # default
+ if camera_pose is not None:
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ # Express in absolute coordinates (invalid depth values)
+ X_world = (
+ np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
+ )
+
+ return X_world, valid_mask
+
+
+def get_absolute_pointmaps_and_rays_info(
+ depthmap, camera_intrinsics, camera_pose, **kw
+):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array),
+ a mask specifying valid pixels,
+ ray origins of absolute coordinates (HxWx3 array),
+ ray directions of absolute coordinates (HxWx3 array),
+ depth along ray (HxWx1 array),
+ ray directions of camera/local coordinates (HxWx3 array),
+ pointmap of camera/local coordinates (HxWx3 array).
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ # Compute 3D ray associated with each pixel
+ # Strong assumption: pinhole & there are no skew terms
+ assert camera_intrinsics[0, 1] == 0.0
+ assert camera_intrinsics[1, 0] == 0.0
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ # Get the rays on the unit plane
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ x_cam = (u - cu) / fu
+ y_cam = (v - cv) / fv
+ z_cam = np.ones_like(x_cam)
+ ray_dirs_cam_on_unit_plane = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(
+ np.float32
+ )
+
+ # Compute the 3d points in the local camera coordinate system
+ pts_cam = depthmap[..., None] * ray_dirs_cam_on_unit_plane
+
+ # Get the depth along the ray and compute the ray directions on the unit sphere
+ depth_along_ray = np.linalg.norm(pts_cam, axis=-1, keepdims=True)
+ ray_directions_cam = ray_dirs_cam_on_unit_plane / np.linalg.norm(
+ ray_dirs_cam_on_unit_plane, axis=-1, keepdims=True
+ )
+
+ # Mask for valid coordinates
+ valid_mask = depthmap > 0.0
+
+ # Get the ray origins in absolute coordinates and the ray directions in absolute coordinates
+ ray_origins_world = np.zeros_like(ray_directions_cam)
+ ray_directions_world = ray_directions_cam
+ pts_world = pts_cam
+ if camera_pose is not None:
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ # Express in absolute coordinates
+ ray_origins_world = ray_origins_world + t_cam2world[None, None, :]
+ ray_directions_world = np.einsum(
+ "ik, vuk -> vui", R_cam2world, ray_directions_cam
+ )
+ pts_world = ray_origins_world + ray_directions_world * depth_along_ray
+
+ return (
+ pts_world,
+ valid_mask,
+ ray_origins_world,
+ ray_directions_world,
+ depth_along_ray,
+ ray_directions_cam,
+ pts_cam,
+ )
+
+
+def adjust_camera_params_for_rotation(camera_params, original_size, k):
+ """
+ Adjust camera parameters for rotation.
+
+ Args:
+ camera_params: Camera parameters [fx, fy, cx, cy, ...]
+ original_size: Original image size as (width, height)
+ k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise)
+
+ Returns:
+ Adjusted camera parameters
+ """
+ fx, fy, cx, cy = camera_params[:4]
+ width, height = original_size
+
+ if k % 4 == 1: # 90 degrees counter-clockwise
+ new_fx, new_fy = fy, fx
+ new_cx, new_cy = height - cy, cx
+ elif k % 4 == 2: # 180 degrees
+ new_fx, new_fy = fx, fy
+ new_cx, new_cy = width - cx, height - cy
+ elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise)
+ new_fx, new_fy = fy, fx
+ new_cx, new_cy = cy, width - cx
+ else: # No rotation
+ return camera_params
+
+ adjusted_params = [new_fx, new_fy, new_cx, new_cy]
+ if len(camera_params) > 4:
+ adjusted_params.extend(camera_params[4:])
+
+ return adjusted_params
+
+
+def adjust_pose_for_rotation(pose, k):
+ """
+ Adjust camera pose for rotation.
+
+ Args:
+ pose: 4x4 camera pose matrix (camera-to-world, OpenCV convention - X right, Y down, Z forward)
+ k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise)
+
+ Returns:
+ Adjusted 4x4 camera pose matrix
+ """
+ # Create rotation matrices for different rotations
+ if k % 4 == 1: # 90 degrees counter-clockwise
+ rot_transform = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
+ elif k % 4 == 2: # 180 degrees
+ rot_transform = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]])
+ elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise)
+ rot_transform = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+ else: # No rotation
+ return pose
+
+ # Apply the transformation to the pose
+ adjusted_pose = pose
+ adjusted_pose[:3, :3] = adjusted_pose[:3, :3] @ rot_transform.T
+
+ return adjusted_pose
+
+
+def crop_to_aspect_ratio(image, depth, camera_params, target_ratio=1.5):
+ """
+ Crop image and depth to the largest possible target aspect ratio while
+ keeping the left side if aspect ratio is wider and the bottom of image if the aspect ratio is taller.
+
+ Args:
+ image: PIL image
+ depth: Depth map as numpy array
+ camera_params: Camera parameters [fx, fy, cx, cy, ...]
+ target_ratio: Target width/height ratio
+
+ Returns:
+ Cropped image, cropped depth, adjusted camera parameters
+ """
+ width, height = image.size
+ fx, fy, cx, cy = camera_params[:4]
+ current_ratio = width / height
+
+ if abs(current_ratio - target_ratio) < 1e-6:
+ # Already at target ratio
+ return image, depth, camera_params
+
+ if current_ratio > target_ratio:
+ # Image is wider than target ratio, crop width
+ new_width = int(height * target_ratio)
+ left = 0
+ right = new_width
+
+ # Crop image
+ cropped_image = image.crop((left, 0, right, height))
+
+ # Crop depth
+ if len(depth.shape) == 3:
+ cropped_depth = depth[:, left:right, :]
+ else:
+ cropped_depth = depth[:, left:right]
+
+ # Adjust camera parameters
+ new_cx = cx - left
+ adjusted_params = [fx, fy, new_cx, cy] + list(camera_params[4:])
+
+ else:
+ # Image is taller than target ratio, crop height
+ new_height = int(width / target_ratio)
+ top = max(0, height - new_height)
+ bottom = height
+
+ # Crop image
+ cropped_image = image.crop((0, top, width, bottom))
+
+ # Crop depth
+ if len(depth.shape) == 3:
+ cropped_depth = depth[top:bottom, :, :]
+ else:
+ cropped_depth = depth[top:bottom, :]
+
+ # Adjust camera parameters
+ new_cy = cy - top
+ adjusted_params = [fx, fy, cx, new_cy] + list(camera_params[4:])
+
+ return cropped_image, cropped_depth, adjusted_params
+
+
+def colmap_to_opencv_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] -= 0.5
+ K[1, 2] -= 0.5
+
+ return K
+
+
+def opencv_to_colmap_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] += 0.5
+ K[1, 2] += 0.5
+
+ return K
+
+
+def normalize_depth_using_non_zero_pixels(depth, return_norm_factor=False):
+ """
+ Normalize the depth by the average depth of non-zero depth pixels.
+
+ Args:
+ depth (torch.Tensor): Depth tensor of size [B, H, W, 1].
+ Returns:
+ normalized_depth (torch.Tensor): Normalized depth tensor.
+ norm_factor (torch.Tensor): Norm factor tensor of size B.
+ """
+ assert depth.ndim == 4 and depth.shape[3] == 1
+ # Calculate the sum and count of non-zero depth pixels for each batch
+ valid_depth_mask = depth > 0
+ valid_sum = torch.sum(depth * valid_depth_mask, dim=(1, 2, 3))
+ valid_count = torch.sum(valid_depth_mask, dim=(1, 2, 3))
+
+ # Calculate the norm factor
+ norm_factor = valid_sum / (valid_count + 1e-8)
+ while norm_factor.ndim < depth.ndim:
+ norm_factor.unsqueeze_(-1)
+
+ # Normalize the depth by the norm factor
+ norm_factor = norm_factor.clip(min=1e-8)
+ normalized_depth = depth / norm_factor
+
+ # Create the output tuple
+ output = (
+ (normalized_depth, norm_factor.squeeze(-1).squeeze(-1).squeeze(-1))
+ if return_norm_factor
+ else normalized_depth
+ )
+
+ return output
+
+
+def normalize_pose_translations(pose_translations, return_norm_factor=False):
+ """
+ Normalize the pose translations by the average norm of the non-zero pose translations.
+
+ Args:
+ pose_translations (torch.Tensor): Pose translations tensor of size [B, V, 3]. B is the batch size, V is the number of views.
+ Returns:
+ normalized_pose_translations (torch.Tensor): Normalized pose translations tensor of size [B, V, 3].
+ norm_factor (torch.Tensor): Norm factor tensor of size B.
+ """
+ assert pose_translations.ndim == 3 and pose_translations.shape[2] == 3
+ # Compute distance of all pose translations to origin
+ pose_translations_dis = pose_translations.norm(dim=-1) # [B, V]
+ non_zero_pose_translations_dis = pose_translations_dis > 0 # [B, V]
+
+ # Calculate the average norm of the translations across all views (considering only views with non-zero translations)
+ sum_of_all_views_pose_translations = pose_translations_dis.sum(dim=1) # [B]
+ count_of_all_views_with_non_zero_pose_translations = (
+ non_zero_pose_translations_dis.sum(dim=1)
+ ) # [B]
+ norm_factor = sum_of_all_views_pose_translations / (
+ count_of_all_views_with_non_zero_pose_translations + 1e-8
+ ) # [B]
+
+ # Normalize the pose translations by the norm factor
+ norm_factor = norm_factor.clip(min=1e-8)
+ normalized_pose_translations = pose_translations / norm_factor.unsqueeze(
+ -1
+ ).unsqueeze(-1)
+
+ # Create the output tuple
+ output = (
+ (normalized_pose_translations, norm_factor)
+ if return_norm_factor
+ else normalized_pose_translations
+ )
+
+ return output
+
+
+def normalize_multiple_pointclouds(
+ pts_list, valid_masks=None, norm_mode="avg_dis", ret_factor=False
+):
+ """
+ Normalize multiple point clouds using a joint normalization strategy.
+
+ Args:
+ pts_list: List of point clouds, each with shape (..., H, W, 3) or (B, H, W, 3)
+ valid_masks: Optional list of masks indicating valid points in each point cloud
+ norm_mode: String in format "{norm}_{dis}" where:
+ - norm: Normalization strategy (currently only "avg" is supported)
+ - dis: Distance transformation ("dis" for raw distance, "log1p" for log(1+distance),
+ "warp-log1p" to warp points using log distance)
+ ret_factor: If True, return the normalization factor as the last element in the result list
+
+ Returns:
+ List of normalized point clouds with the same shapes as inputs.
+ If ret_factor is True, the last element is the normalization factor.
+ """
+ assert all(pts.ndim >= 3 and pts.shape[-1] == 3 for pts in pts_list)
+ if valid_masks is not None:
+ assert len(pts_list) == len(valid_masks)
+
+ norm_mode, dis_mode = norm_mode.split("_")
+
+ # Gather all points together (joint normalization)
+ nan_pts_list = [
+ invalid_to_zeros(pts, valid_masks[i], ndim=3)
+ if valid_masks
+ else invalid_to_zeros(pts, None, ndim=3)
+ for i, pts in enumerate(pts_list)
+ ]
+ all_pts = torch.cat([nan_pts for nan_pts, _ in nan_pts_list], dim=1)
+ nnz_list = [nnz for _, nnz in nan_pts_list]
+
+ # Compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+ if dis_mode == "dis":
+ pass # do nothing
+ elif dis_mode == "log1p":
+ all_dis = torch.log1p(all_dis)
+ elif dis_mode == "warp-log1p":
+ # Warp input points before normalizing them
+ log_dis = torch.log1p(all_dis)
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
+ for i, pts in enumerate(pts_list):
+ H, W = pts.shape[1:-1]
+ pts_list[i] = pts * warp_factor[:, i * (H * W) : (i + 1) * (H * W)].view(
+ -1, H, W, 1
+ )
+ all_dis = log_dis
+ else:
+ raise ValueError(f"bad {dis_mode=}")
+
+ # Compute normalization factor
+ norm_factor = all_dis.sum(dim=1) / (sum(nnz_list) + 1e-8)
+ norm_factor = norm_factor.clip(min=1e-8)
+ while norm_factor.ndim < pts_list[0].ndim:
+ norm_factor.unsqueeze_(-1)
+
+ # Normalize points
+ res = [pts / norm_factor for pts in pts_list]
+ if ret_factor:
+ res.append(norm_factor)
+
+ return res
+
+
+def apply_log_to_norm(input_data):
+ """
+ Normalize the input data and apply a logarithmic transformation based on the normalization factor.
+
+ Args:
+ input_data (torch.Tensor): The input tensor to be normalized and transformed.
+
+ Returns:
+ torch.Tensor: The transformed tensor after normalization and logarithmic scaling.
+ """
+ org_d = input_data.norm(dim=-1, keepdim=True)
+ input_data = input_data / org_d.clip(min=1e-8)
+ input_data = input_data * torch.log1p(org_d)
+ return input_data
+
+
+def angle_diff_vec3(v1, v2, eps=1e-12):
+ """
+ Compute angle difference between 3D vectors.
+
+ Args:
+ v1: torch.Tensor of shape (..., 3)
+ v2: torch.Tensor of shape (..., 3)
+ eps: Small epsilon value for numerical stability
+
+ Returns:
+ torch.Tensor: Angle differences in radians
+ """
+ cross_norm = torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps
+ dot_prod = (v1 * v2).sum(dim=-1)
+ return torch.atan2(cross_norm, dot_prod)
+
+
+def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12):
+ """
+ Compute angle difference between 3D vectors using NumPy.
+
+ Args:
+ v1 (np.ndarray): First vector of shape (..., 3)
+ v2 (np.ndarray): Second vector of shape (..., 3)
+ eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12.
+
+ Returns:
+ np.ndarray: Angle differences in radians
+ """
+ return np.arctan2(
+ np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1)
+ )
+
+
+@no_warnings(category=RuntimeWarning)
+def points_to_normals(
+ point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None
+) -> np.ndarray:
+ """
+ Calculate normal map from point map. Value range is [-1, 1].
+
+ Args:
+ point (np.ndarray): shape (height, width, 3), point map
+ mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None.
+ edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None.
+
+ Returns:
+ normal (np.ndarray): shape (height, width, 3), normal map.
+ """
+ height, width = point.shape[-3:-1]
+ has_mask = mask is not None
+
+ if mask is None:
+ mask = np.ones_like(point[..., 0], dtype=bool)
+ mask_pad = np.zeros((height + 2, width + 2), dtype=bool)
+ mask_pad[1:-1, 1:-1] = mask
+ mask = mask_pad
+
+ pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype)
+ pts[1:-1, 1:-1, :] = point
+ up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :]
+ left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :]
+ down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :]
+ right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :]
+ normal = np.stack(
+ [
+ np.cross(up, left, axis=-1),
+ np.cross(left, down, axis=-1),
+ np.cross(down, right, axis=-1),
+ np.cross(right, up, axis=-1),
+ ]
+ )
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
+
+ valid = (
+ np.stack(
+ [
+ mask[:-2, 1:-1] & mask[1:-1, :-2],
+ mask[1:-1, :-2] & mask[2:, 1:-1],
+ mask[2:, 1:-1] & mask[1:-1, 2:],
+ mask[1:-1, 2:] & mask[:-2, 1:-1],
+ ]
+ )
+ & mask[None, 1:-1, 1:-1]
+ )
+ if edge_threshold is not None:
+ view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal)
+ view_angle = np.minimum(view_angle, np.pi - view_angle)
+ valid = valid & (view_angle < np.deg2rad(edge_threshold))
+
+ normal = (normal * valid[..., None]).sum(axis=0)
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
+
+ if has_mask:
+ normal_mask = valid.any(axis=0)
+ normal = np.where(normal_mask[..., None], normal, 0)
+ return normal, normal_mask
+ else:
+ return normal
+
+
+def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
+ """
+ Create a sliding window view of the input array along a specified axis.
+
+ This function creates a memory-efficient view of the input array with sliding windows
+ of the specified size and stride. The window dimension is appended to the end of the
+ output array's shape. This is useful for operations like convolution, pooling, or
+ any analysis that requires examining local neighborhoods in the data.
+
+ Args:
+ x (np.ndarray): Input array with shape (..., axis_size, ...)
+ window_size (int): Size of the sliding window
+ stride (int): Stride of the sliding window (step size between consecutive windows)
+ axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis)
+
+ Returns:
+ np.ndarray: View of the input array with shape (..., n_windows, ..., window_size),
+ where n_windows = (axis_size - window_size + 1) // stride
+
+ Raises:
+ AssertionError: If window_size is larger than the size of the specified axis
+
+ Example:
+ >>> x = np.array([1, 2, 3, 4, 5, 6])
+ >>> sliding_window_1d(x, window_size=3, stride=2)
+ array([[1, 2, 3],
+ [3, 4, 5]])
+ """
+ assert x.shape[axis] >= window_size, (
+ f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})"
+ )
+ axis = axis % x.ndim
+ shape = (
+ *x.shape[:axis],
+ (x.shape[axis] - window_size + 1) // stride,
+ *x.shape[axis + 1 :],
+ window_size,
+ )
+ strides = (
+ *x.strides[:axis],
+ stride * x.strides[axis],
+ *x.strides[axis + 1 :],
+ x.strides[axis],
+ )
+ x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+ return x_sliding
+
+
+def sliding_window_nd(
+ x: np.ndarray,
+ window_size: Tuple[int, ...],
+ stride: Tuple[int, ...],
+ axis: Tuple[int, ...],
+) -> np.ndarray:
+ """
+ Create sliding windows along multiple dimensions of the input array.
+
+ This function applies sliding_window_1d sequentially along multiple axes to create
+ N-dimensional sliding windows. This is useful for operations that need to examine
+ local neighborhoods in multiple dimensions simultaneously.
+
+ Args:
+ x (np.ndarray): Input array
+ window_size (Tuple[int, ...]): Size of the sliding window for each axis
+ stride (Tuple[int, ...]): Stride of the sliding window for each axis
+ axis (Tuple[int, ...]): Axes to perform sliding window over
+
+ Returns:
+ np.ndarray: Array with sliding windows along the specified dimensions.
+ The window dimensions are appended to the end of the shape.
+
+ Note:
+ The length of window_size, stride, and axis tuples must be equal.
+
+ Example:
+ >>> x = np.random.rand(10, 10)
+ >>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1))
+ >>> # Creates 3x3 sliding windows with stride 2 in both dimensions
+ """
+ axis = [axis[i] % x.ndim for i in range(len(axis))]
+ for i in range(len(axis)):
+ x = sliding_window_1d(x, window_size[i], stride[i], axis[i])
+ return x
+
+
+def sliding_window_2d(
+ x: np.ndarray,
+ window_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]],
+ axis: Tuple[int, int] = (-2, -1),
+) -> np.ndarray:
+ """
+ Create 2D sliding windows over the input array.
+
+ Convenience function for creating 2D sliding windows, commonly used for image
+ processing operations like convolution, pooling, or patch extraction.
+
+ Args:
+ x (np.ndarray): Input array
+ window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window.
+ If int, same size is used for both dimensions.
+ stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window.
+ If int, same stride is used for both dimensions.
+ axis (Tuple[int, int], optional): Two axes to perform sliding window over.
+ Defaults to (-2, -1) (last two dimensions).
+
+ Returns:
+ np.ndarray: Array with 2D sliding windows. The window dimensions (height, width)
+ are appended to the end of the shape.
+
+ Example:
+ >>> image = np.random.rand(100, 100)
+ >>> patches = sliding_window_2d(image, window_size=8, stride=4)
+ >>> # Creates 8x8 patches with stride 4 from the image
+ """
+ if isinstance(window_size, int):
+ window_size = (window_size, window_size)
+ if isinstance(stride, int):
+ stride = (stride, stride)
+ return sliding_window_nd(x, window_size, stride, axis)
+
+
+def max_pool_1d(
+ x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1
+):
+ """
+ Perform 1D max pooling on the input array.
+
+ Max pooling reduces the dimensionality of the input by taking the maximum value
+ within each sliding window. This is commonly used in neural networks and signal
+ processing for downsampling and feature extraction.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (int): Size of the pooling kernel
+ stride (int): Stride of the pooling operation
+ padding (int, optional): Amount of padding to add on both sides. Defaults to 0.
+ axis (int, optional): Axis to perform max pooling over. Defaults to -1.
+
+ Returns:
+ np.ndarray: Max pooled array with reduced size along the specified axis
+
+ Note:
+ - For floating point arrays, padding is done with np.nan values
+ - For integer arrays, padding is done with the minimum value of the dtype
+ - np.nanmax is used to handle NaN values in the computation
+
+ Example:
+ >>> x = np.array([1, 3, 2, 4, 5, 1, 2])
+ >>> max_pool_1d(x, kernel_size=3, stride=2)
+ array([3, 5, 2])
+ """
+ axis = axis % x.ndim
+ if padding > 0:
+ fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min
+ padding_arr = np.full(
+ (*x.shape[:axis], padding, *x.shape[axis + 1 :]),
+ fill_value=fill_value,
+ dtype=x.dtype,
+ )
+ x = np.concatenate([padding_arr, x, padding_arr], axis=axis)
+ a_sliding = sliding_window_1d(x, kernel_size, stride, axis)
+ max_pool = np.nanmax(a_sliding, axis=-1)
+ return max_pool
+
+
+def max_pool_nd(
+ x: np.ndarray,
+ kernel_size: Tuple[int, ...],
+ stride: Tuple[int, ...],
+ padding: Tuple[int, ...],
+ axis: Tuple[int, ...],
+) -> np.ndarray:
+ """
+ Perform N-dimensional max pooling on the input array.
+
+ This function applies max_pool_1d sequentially along multiple axes to perform
+ multi-dimensional max pooling. This is useful for downsampling multi-dimensional
+ data while preserving the most important features.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis
+ stride (Tuple[int, ...]): Stride of the pooling operation for each axis
+ padding (Tuple[int, ...]): Amount of padding for each axis
+ axis (Tuple[int, ...]): Axes to perform max pooling over
+
+ Returns:
+ np.ndarray: Max pooled array with reduced size along the specified axes
+
+ Note:
+ The length of kernel_size, stride, padding, and axis tuples must be equal.
+ Max pooling is applied sequentially along each axis in the order specified.
+
+ Example:
+ >>> x = np.random.rand(10, 10, 10)
+ >>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2),
+ ... padding=(0, 0, 0), axis=(-3, -2, -1))
+ >>> # Reduces each dimension by half with 2x2x2 max pooling
+ """
+ for i in range(len(axis)):
+ x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i])
+ return x
+
+
+def max_pool_2d(
+ x: np.ndarray,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]],
+ padding: Union[int, Tuple[int, int]],
+ axis: Tuple[int, int] = (-2, -1),
+):
+ """
+ Perform 2D max pooling on the input array.
+
+ Convenience function for 2D max pooling, commonly used in computer vision
+ and image processing for downsampling images while preserving important features.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel.
+ If int, same size is used for both dimensions.
+ stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation.
+ If int, same stride is used for both dimensions.
+ padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions.
+ If int, same padding is used for both dimensions.
+ axis (Tuple[int, int], optional): Two axes to perform max pooling over.
+ Defaults to (-2, -1) (last two dimensions).
+
+ Returns:
+ np.ndarray: 2D max pooled array with reduced size along the specified axes
+
+ Example:
+ >>> image = np.random.rand(64, 64)
+ >>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0)
+ >>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling
+ """
+ if isinstance(kernel_size, Number):
+ kernel_size = (kernel_size, kernel_size)
+ if isinstance(stride, Number):
+ stride = (stride, stride)
+ if isinstance(padding, Number):
+ padding = (padding, padding)
+ axis = tuple(axis)
+ return max_pool_nd(x, kernel_size, stride, padding, axis)
+
+
+@no_warnings(category=RuntimeWarning)
+def depth_edge(
+ depth: np.ndarray,
+ atol: float = None,
+ rtol: float = None,
+ kernel_size: int = 3,
+ mask: np.ndarray = None,
+) -> np.ndarray:
+ """
+ Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
+
+ Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ if mask is None:
+ diff = max_pool_2d(
+ depth, kernel_size, stride=1, padding=kernel_size // 2
+ ) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
+ else:
+ diff = max_pool_2d(
+ np.where(mask, depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ ) + max_pool_2d(
+ np.where(mask, -depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+
+ edge = np.zeros_like(depth, dtype=bool)
+ if atol is not None:
+ edge |= diff > atol
+
+ if rtol is not None:
+ edge |= diff / depth > rtol
+ return edge
+
+
+def depth_aliasing(
+ depth: np.ndarray,
+ atol: float = None,
+ rtol: float = None,
+ kernel_size: int = 3,
+ mask: np.ndarray = None,
+) -> np.ndarray:
+ """
+ Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
+ Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ if mask is None:
+ diff_max = (
+ max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
+ )
+ diff_min = (
+ max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
+ )
+ else:
+ diff_max = (
+ max_pool_2d(
+ np.where(mask, depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+ - depth
+ )
+ diff_min = (
+ max_pool_2d(
+ np.where(mask, -depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+ + depth
+ )
+ diff = np.minimum(diff_max, diff_min)
+
+ edge = np.zeros_like(depth, dtype=bool)
+ if atol is not None:
+ edge |= diff > atol
+ if rtol is not None:
+ edge |= diff / depth > rtol
+ return edge
+
+
+@no_warnings(category=RuntimeWarning)
+def normals_edge(
+ normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None
+) -> np.ndarray:
+ """
+ Compute the edge mask from normal map.
+
+ Args:
+ normal (np.ndarray): shape (..., height, width, 3), normal map
+ tol (float): tolerance in degrees
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ assert normals.ndim >= 3 and normals.shape[-1] == 3, (
+ "normal should be of shape (..., height, width, 3)"
+ )
+ normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12)
+
+ padding = kernel_size // 2
+ normals_window = sliding_window_2d(
+ np.pad(
+ normals,
+ (
+ *([(0, 0)] * (normals.ndim - 3)),
+ (padding, padding),
+ (padding, padding),
+ (0, 0),
+ ),
+ mode="edge",
+ ),
+ window_size=kernel_size,
+ stride=1,
+ axis=(-3, -2),
+ )
+ if mask is None:
+ angle_diff = np.arccos(
+ (normals[..., None, None] * normals_window).sum(axis=-3)
+ ).max(axis=(-2, -1))
+ else:
+ mask_window = sliding_window_2d(
+ np.pad(
+ mask,
+ (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)),
+ mode="edge",
+ ),
+ window_size=kernel_size,
+ stride=1,
+ axis=(-3, -2),
+ )
+ angle_diff = np.where(
+ mask_window,
+ np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)),
+ 0,
+ ).max(axis=(-2, -1))
+
+ angle_diff = max_pool_2d(
+ angle_diff, kernel_size, stride=1, padding=kernel_size // 2
+ )
+ edge = angle_diff > np.deg2rad(tol)
+ return edge
diff --git a/mapanything/utils/image.py b/mapanything/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5483edbd5320f292749051be0752f28882ebdbe
--- /dev/null
+++ b/mapanything/utils/image.py
@@ -0,0 +1,325 @@
+"""
+Utility functions for loading, converting, and manipulating images.
+
+This module provides functions for:
+- Converting between different image formats and representations
+- Resizing and cropping images to specific resolutions
+- Loading and normalizing images for model input
+- Handling various image file formats including HEIF/HEIC when available
+"""
+
+import os
+
+import numpy as np
+import PIL.Image
+import torch
+import torchvision.transforms as tvf
+from PIL.ImageOps import exif_transpose
+
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2
+
+try:
+ from pillow_heif import register_heif_opener
+
+ register_heif_opener()
+ heif_support_enabled = True
+except ImportError:
+ heif_support_enabled = False
+
+from mapanything.utils.cropping import crop_resize_if_necessary
+from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
+
+# Fixed resolution mappings with precomputed aspect ratios as keys
+RESOLUTION_MAPPINGS = {
+ 518: {
+ 1.000: (518, 518), # 1:1
+ 1.321: (518, 392), # 4:3
+ 1.542: (518, 336), # 3:2
+ 1.762: (518, 294), # 16:9
+ 2.056: (518, 252), # 2:1
+ 3.083: (518, 168), # 3.2:1
+ 0.757: (392, 518), # 3:4
+ 0.649: (336, 518), # 2:3
+ 0.567: (294, 518), # 9:16
+ 0.486: (252, 518), # 1:2
+ },
+ 512: {
+ 1.000: (512, 512), # 1:1
+ 1.333: (512, 384), # 4:3
+ 1.524: (512, 336), # 3:2
+ 1.778: (512, 288), # 16:9
+ 2.000: (512, 256), # 2:1
+ 3.200: (512, 160), # 3.2:1
+ 0.750: (384, 512), # 3:4
+ 0.656: (336, 512), # 2:3
+ 0.562: (288, 512), # 9:16
+ 0.500: (256, 512), # 1:2
+ },
+}
+
+# Precomputed sorted aspect ratio keys for efficient lookup
+ASPECT_RATIO_KEYS = {
+ 518: sorted(RESOLUTION_MAPPINGS[518].keys()),
+ 512: sorted(RESOLUTION_MAPPINGS[512].keys()),
+}
+
+
+def find_closest_aspect_ratio(aspect_ratio, resolution_set):
+ """
+ Find the closest aspect ratio from the resolution mappings using efficient key lookup.
+
+ Args:
+ aspect_ratio (float): Target aspect ratio
+ resolution_set (int): Resolution set to use (518 or 512)
+
+ Returns:
+ tuple: (target_width, target_height) from the resolution mapping
+ """
+ aspect_keys = ASPECT_RATIO_KEYS[resolution_set]
+
+ # Find the closest aspect ratio key using binary search approach
+ closest_key = min(aspect_keys, key=lambda x: abs(x - aspect_ratio))
+
+ return RESOLUTION_MAPPINGS[resolution_set][closest_key]
+
+
+def rgb(ftensor, norm_type, true_shape=None):
+ """
+ Convert normalized image tensor to RGB image for visualization.
+
+ Args:
+ ftensor (torch.Tensor or numpy.ndarray or list): Image tensor or list of image tensors
+ norm_type (str): Normalization type, see UniCeption IMAGE_NORMALIZATION_DICT keys or use "identity"
+ true_shape (tuple, optional): If provided, the image will be cropped to this shape (H, W)
+
+ Returns:
+ numpy.ndarray: RGB image with values in range [0, 1]
+ """
+ if isinstance(ftensor, list):
+ return [rgb(x, norm_type, true_shape=true_shape) for x in ftensor]
+ if isinstance(ftensor, torch.Tensor):
+ ftensor = ftensor.detach().cpu().numpy() # H,W,3
+ if ftensor.ndim == 3 and ftensor.shape[0] == 3:
+ ftensor = ftensor.transpose(1, 2, 0)
+ elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
+ ftensor = ftensor.transpose(0, 2, 3, 1)
+ if true_shape is not None:
+ H, W = true_shape
+ ftensor = ftensor[:H, :W]
+ if ftensor.dtype == np.uint8:
+ img = np.float32(ftensor) / 255
+ else:
+ if norm_type in IMAGE_NORMALIZATION_DICT.keys():
+ img_norm = IMAGE_NORMALIZATION_DICT[norm_type]
+ mean = img_norm.mean.numpy()
+ std = img_norm.std.numpy()
+ elif norm_type == "identity":
+ mean = 0.0
+ std = 1.0
+ else:
+ raise ValueError(
+ f"Unknown image normalization type: {norm_type}. Available types: identity or {IMAGE_NORMALIZATION_DICT.keys()}"
+ )
+ img = ftensor * std + mean
+ return img.clip(min=0, max=1)
+
+
+def load_images(
+ folder_or_list,
+ resize_mode="fixed_mapping",
+ size=None,
+ norm_type="dinov2",
+ patch_size=14,
+ verbose=False,
+ bayer_format=False,
+ resolution_set=518,
+ stride=1,
+):
+ """
+ Open and convert all images in a list or folder to proper input format for model
+
+ Args:
+ folder_or_list (str or list): Path to folder or list of image paths.
+ resize_mode (str): Resize mode - "fixed_mapping", "longest_side", "square", or "fixed_size". Defaults to "fixed_mapping".
+ size (int or tuple, optional): Required for "longest_side", "square", and "fixed_size" modes.
+ - For "longest_side" and "square": int value for resize dimension
+ - For "fixed_size": tuple of (width, height)
+ norm_type (str, optional): Image normalization type. See UniCeption IMAGE_NORMALIZATION_DICT keys. Defaults to "dinov2".
+ patch_size (int, optional): Patch size for image processing. Defaults to 14.
+ verbose (bool, optional): If True, print progress messages. Defaults to False.
+ bayer_format (bool, optional): If True, read images in Bayer format. Defaults to False.
+ resolution_set (int, optional): Resolution set to use for "fixed_mapping" mode (518 or 512). Defaults to 518.
+ stride (int, optional): Load every nth image from the input. stride=1 loads all images, stride=2 loads every 2nd image, etc. Defaults to 1.
+
+ Returns:
+ list: List of dictionaries containing image data and metadata
+ """
+ # Validate resize_mode and size parameter requirements
+ valid_resize_modes = ["fixed_mapping", "longest_side", "square", "fixed_size"]
+ if resize_mode not in valid_resize_modes:
+ raise ValueError(
+ f"Resize_mode must be one of {valid_resize_modes}, got '{resize_mode}'"
+ )
+
+ if resize_mode in ["longest_side", "square", "fixed_size"] and size is None:
+ raise ValueError(f"Size parameter is required for resize_mode='{resize_mode}'")
+
+ # Validate size type based on resize mode
+ if resize_mode in ["longest_side", "square"]:
+ if not isinstance(size, int):
+ raise ValueError(
+ f"Size must be an int for resize_mode='{resize_mode}', got {type(size)}"
+ )
+ elif resize_mode == "fixed_size":
+ if not isinstance(size, (tuple, list)) or len(size) != 2:
+ raise ValueError(
+ f"Size must be a tuple/list of (width, height) for resize_mode='fixed_size', got {size}"
+ )
+ if not all(isinstance(x, int) for x in size):
+ raise ValueError(
+ f"Size values must be integers for resize_mode='fixed_size', got {size}"
+ )
+
+ # Get list of image paths
+ if isinstance(folder_or_list, str):
+ # If folder_or_list is a string, assume it's a path to a folder
+ if verbose:
+ print(f"Loading images from {folder_or_list}")
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
+ elif isinstance(folder_or_list, list):
+ # If folder_or_list is a list, assume it's a list of image paths
+ if verbose:
+ print(f"Loading a list of {len(folder_or_list)} images")
+ root, folder_content = "", folder_or_list
+ else:
+ # If folder_or_list is neither a string nor a list, raise an error
+ raise ValueError(f"Bad {folder_or_list=} ({type(folder_or_list)})")
+
+ # Define supported image extensions
+ supported_images_extensions = [".jpg", ".jpeg", ".png"]
+ if heif_support_enabled:
+ supported_images_extensions += [".heic", ".heif"]
+ supported_images_extensions = tuple(supported_images_extensions)
+
+ # First pass: Load all images and collect aspect ratios
+ loaded_images = []
+ aspect_ratios = []
+ for i, path in enumerate(folder_content):
+ # Skip images based on stride
+ if i % stride != 0:
+ continue
+
+ # Check if the file has a supported image extension
+ if not path.lower().endswith(supported_images_extensions):
+ continue
+
+ try:
+ if bayer_format:
+ # If bayer_format is True, read the image in Bayer format
+ color_bayer = cv2.imread(os.path.join(root, path), cv2.IMREAD_UNCHANGED)
+ color = cv2.cvtColor(color_bayer, cv2.COLOR_BAYER_RG2BGR)
+ img = PIL.Image.fromarray(color)
+ img = exif_transpose(img).convert("RGB")
+ else:
+ # Otherwise, read the image normally
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert(
+ "RGB"
+ )
+
+ W1, H1 = img.size
+ aspect_ratios.append(W1 / H1)
+ loaded_images.append((path, img, W1, H1))
+
+ except Exception as e:
+ if verbose:
+ print(f"Warning: Could not load {path}: {e}")
+ continue
+
+ # Check if any images were loaded
+ if not loaded_images:
+ raise ValueError("No valid images found")
+
+ # Calculate average aspect ratio and determine target size
+ average_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios)
+ if verbose:
+ print(
+ f"Calculated average aspect ratio: {average_aspect_ratio:.3f} from {len(aspect_ratios)} images"
+ )
+
+ # Determine target size for all images based on resize mode
+ if resize_mode == "fixed_mapping":
+ # Resolution mappings are already compatible with their respective patch sizes
+ # 518 mappings are divisible by 14, 512 mappings are divisible by 16
+ target_width, target_height = find_closest_aspect_ratio(
+ average_aspect_ratio, resolution_set
+ )
+ target_size = (target_width, target_height)
+ elif resize_mode == "square":
+ target_size = (
+ round((size // patch_size)) * patch_size,
+ round((size // patch_size)) * patch_size,
+ )
+ elif resize_mode == "longest_side":
+ # Use average aspect ratio to determine size for all images
+ # Longest side should be the input size
+ if average_aspect_ratio >= 1: # Landscape or square
+ # Width is the longest side
+ target_size = (
+ size,
+ round((size // patch_size) / average_aspect_ratio) * patch_size,
+ )
+ else: # Portrait
+ # Height is the longest side
+ target_size = (
+ round((size // patch_size) * average_aspect_ratio) * patch_size,
+ size,
+ )
+ elif resize_mode == "fixed_size":
+ # Use exact size provided, aligned to patch_size
+ target_size = (
+ (size[0] // patch_size) * patch_size,
+ (size[1] // patch_size) * patch_size,
+ )
+
+ if verbose:
+ print(
+ f"Using target resolution {target_size[0]}x{target_size[1]} (W x H) for all images"
+ )
+
+ # Second pass: Resize all images to the same target size
+ imgs = []
+ for path, img, W1, H1 in loaded_images:
+ # Resize and crop the image to the target size
+ img = crop_resize_if_necessary(img, resolution=target_size)[0]
+
+ # Normalize image and add it to the list
+ W2, H2 = img.size
+ if verbose:
+ print(f" - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
+
+ if norm_type in IMAGE_NORMALIZATION_DICT.keys():
+ img_norm = IMAGE_NORMALIZATION_DICT[norm_type]
+ ImgNorm = tvf.Compose(
+ [tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)]
+ )
+ else:
+ raise ValueError(
+ f"Unknown image normalization type: {norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}"
+ )
+
+ imgs.append(
+ dict(
+ img=ImgNorm(img)[None],
+ true_shape=np.int32([img.size[::-1]]),
+ idx=len(imgs),
+ instance=str(len(imgs)),
+ data_norm_type=[norm_type],
+ )
+ )
+
+ assert imgs, "No images foud at " + root
+ if verbose:
+ print(f" (Found {len(imgs)} images)")
+
+ return imgs
diff --git a/mapanything/utils/inference.py b/mapanything/utils/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..39e59cda0f9b89ffe76cb23d258e26e6a3a410f7
--- /dev/null
+++ b/mapanything/utils/inference.py
@@ -0,0 +1,86 @@
+"""
+Inference utilities.
+"""
+
+import warnings
+
+import torch
+
+
+def loss_of_one_batch_multi_view(
+ batch,
+ model,
+ criterion,
+ device,
+ use_amp=False,
+ amp_dtype="bf16",
+ ret=None,
+ ignore_keys=None,
+):
+ """
+ Calculate loss for a batch with multiple views.
+
+ Args:
+ batch (list): List of view dictionaries containing input data.
+ model (torch.nn.Module): Model to run inference with.
+ criterion (callable, optional): Loss function to compute the loss.
+ device (torch.device): Device to run the computation on.
+ use_amp (bool, optional): Whether to use automatic mixed precision. Defaults to False.
+ amp_dtype (str, optional): Floating point type to use for automatic mixed precision. Options: ["fp32", "fp16", "bf16"]. Defaults to "bf16".
+ ret (str, optional): If provided, return only the specified key from the result dictionary.
+ ignore_keys (set, optional): Set of keys to ignore when moving tensors to device.
+ Defaults to {"dataset", "label", "instance",
+ "idx", "true_shape", "rng", "data_norm_type"}.
+
+ Returns:
+ dict or Any: If ret is None, returns a dictionary containing views, predictions, and loss.
+ Otherwise, returns the value associated with the ret key.
+ """
+ # Move necessary tensors to device
+ if ignore_keys is None:
+ ignore_keys = set(
+ [
+ "depthmap",
+ "dataset",
+ "label",
+ "instance",
+ "idx",
+ "true_shape",
+ "rng",
+ "data_norm_type",
+ ]
+ )
+ for view in batch:
+ for name in view.keys():
+ if name in ignore_keys:
+ continue
+ view[name] = view[name].to(device, non_blocking=True)
+
+ # Determine the mixed precision floating point type
+ if use_amp:
+ if amp_dtype == "fp16":
+ amp_dtype = torch.float16
+ elif amp_dtype == "bf16":
+ if torch.cuda.is_bf16_supported():
+ amp_dtype = torch.bfloat16
+ else:
+ warnings.warn(
+ "bf16 is not supported on this device. Using fp16 instead."
+ )
+ amp_dtype = torch.float16
+ elif amp_dtype == "fp32":
+ amp_dtype = torch.float32
+ else:
+ amp_dtype = torch.float32
+
+ # Run model and compute loss
+ with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
+ preds = model(batch)
+ with torch.autocast("cuda", enabled=False):
+ loss = criterion(batch, preds) if criterion is not None else None
+
+ result = {f"view{i + 1}": view for i, view in enumerate(batch)}
+ result.update({f"pred{i + 1}": pred for i, pred in enumerate(preds)})
+ result["loss"] = loss
+
+ return result[ret] if ret else result
diff --git a/mapanything/utils/metrics.py b/mapanything/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..29d28b5a6856daface74ddebe50d51ca9548ee51
--- /dev/null
+++ b/mapanything/utils/metrics.py
@@ -0,0 +1,504 @@
+"""
+Utils for Metrics
+Source for Pose AUC Metrics: VGGT
+"""
+
+import math
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def l2_distance_of_unit_quats_to_angular_error(l2_distance):
+ """
+ Converts a given L2 distance (for unit quaternions) to the angular error in degrees.
+ For two quaternions differing by an angle θ the relationship is:
+ L2 distance = 2 * sin(θ/4)
+ Hence, the angular error in degrees is computed as:
+ 4 * asin(l2_distance / 2) * (180/π)
+
+ Args:
+ l2_distance: L2 distance between two unit quaternions (torch.Tensor, shape: (N,))
+ Returns:
+ angular_error_degrees: Angular error in degrees (torch.Tensor, shape: (N,))
+ """
+ angular_error_radians = 4 * torch.asin(l2_distance / 2)
+ angular_error_degrees = angular_error_radians * 180.0 / math.pi
+
+ return angular_error_degrees
+
+
+def l2_distance_of_unit_ray_directions_to_angular_error(l2_distance):
+ """
+ Converts a given L2 distance (for unit ray directions) to the angular error in degrees.
+ For two unit ray directions differing by an angle θ the relationship is:
+ L2 distance = 2 * sin(θ/2)
+ Hence, the angular error in degrees is computed as:
+ 2 * asin(l2_distance / 2) * (180/π)
+
+ Args:
+ l2_distance: L2 distance between two unit ray directions (torch.Tensor, shape: (N,))
+ Returns:
+ angular_error_degrees: Angular error in degrees (torch.Tensor, shape: (N,))
+ """
+ angular_error_radians = 2 * torch.asin(l2_distance / 2)
+ angular_error_degrees = angular_error_radians * 180.0 / math.pi
+
+ return angular_error_degrees
+
+
+def valid_mean(arr, mask, axis=None, keepdims=np._NoValue):
+ """Compute mean of elements across given dimensions of an array, considering only valid elements.
+
+ Args:
+ arr: The array to compute the mean.
+ mask: Array with numerical or boolean values for element weights or validity. For bool, False means invalid.
+ axis: Dimensions to reduce.
+ keepdims: If true, retains reduced dimensions with length 1.
+
+ Returns:
+ Mean array/scalar and a valid array/scalar that indicates where the mean could be computed successfully.
+ """
+
+ mask = mask.astype(arr.dtype) if mask.dtype == bool else mask
+ num_valid = np.sum(mask, axis=axis, keepdims=keepdims)
+ masked_arr = arr * mask
+ masked_arr_sum = np.sum(masked_arr, axis=axis, keepdims=keepdims)
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ valid_mean = masked_arr_sum / num_valid
+ is_valid = np.isfinite(valid_mean)
+ valid_mean = np.nan_to_num(valid_mean, nan=0, posinf=0, neginf=0)
+
+ return valid_mean, is_valid
+
+
+def thresh_inliers(gt, pred, thresh=1.03, mask=None, output_scaling_factor=1.0):
+ """Computes the inlier (=error within a threshold) ratio for a predicted and ground truth dense map of size H x W x C.
+
+ Args:
+ gt: Ground truth depth map as numpy array of shape HxW. Negative or 0 values are invalid and ignored.
+ pred: Predicted depth map as numpy array of shape HxW.
+ thresh: Threshold for the relative difference between the prediction and ground truth. Default: 1.03
+ mask: Array of shape HxW with boolean values to indicate validity. For bool, False means invalid. Default: None
+ output_scaling_factor: Scaling factor that is applied after computing the metrics (e.g. to get [%]). Default: 1
+
+ Returns:
+ Scalar that indicates the inlier ratio. Scalar is np.nan if the result is invalid.
+ """
+ # Compute the norms
+ gt_norm = np.linalg.norm(gt, axis=-1)
+ pred_norm = np.linalg.norm(pred, axis=-1)
+
+ gt_norm_valid = (gt_norm) > 0
+ if mask is not None:
+ combined_mask = mask & gt_norm_valid
+ else:
+ combined_mask = gt_norm_valid
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ rel_1 = np.nan_to_num(
+ gt_norm / pred_norm, nan=thresh + 1, posinf=thresh + 1, neginf=thresh + 1
+ ) # pred=0 should be an outlier
+ rel_2 = np.nan_to_num(
+ pred_norm / gt_norm, nan=0, posinf=0, neginf=0
+ ) # gt=0 is masked out anyways
+
+ max_rel = np.maximum(rel_1, rel_2)
+ inliers = ((0 < max_rel) & (max_rel < thresh)).astype(
+ np.float32
+ ) # 1 for inliers, 0 for outliers
+
+ inlier_ratio, valid = valid_mean(inliers, combined_mask)
+
+ inlier_ratio = inlier_ratio * output_scaling_factor
+ inlier_ratio = inlier_ratio if valid else np.nan
+
+ return inlier_ratio
+
+
+def m_rel_ae(gt, pred, mask=None, output_scaling_factor=1.0):
+ """Computes the mean-relative-absolute-error for a predicted and ground truth dense map of size HxWxC.
+
+ Args:
+ gt: Ground truth map as numpy array of shape H x W x C.
+ pred: Predicted map as numpy array of shape H x W x C.
+ mask: Array of shape HxW with boolean values to indicate validity. For bool, False means invalid. Default: None
+ output_scaling_factor: Scaling factor that is applied after computing the metrics (e.g. to get [%]). Default: 1
+
+ Returns:
+ Scalar that indicates the mean-relative-absolute-error. Scalar is np.nan if the result is invalid.
+ """
+ error_norm = np.linalg.norm(pred - gt, axis=-1)
+ gt_norm = np.linalg.norm(gt, axis=-1)
+
+ gt_norm_valid = (gt_norm) > 0
+ if mask is not None:
+ combined_mask = mask & gt_norm_valid
+ else:
+ combined_mask = gt_norm_valid
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ rel_ae = np.nan_to_num(error_norm / gt_norm, nan=0, posinf=0, neginf=0)
+
+ m_rel_ae, valid = valid_mean(rel_ae, combined_mask)
+
+ m_rel_ae = m_rel_ae * output_scaling_factor
+ m_rel_ae = m_rel_ae if valid else np.nan
+
+ return m_rel_ae
+
+
+def align(model, data):
+ """Align two trajectories using the method of Horn (closed-form).
+
+ Args:
+ model -- first trajectory (3xn)
+ data -- second trajectory (3xn)
+
+ Returns:
+ rot -- rotation matrix (3x3)
+ trans -- translation vector (3x1)
+ trans_error -- translational error per point (1xn)
+
+ """
+ np.set_printoptions(precision=3, suppress=True)
+ model_zerocentered = model - model.mean(1).reshape((3, -1))
+ data_zerocentered = data - data.mean(1).reshape((3, -1))
+
+ W = np.zeros((3, 3))
+ for column in range(model.shape[1]):
+ W += np.outer(model_zerocentered[:, column], data_zerocentered[:, column])
+ U, d, Vh = np.linalg.linalg.svd(W.transpose())
+ S = np.matrix(np.identity(3))
+ if np.linalg.det(U) * np.linalg.det(Vh) < 0:
+ S[2, 2] = -1
+ rot = U * S * Vh
+ trans = data.mean(1).reshape((3, -1)) - rot * model.mean(1).reshape((3, -1))
+
+ model_aligned = rot * model + trans
+ alignment_error = model_aligned - data
+
+ trans_error = np.sqrt(np.sum(np.multiply(alignment_error, alignment_error), 0)).A[0]
+
+ return rot, trans, trans_error
+
+
+def evaluate_ate(gt_traj, est_traj):
+ """
+ Input :
+ gt_traj: list of 4x4 matrices
+ est_traj: list of 4x4 matrices
+ len(gt_traj) == len(est_traj)
+ """
+ gt_traj_pts = [gt_traj[idx][:3, 3] for idx in range(len(gt_traj))]
+ est_traj_pts = [est_traj[idx][:3, 3] for idx in range(len(est_traj))]
+
+ gt_traj_pts = torch.stack(gt_traj_pts).detach().cpu().numpy().T
+ est_traj_pts = torch.stack(est_traj_pts).detach().cpu().numpy().T
+
+ _, _, trans_error = align(gt_traj_pts, est_traj_pts)
+
+ avg_trans_error = trans_error.mean()
+
+ return avg_trans_error
+
+
+def build_pair_index(N, B=1):
+ """
+ Build indices for all possible pairs of frames.
+
+ Args:
+ N: Number of frames
+ B: Batch size
+
+ Returns:
+ i1, i2: Indices for all possible pairs
+ """
+ i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
+ i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
+ return i1, i2
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,)) # pylint: disable=not-callable
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
+
+
+def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
+ """
+ Calculate rotation angle error between ground truth and predicted rotations.
+
+ Args:
+ rot_gt: Ground truth rotation matrices
+ rot_pred: Predicted rotation matrices
+ batch_size: Batch size for reshaping the result
+ eps: Small value to avoid numerical issues
+
+ Returns:
+ Rotation angle error in degrees
+ """
+ q_pred = mat_to_quat(rot_pred)
+ q_gt = mat_to_quat(rot_gt)
+
+ loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
+ err_q = torch.arccos(1 - 2 * loss_q)
+
+ rel_rangle_deg = err_q * 180 / np.pi
+
+ if batch_size is not None:
+ rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
+
+ return rel_rangle_deg
+
+
+def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
+ """
+ Calculate translation angle error between ground truth and predicted translations.
+
+ Args:
+ tvec_gt: Ground truth translation vectors
+ tvec_pred: Predicted translation vectors
+ batch_size: Batch size for reshaping the result
+ ambiguity: Whether to handle direction ambiguity
+
+ Returns:
+ Translation angle error in degrees
+ """
+ rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
+ rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
+
+ if ambiguity:
+ rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
+
+ if batch_size is not None:
+ rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
+
+ return rel_tangle_deg
+
+
+def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
+ """
+ Normalize the translation vectors and compute the angle between them.
+
+ Args:
+ t_gt: Ground truth translation vectors
+ t: Predicted translation vectors
+ eps: Small value to avoid division by zero
+ default_err: Default error value for invalid cases
+
+ Returns:
+ Angular error between translation vectors in radians
+ """
+ t_norm = torch.norm(t, dim=1, keepdim=True)
+ t = t / (t_norm + eps)
+
+ t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
+ t_gt = t_gt / (t_gt_norm + eps)
+
+ loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
+ err_t = torch.acos(torch.sqrt(1 - loss_t))
+
+ err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
+ return err_t
+
+
+def calculate_auc_np(r_error, t_error, max_threshold=30):
+ """
+ Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
+
+ Args:
+ r_error: numpy array representing R error values (Degree)
+ t_error: numpy array representing T error values (Degree)
+ max_threshold: Maximum threshold value for binning the histogram
+
+ Returns:
+ AUC value and the normalized histogram
+ """
+ error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
+ max_errors = np.max(error_matrix, axis=1)
+ bins = np.arange(max_threshold + 1)
+ histogram, _ = np.histogram(max_errors, bins=bins)
+ num_pairs = float(len(max_errors))
+ normalized_histogram = histogram.astype(float) / num_pairs
+ return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
+
+
+def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
+ """
+ Compute rotation and translation errors between predicted and ground truth poses.
+
+ Args:
+ pred_se3: Predicted SE(3) transformations
+ gt_se3: Ground truth SE(3) transformations
+ num_frames: Number of frames
+
+ Returns:
+ Rotation and translation angle errors in degrees
+ """
+ pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
+
+ # Compute relative camera poses between pairs
+ # We use closed_form_inverse to avoid potential numerical loss by torch.inverse()
+ relative_pose_gt = closed_form_inverse_se3(gt_se3[pair_idx_i1]).bmm(
+ gt_se3[pair_idx_i2]
+ )
+ relative_pose_pred = closed_form_inverse_se3(pred_se3[pair_idx_i1]).bmm(
+ pred_se3[pair_idx_i2]
+ )
+
+ # Compute the difference in rotation and translation
+ rel_rangle_deg = rotation_angle(
+ relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
+ )
+ rel_tangle_deg = translation_angle(
+ relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
+ )
+
+ return rel_rangle_deg, rel_tangle_deg
diff --git a/mapanything/utils/misc.py b/mapanything/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b976fa1df7430465961bc946d1744ae5665faa
--- /dev/null
+++ b/mapanything/utils/misc.py
@@ -0,0 +1,109 @@
+"""
+Miscellaneous utility functions.
+"""
+
+import logging
+import os
+import random
+
+import numpy as np
+import torch
+
+
+class StreamToLogger:
+ """
+ A class that redirects stream writes to a logger.
+
+ This class can be used to redirect stdout or stderr to a logger
+ by implementing a file-like interface with write and flush methods.
+
+ Parameters:
+ - logger: A logger instance that will receive the log messages
+ - log_level: The logging level to use (default: logging.INFO)
+ """
+
+ def __init__(self, logger, log_level=logging.INFO):
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ""
+
+ def write(self, buf):
+ """
+ Write the buffer content to the logger.
+
+ Parameters:
+ - buf: The string buffer to write
+ """
+ for line in buf.rstrip().splitlines():
+ self.logger.log(self.log_level, line.rstrip())
+
+ def flush(self):
+ """
+ Flush method to comply with file-like object interface.
+ This method is required but does nothing in this implementation.
+ """
+ pass
+
+
+def seed_everything(seed: int = 42):
+ """
+ Set the `seed` value for torch and numpy seeds. Also turns on
+ deterministic execution for cudnn.
+
+ Parameters:
+ - seed: A hashable seed value
+ """
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ print(f"Seed set to: {seed}")
+
+
+def invalid_to_nans(arr, valid_mask, ndim=999):
+ """
+ Replace invalid values in an array with NaN values based on a validity mask.
+
+ Parameters:
+ - arr: Input array (typically a PyTorch tensor)
+ - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False)
+ - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim
+
+ Returns:
+ - Modified array with invalid values replaced by NaN
+ """
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = float("nan")
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr
+
+
+def invalid_to_zeros(arr, valid_mask, ndim=999):
+ """
+ Replace invalid values in an array with zeros based on a validity mask.
+
+ Parameters:
+ - arr: Input array (typically a PyTorch tensor)
+ - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False)
+ - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim
+
+ Returns:
+ - Tuple containing:
+ - Modified array with invalid values replaced by zeros
+ - nnz: Number of non-zero (valid) elements per sample in the batch
+ """
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = 0
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
+ else:
+ nnz = (
+ arr[..., 0].numel() // len(arr) if len(arr) else 0
+ ) # Number of pixels per image
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr, nnz
diff --git a/mapanything/utils/parallel.py b/mapanything/utils/parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5c6f826291646a6c4059d3b6ae4c3d71d5e0104
--- /dev/null
+++ b/mapanything/utils/parallel.py
@@ -0,0 +1,158 @@
+"""
+Utility functions for multiprocessing
+"""
+
+import os
+from multiprocessing.dummy import Pool as ThreadPool
+
+import torch
+from torch.multiprocessing import Pool as TorchPool, set_start_method
+from tqdm import tqdm
+
+
+def cpu_count():
+ """
+ Returns the number of available CPUs for the python process
+ """
+ return len(os.sched_getaffinity(0))
+
+
+def parallel_threads(
+ function,
+ args,
+ workers=0,
+ star_args=False,
+ kw_args=False,
+ front_num=1,
+ Pool=ThreadPool,
+ ordered_res=True,
+ **tqdm_kw,
+):
+ """tqdm but with parallel execution.
+
+ Will essentially return
+ res = [ function(arg) # default
+ function(*arg) # if star_args is True
+ function(**arg) # if kw_args is True
+ for arg in args]
+
+ Note:
+ the first elements of args will not be parallelized.
+ This can be useful for debugging.
+ """
+ # Determine the number of workers
+ while workers <= 0:
+ workers += cpu_count()
+
+ # Convert args to an iterable
+ try:
+ n_args_parallel = len(args) - front_num
+ except TypeError:
+ n_args_parallel = None
+ args = iter(args)
+
+ # Sequential execution for the first few elements (useful for debugging)
+ front = []
+ while len(front) < front_num:
+ try:
+ a = next(args)
+ except StopIteration:
+ return front # end of the iterable
+ front.append(
+ function(*a) if star_args else function(**a) if kw_args else function(a)
+ )
+
+ # Parallel execution using multiprocessing.dummy
+ out = []
+ with Pool(workers) as pool:
+ if star_args:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(starcall, [(function, a) for a in args])
+ elif kw_args:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(starstarcall, [(function, a) for a in args])
+ else:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(function, args)
+ # Track progress with tqdm
+ for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
+ out.append(f)
+ return front + out
+
+
+def cuda_parallel_threads(
+ function,
+ args,
+ workers=0,
+ star_args=False,
+ kw_args=False,
+ front_num=1,
+ Pool=TorchPool,
+ ordered_res=True,
+ **tqdm_kw,
+):
+ """
+ Parallel execution of a function using torch.multiprocessing with CUDA support.
+ This is the CUDA variant of the parallel_threads function.
+ """
+ # Set the start method for multiprocessing
+ set_start_method("spawn", force=True)
+
+ # Determine the number of workers
+ while workers <= 0:
+ workers += torch.multiprocessing.cpu_count()
+
+ # Convert args to an iterable
+ try:
+ n_args_parallel = len(args) - front_num
+ except TypeError:
+ n_args_parallel = None
+ args = iter(args)
+
+ # Sequential execution for the first few elements (useful for debugging)
+ front = []
+ while len(front) < front_num:
+ try:
+ a = next(args)
+ except StopIteration:
+ return front # End of the iterable
+ front.append(
+ function(*a) if star_args else function(**a) if kw_args else function(a)
+ )
+
+ # Parallel execution using torch.multiprocessing
+ out = []
+ with Pool(workers) as pool:
+ if star_args:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(starcall, [(function, a) for a in args])
+ elif kw_args:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(starstarcall, [(function, a) for a in args])
+ else:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(function, args)
+ # Track progress with tqdm
+ for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
+ out.append(f)
+ return front + out
+
+
+def parallel_processes(*args, **kwargs):
+ """Same as parallel_threads, with processes"""
+ import multiprocessing as mp
+
+ kwargs["Pool"] = mp.Pool
+ return parallel_threads(*args, **kwargs)
+
+
+def starcall(args):
+ """convenient wrapper for Process.Pool"""
+ function, args = args
+ return function(*args)
+
+
+def starstarcall(args):
+ """convenient wrapper for Process.Pool"""
+ function, args = args
+ return function(**args)
diff --git a/mapanything/utils/timing.py b/mapanything/utils/timing.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba9ee891d33420c51644f11f38bf169a1a610c27
--- /dev/null
+++ b/mapanything/utils/timing.py
@@ -0,0 +1,304 @@
+"""
+Utility functions for timing code blocks
+"""
+
+import time
+from contextlib import ContextDecorator
+
+import numpy as np
+
+
+class BlockTimeManager:
+ """
+ Manages a collection of timers and their formatting options.
+
+ This class serves as a central registry for Timer objects, allowing them to be
+ accessed by name and maintaining their formatting preferences.
+
+ Attributes:
+ timers (dict): Dictionary mapping timer names to Timer objects
+ timer_fmts (dict): Dictionary mapping timer names to their display formats
+ window_size (int): Default window size for calculating windowed averages
+ buf_size (int): Default buffer size for storing timing measurements
+ """
+
+ def __init__(self, window_size=10, buf_size=100000):
+ self.timers = dict()
+ self.timer_fmts = dict()
+ self.window_size = window_size
+ self.buf_size = buf_size
+
+
+btm = BlockTimeManager(window_size=100000)
+
+
+class Timer:
+ """
+ Core timing class that tracks execution times.
+
+ This class provides the fundamental timing functionality, storing timing measurements
+ and calculating various statistics.
+
+ Attributes:
+ name (str): Identifier for this timer
+ buf_size (int): Maximum number of timing measurements to store
+ window_size (int): Number of most recent measurements to use for windowed statistics
+ measures_arr (numpy.ndarray): Array storing start and end times of measurements
+ current_start (float or None): Start time of current measurement
+ current_end (float or None): End time of current measurement
+ """
+
+ def __init__(self, name, window_size, buf_size=100000):
+ self.name = name
+ self.buf_size = buf_size
+ self.window_size = window_size
+ self.init()
+
+ def init(self):
+ """Initialize or reset the timer's state."""
+ self.measures_arr = np.empty((0, 2)) # LIFO
+ self.current_start = None
+ self.current_end = None
+
+ def reset(self):
+ """Reset the timer to its initial state."""
+ self.init()
+
+ def tic(self):
+ """Start a new timing measurement."""
+ if self.current_start is not None:
+ # another tic executed before a toc
+ self.toc()
+ self.current_start = time.perf_counter()
+
+ def toc(self):
+ """End the current timing measurement."""
+ self.current_end = time.perf_counter()
+ self._add_current_measure()
+
+ def _add_current_measure(self):
+ """Add the current timing measurement to the measurements array."""
+ self.measures_arr = np.concatenate(
+ [
+ np.array([[self.current_start, self.current_end]]),
+ self.measures_arr[: self.buf_size],
+ ]
+ )
+ self.current_start = None
+ self.current_end = None
+
+ @property
+ def avg(self) -> float:
+ """Calculate the average execution time across all measurements."""
+ return np.mean(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def wavg(self) -> float:
+ """Calculate the windowed average execution time using the most recent measurements."""
+ return np.mean(
+ self.measures_arr[: self.window_size, 1]
+ - self.measures_arr[: self.window_size, 0]
+ )
+
+ @property
+ def max(self) -> float:
+ """Return the maximum execution time."""
+ return np.max(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def min(self) -> float:
+ """Return the minimum execution time."""
+ return np.min(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def total(self) -> float:
+ """Return the total execution time across all measurements."""
+ return np.sum(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def latest(self) -> float:
+ """Return the most recent execution time."""
+ return self.measures_arr[0, 1] - self.measures_arr[0, 0]
+
+ @property
+ def median(self) -> float:
+ """Return the median execution time."""
+ return np.median(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def var(self) -> float:
+ """Return the variance of execution times."""
+ return np.var(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+
+class BlockTimer(ContextDecorator):
+ """
+ A context manager and decorator for timing code blocks.
+
+ This class provides a convenient interface for timing code execution, either as a
+ context manager (with statement) or as a decorator. It uses the Timer class for
+ the actual timing functionality.
+
+ Attributes:
+ name (str): Identifier for this timer
+ fmt (str or None): Format string for displaying timing information
+ timer (Timer): The underlying Timer object
+ num_calls (int): Number of times this timer has been called
+ """
+
+ @staticmethod
+ def timers():
+ """Return a list of all registered timer names."""
+ return list(btm.timers.keys())
+
+ def __init__(self, name, fmt=None, window_size=100):
+ self.name = name
+ if name in btm.timers:
+ self.timer = btm.timers[name]
+ # restore format
+ self.fmt = fmt if fmt is not None else btm.timer_fmts[name]
+ else:
+ self.timer = Timer(name, btm.window_size, btm.buf_size)
+ btm.timers[name] = self.timer
+ btm.timer_fmts[name] = fmt
+ self.timer.window_size = window_size
+ self._default_fmt = "[{name}] num: {num} latest: {latest:.4f} --wind_avg: {wavg:.4f} -- avg: {avg:.4f} --var: {var:.4f} -- total: {total:.4f}"
+ if fmt == "default":
+ self.fmt = self._default_fmt
+ # extend here for new formats
+ else:
+ self.fmt = None
+
+ self.num_calls = 0
+
+ def __enter__(self) -> "Timer":
+ """Start timing when entering a context."""
+ self.tic()
+ return self
+
+ def __exit__(self, *args):
+ """End timing when exiting a context and optionally display results."""
+ self.toc()
+ if self.fmt is not None:
+ print(str(self))
+
+ def __str__(self) -> str:
+ """Return a string representation of the timer."""
+ return self.display()
+
+ def reset(self):
+ """Reset the timer and call counter."""
+ self.timer.reset()
+ self.num_calls = 0
+
+ def display(self, fmt=None):
+ """
+ Format and return timing information.
+
+ Args:
+ fmt (str, optional): Format string to use. If None, uses the timer's format.
+
+ Returns:
+ str: Formatted timing information
+ """
+ if fmt is None:
+ if self.fmt is not None:
+ fmt = self.fmt
+ else:
+ fmt = self._default_fmt
+ return fmt.format(
+ name=self.name,
+ num=self.num_calls,
+ latest=self.latest,
+ wavg=self.wavg,
+ avg=self.avg,
+ var=self.var,
+ total=self.total,
+ )
+
+ def tic(self):
+ """Start a new timing measurement and increment the call counter."""
+ self.timer.tic()
+ self.num_calls += 1
+
+ def toc(self, display=False):
+ """
+ End the current timing measurement.
+
+ Args:
+ display (bool): Whether to return a formatted display string
+
+ Returns:
+ str or None: Formatted timing information if display is True
+ """
+ self.timer.toc()
+ if display:
+ return self.display()
+
+ @property
+ def latest(self) -> float:
+ """Return the most recent execution time."""
+ return self.timer.latest
+
+ @property
+ def avg(self) -> float:
+ """Return the average execution time."""
+ return self.timer.avg
+
+ @property
+ def wavg(self) -> float:
+ """Return the windowed average execution time."""
+ return self.timer.wavg
+
+ @property
+ def max(self) -> float:
+ """Return the maximum execution time."""
+ return self.timer.max
+
+ @property
+ def min(self) -> float:
+ """Return the minimum execution time."""
+ return self.timer.min
+
+ @property
+ def total(self) -> float:
+ """Return the total execution time."""
+ return self.timer.total
+
+ @property
+ def median(self) -> float:
+ """Return the median execution time."""
+ return self.timer.median
+
+ @property
+ def var(self) -> float:
+ """Return the variance of execution times."""
+ return self.timer.var
+
+
+if __name__ == "__main__":
+
+ @BlockTimer("fct", "default")
+ def fct(bobo):
+ time.sleep(0.5)
+
+ fct(2)
+
+ for i in range(10):
+ with BlockTimer("affe", "default"):
+ time.sleep(0.1)
+ for i in range(1000):
+ with BlockTimer("test", None):
+ time.sleep(0.001)
+
+ # BlockTimer("test").display = f"""avg: {BlockTimer("test").avg} total: {BlockTimer("test").total}"""
+ # print(str(BlockTimer("test")))
+
+ print(BlockTimer("test"))
+ BlockTimer("test").tic()
+ BlockTimer("t2", "default").tic()
+ time.sleep(0.4)
+ print(BlockTimer("t2").toc(True))
+
+ time.sleep(0.4)
+ print(BlockTimer("test").toc(True))
diff --git a/mapanything/utils/train_tools.py b/mapanything/utils/train_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f95d207f541a4d4de61f9111014d0927f91cd08
--- /dev/null
+++ b/mapanything/utils/train_tools.py
@@ -0,0 +1,978 @@
+"""
+Utility functions for training deep learning models, particularly focused on distributed training,
+metric logging, and gradient handling.
+
+This module provides tools for:
+- Tracking and logging metrics during training
+- Setting up distributed training environments
+- Handling gradient scaling and normalization
+- Managing learning rates and parameter groups
+- Saving and loading model checkpoints
+
+References: CroCo (https://github.com/naver/croco)
+"""
+
+import builtins
+import datetime
+import json
+import math
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+from torch import inf
+
+
+class SmoothedValue(object):
+ """
+ Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+class MetricLogger(object):
+ """
+ Logger for tracking and displaying training metrics.
+
+ This class maintains a collection of metrics during training, provides
+ methods to update them, and formats them for display. It also handles
+ synchronization of metrics across processes in distributed training.
+ """
+
+ def __init__(self, delimiter="\t", print_per_view_stats=False):
+ """
+ Initialize the MetricLogger.
+
+ Args:
+ delimiter (str, optional): Delimiter for formatting output. Defaults to "\t".
+ print_per_view_stats (bool, optional): Whether to print per-view statistics. Defaults to False.
+ """
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+ self.print_per_view_stats = print_per_view_stats
+
+ def update(self, **kwargs):
+ """
+ Update metrics with new values.
+
+ Args:
+ **kwargs: Key-value pairs where keys are metric names and values are metric values
+ Values can be tensors or numbers
+
+ Raises:
+ AssertionError: If a value is not a float or int after conversion from tensor
+ """
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ """
+ Get a meter by attribute name.
+
+ This allows accessing meters as attributes of the logger.
+
+ Args:
+ attr (str): Name of the attribute to get
+
+ Returns:
+ SmoothedValue: The meter corresponding to the attribute name
+
+ Raises:
+ AttributeError: If the attribute doesn't exist as a meter or regular attribute
+ """
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+ )
+
+ def __str__(self):
+ """
+ Format all metrics as a string.
+
+ Returns:
+ str: Formatted string containing all metrics
+ """
+ loss_str = []
+ for name, meter in self.meters.items():
+ # Skip printing per-view stats if not enabled
+ if not self.print_per_view_stats and "view" in name:
+ continue
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ """
+ Synchronize metrics across processes in distributed training.
+
+ This method calls synchronize_between_processes on each meter to
+ ensure consistent values across all processes.
+ """
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ """
+ Add a custom meter to the logger.
+
+ Args:
+ name (str): Name of the meter
+ meter (SmoothedValue): The meter to add
+ """
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None, max_iter=None):
+ """
+ Log metrics at regular intervals while iterating.
+
+ This method wraps an iterable and logs metrics every print_freq iterations.
+ It also tracks iteration time, data loading time, and memory usage.
+
+ Args:
+ iterable: Iterable to iterate over (typically a data loader)
+ print_freq (int): How often to log metrics (in iterations)
+ header (str, optional): Header string to print before metrics. Defaults to None.
+ max_iter (int, optional): Maximum number of iterations. Defaults to None.
+
+ Yields:
+ object: Items from the original iterable
+ """
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable)
+ space_fmt = ":" + str(len(str(len_iterable))) + "d"
+ log_msg = [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ if torch.cuda.is_available():
+ log_msg.append("max mem: {memory:.0f}")
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for it, obj in enumerate(iterable):
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len_iterable - 1:
+ eta_seconds = iter_time.global_avg * (len_iterable - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len_iterable,
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len_iterable,
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ if max_iter and it >= max_iter:
+ break
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len_iterable
+ )
+ )
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process.
+
+ It replaces the built-in print function with a custom version that only prints
+ when the current process is the master process or when explicitly forced.
+
+ Args:
+ is_master (bool): Whether the current process is the master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ # force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ """
+ Check if distributed training is available and initialized.
+
+ Returns:
+ bool: True if distributed training is available and initialized, False otherwise
+ """
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ """
+ Get the number of processes in the distributed training group.
+
+ Returns:
+ int: Number of processes in the distributed group, or 1 if not using distributed training
+ """
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ """
+ Get the rank of the current process in the distributed training group.
+
+ Returns:
+ int: Rank of the current process, or 0 if not using distributed training
+ """
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ """
+ Check if the current process is the main process (rank 0).
+
+ Returns:
+ bool: True if the current process is the main process, False otherwise
+ """
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ """
+ Save a PyTorch object only on the master process.
+
+ This function is useful in distributed training to avoid multiple processes
+ trying to save the same file simultaneously.
+
+ Args:
+ *args: Positional arguments to pass to torch.save()
+ **kwargs: Keyword arguments to pass to torch.save()
+ """
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ """
+ Initialize distributed training mode.
+
+ This function sets up the distributed training environment based on environment
+ variables and command-line arguments. It initializes the process group,
+ sets the appropriate device, and configures printing for the distributed setup.
+
+ Args:
+ args: Arguments object containing distributed training configuration.
+ Expected to have attributes like dist_url, and will be modified
+ to include rank, world_size, gpu, and distributed flag.
+ """
+ nodist = args.nodist if hasattr(args, "nodist") else False
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not nodist:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ else:
+ print("Not using distributed mode")
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print(
+ "| distributed init (rank {}): {}, gpu {}".format(
+ args.rank, args.dist_url, args.gpu
+ ),
+ flush=True,
+ )
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ """
+ A gradient scaler that handles gradient scaling and norm computation for mixed precision training.
+
+ This class wraps PyTorch's GradScaler to provide additional functionality for gradient norm tracking
+ and clipping during mixed precision training.
+ """
+
+ state_dict_key = "amp_scaler"
+
+ def __init__(self, enabled=True):
+ """Initialize the scaler.
+
+ Args:
+ enabled (bool): Whether to enable gradient scaling. Default: True
+ """
+ self._scaler = torch.GradScaler("cuda", enabled=enabled)
+
+ def __call__(
+ self,
+ loss,
+ optimizer,
+ clip_grad=None,
+ parameters=None,
+ create_graph=False,
+ update_grad=True,
+ ):
+ """Scales loss and performs backward pass with optional gradient clipping.
+
+ Args:
+ loss: The loss to backpropagate
+ optimizer: The optimizer being used
+ clip_grad: Max norm for gradient clipping. None means no clipping
+ parameters: Model parameters or list of parameters for gradient norm computation
+ create_graph: Whether to create graph during backward pass
+ update_grad: Whether to update gradients
+
+ Returns:
+ norm: The gradient norm if computed, else None. Returns list of norms if parameters is a list.
+ """
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(
+ optimizer
+ ) # unscale the gradients of optimizer's assigned params in-place
+ if isinstance(parameters, (list, tuple)):
+ norm = [
+ torch.nn.utils.clip_grad_norm_(p, clip_grad) for p in parameters
+ ]
+ else:
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ """Returns the state dict of the underlying scaler.
+
+ Returns:
+ dict: The state dict of the gradient scaler
+ """
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ """Loads the state dict into the underlying scaler.
+
+ Args:
+ state_dict: The state dict to load
+ """
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ """
+ Calculate the gradient norm of parameters.
+
+ This function computes the norm of gradients for a set of parameters. It can handle
+ both single parameter groups and multiple parameter groups (list/tuple of parameters).
+
+ Args:
+ parameters: A tensor or iterable of tensors or iterable of iterables of tensors
+ containing model parameters for which to compute gradient norms
+ norm_type (float): Type of norm to use (e.g., 2.0 for L2 norm, inf for infinity norm)
+
+ Returns:
+ torch.Tensor: The computed gradient norm. If parameters is a list/tuple of parameter
+ groups, returns a list of norms, one for each group.
+ """
+ if isinstance(parameters, (list, tuple)):
+ # If parameters is already a list/tuple, process each parameter group
+ all_norms = []
+ for params in parameters:
+ if isinstance(params, torch.Tensor):
+ params = [params]
+ params = [p for p in params if p.grad is not None]
+ if len(params) > 0:
+ device = params[0].grad.device
+ if norm_type == inf:
+ group_norm = max(
+ p.grad.detach().abs().max().to(device) for p in params
+ )
+ else:
+ group_norm = torch.norm(
+ torch.stack(
+ [
+ torch.norm(p.grad.detach(), norm_type).to(device)
+ for p in params
+ ]
+ ),
+ norm_type,
+ )
+ else:
+ group_norm = torch.tensor(0.0)
+ all_norms.append(group_norm)
+ return all_norms
+
+ # Original logic for single parameter group
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.0)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(
+ torch.stack(
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
+ ),
+ norm_type,
+ )
+ return total_norm
+
+
+def save_model(
+ args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None
+):
+ """
+ Save model checkpoint to disk.
+
+ This function saves the model state, optimizer state, loss scaler state,
+ training arguments, current epoch, and optionally the best metric value so far.
+ The checkpoint is only saved on the master process in distributed training.
+
+ Args:
+ args: Arguments containing output directory information
+ epoch (int): Current training epoch
+ model_without_ddp (torch.nn.Module): Model without DistributedDataParallel wrapper
+ optimizer (torch.optim.Optimizer): Optimizer instance
+ loss_scaler: Gradient scaler for mixed precision training
+ fname (str, optional): Custom filename suffix. If None, uses the epoch number. Defaults to None.
+ best_so_far (float, optional): Best metric value achieved so far. Defaults to None.
+ """
+ output_dir = Path(args.output_dir)
+ if fname is None:
+ fname = str(epoch)
+ checkpoint_path = output_dir / ("checkpoint-%s.pth" % fname)
+ to_save = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scaler": loss_scaler.state_dict(),
+ "args": args,
+ "epoch": epoch,
+ }
+ if best_so_far is not None:
+ to_save["best_so_far"] = best_so_far
+ print(f">> Saving model to {checkpoint_path} ...")
+ save_on_master(to_save, checkpoint_path)
+
+
+def load_model(train_args, model_without_ddp, optimizer, loss_scaler):
+ """
+ Load model checkpoint from disk or URL.
+
+ This function loads a saved checkpoint, restoring the model state, optimizer state,
+ loss scaler state, and training epoch. It can load from a local file or a URL.
+
+ Args:
+ train_args: Training arguments containing resume information
+ model_without_ddp (torch.nn.Module): Model without DistributedDataParallel wrapper
+ optimizer (torch.optim.Optimizer): Optimizer instance
+ loss_scaler: Gradient scaler for mixed precision training
+
+ Returns:
+ float or None: Best metric value from the checkpoint if available, otherwise None
+ """
+ train_args.start_epoch = 0
+ best_so_far = None
+ if train_args.resume and train_args.resume_ckpt is not None:
+ if train_args.resume_ckpt.startswith("https"):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ train_args.resume_ckpt, map_location="cpu", check_hash=True
+ )
+ else:
+ checkpoint = torch.load(
+ train_args.resume_ckpt, map_location="cpu", weights_only=False
+ )
+ print("Resume checkpoint %s" % train_args.resume_ckpt)
+ model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
+ train_args.start_epoch = checkpoint["epoch"] + 1
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if "scaler" in checkpoint:
+ loss_scaler.load_state_dict(checkpoint["scaler"])
+ if "best_so_far" in checkpoint:
+ best_so_far = checkpoint["best_so_far"]
+ print(" & best_so_far={:g}".format(best_so_far))
+ else:
+ print("")
+ print(
+ "With optim & sched! start_epoch={:d}".format(train_args.start_epoch),
+ end="",
+ )
+ return best_so_far
+
+
+def all_reduce_mean(x):
+ """
+ Compute the mean of a value across all processes in distributed training.
+
+ This function takes a value, reduces it across all processes using all_reduce,
+ and returns the mean value.
+
+ Args:
+ x: The value to reduce (typically a scalar)
+
+ Returns:
+ float: The mean value across all processes
+ """
+ world_size = get_world_size()
+ if world_size > 1:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
+
+
+def _replace(text, src, tgt, rm=""):
+ """
+ Advanced string replacement utility.
+
+ Given a text:
+ - replace all elements in src by the corresponding element in tgt
+ - remove all elements in rm
+
+ Args:
+ text (str): The input text to modify
+ src (str): String of characters to replace
+ tgt (str): String of replacement characters (must be same length as src or length 1)
+ rm (str, optional): String of characters to remove. Defaults to "".
+
+ Returns:
+ str: The modified text after replacements and removals
+
+ Raises:
+ AssertionError: If src and tgt have different lengths (unless tgt has length 1)
+ """
+ if len(tgt) == 1:
+ tgt = tgt * len(src)
+ assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len"
+ for s, t in zip(src, tgt):
+ text = text.replace(s, t)
+ for c in rm:
+ text = text.replace(c, "")
+ return text
+
+
+def filename(obj):
+ """
+ Transform a Python object or command into a proper filename.
+
+ This function converts a Python object or command string into a valid filename
+ by replacing special characters and ensuring the filename is not too long.
+
+ Special replacements:
+ - \1 gets replaced by slash '/'
+ - \2 gets replaced by comma ','
+
+ Args:
+ obj: The Python object or string to convert to a filename
+
+ Returns:
+ str: A valid filename derived from the input object
+
+ Raises:
+ AssertionError: If any part of the resulting path is longer than 256 characters
+ """
+ if not isinstance(obj, str):
+ obj = repr(obj)
+ obj = str(obj).replace("()", "")
+ obj = _replace(obj, "_,(*/\1\2", "-__x%/,", rm=" )'\"")
+ assert all(len(s) < 256 for s in obj.split(os.sep)), (
+ "filename too long (>256 characters):\n" + obj
+ )
+ return obj
+
+
+def compute_effective_lrs(train_args):
+ """
+ Compute the effective learning rates based on batch size scaling.
+
+ This function calculates the effective learning rates for the main model and
+ any submodules based on the effective batch size (accounting for gradient accumulation
+ and distributed training) and the base learning rates.
+
+ Args:
+ train_args: Training arguments containing batch size, accumulation iterations,
+ learning rates, and submodule configurations
+
+ Returns:
+ train_args: Updated training arguments with computed effective learning rates
+ """
+
+ # Compute the effective batch size
+ eff_batch_size = train_args.batch_size * train_args.accum_iter * get_world_size()
+ print("Accumulate grad iterations: %d" % train_args.accum_iter)
+ print("Effective batch size: %d" % eff_batch_size)
+ # Compute the effective default learning rate
+ if train_args.lr is None: # only base_lr is specified
+ train_args.lr = train_args.blr * math.sqrt(
+ eff_batch_size / train_args.base_eff_batch_size
+ )
+ print(
+ f"Base default lr for effective batch size {eff_batch_size}: %.2e"
+ % (train_args.lr * math.sqrt(train_args.base_eff_batch_size / eff_batch_size))
+ )
+ print("Actual default lr: %.2e" % train_args.lr)
+ for submodule, config in train_args.submodule_configs.items():
+ if config.get("lr") is None: # only base_lr is specified
+ config["lr"] = config["blr"] * math.sqrt(
+ eff_batch_size / train_args.base_eff_batch_size
+ )
+ print(
+ f"Submodule {submodule} base lr for effective batch size {eff_batch_size}: %.2e"
+ % (
+ config["lr"]
+ * math.sqrt(train_args.base_eff_batch_size / eff_batch_size)
+ )
+ )
+ print(f"Submodule {submodule} actual lr: %.2e" % config["lr"])
+
+ return train_args
+
+
+def get_parameter_groups(
+ model,
+ lr,
+ weight_decay,
+ skip_list=[],
+ submodule_configs=None,
+ warn_not_in_submodule=False,
+):
+ """
+ Get parameter groups for optimizer with customized learning rates and weight decay.
+
+ This function organizes model parameters into groups for the optimizer, allowing
+ different learning rates and weight decay values for different parts of the model.
+ Parameters are grouped by:
+ 1. Whether they should have weight decay applied (bias terms and 1D tensors typically don't)
+ 2. Which submodule they belong to (if submodule_configs is provided)
+
+ Args:
+ model (torch.nn.Module): Model to get parameter groups for
+ lr (float): Default learning rate for parameters not in submodule_configs
+ weight_decay (float): Default weight decay for parameters not in submodule_configs
+ skip_list (list): List of parameter names to skip weight decay for
+ submodule_configs (dict, optional): Dictionary mapping submodule prefixes to configs
+ with 'lr' and 'weight_decay' keys
+ warn_not_in_submodule (bool, optional): Whether to warn if a parameter does not
+ belong to any submodule. Defaults to False.
+
+ Returns:
+ tuple: A tuple containing:
+ - parameter_group_vars (list): List of parameter groups for optimizer
+ - parameter_group_name_to_idx_map (dict): Mapping from submodule name to parameter group indices
+ - parameter_group_idx_to_name_map (dict): Mapping from parameter group index to submodule name
+ """
+
+ if submodule_configs is None:
+ submodule_configs = {}
+
+ parameter_group_names = {}
+ parameter_group_vars = {}
+ parameter_group_name_to_idx_map = {}
+ parameter_group_idx_to_name_map = {}
+ mapping_index = 0
+
+ for name, param in model.named_parameters():
+ # Skip frozen parameters
+ if not param.requires_grad:
+ continue
+
+ # Determine the submodule this parameter belongs to
+ submodule_name = None
+ for submodule, config in submodule_configs.items():
+ if name.startswith(submodule):
+ submodule_name = submodule
+ break
+
+ if submodule_name:
+ config = submodule_configs[submodule_name]
+ this_weight_decay = config.get("weight_decay", weight_decay)
+ this_lr = config.get("lr", lr)
+ # Freeze the parameters if lr is 0
+ if this_lr == 0:
+ param.requires_grad = False
+ continue
+ else:
+ this_weight_decay = weight_decay
+ this_lr = lr
+ if warn_not_in_submodule and submodule_configs is not None:
+ print(
+ f"Warning: Parameter {name} does not belong to any submodule in {submodule_configs.keys()}."
+ )
+
+ # Assign weight decay values
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
+ group_name = f"{submodule_name}_no_decay" if submodule_name else "no_decay"
+ this_weight_decay = 0.0
+ else:
+ group_name = f"{submodule_name}_decay" if submodule_name else "decay"
+
+ if group_name not in parameter_group_names:
+ parameter_group_names[group_name] = {
+ "weight_decay": this_weight_decay,
+ "lr": this_lr,
+ "params": [],
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_weight_decay,
+ "lr": this_lr,
+ "params": [],
+ }
+ submodule_name_mapping = submodule_name if submodule_name else "default"
+ if submodule_name_mapping not in parameter_group_name_to_idx_map:
+ parameter_group_name_to_idx_map[submodule_name_mapping] = [
+ mapping_index
+ ]
+ else:
+ parameter_group_name_to_idx_map[submodule_name_mapping].append(
+ mapping_index
+ )
+ parameter_group_idx_to_name_map[mapping_index] = submodule_name_mapping
+ mapping_index += 1
+
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+
+ # Print the parameter groups
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+
+ return (
+ list(parameter_group_vars.values()),
+ parameter_group_name_to_idx_map,
+ parameter_group_idx_to_name_map,
+ )
+
+
+def adjust_learning_rate(
+ optimizer,
+ epoch,
+ train_args,
+ parameter_group_idx_to_name_map,
+ submodule_configs=None,
+):
+ """
+ Adjust the learning rate based on the schedule type and current epoch.
+
+ This function updates the learning rates for all parameter groups in the optimizer
+ according to the specified learning rate schedule. Different submodules can have
+ different learning rate schedules.
+
+ Currently supported schedule types:
+ - linear_warmup_half_cycle_cosine_decay: Linear warmup followed by cosine decay
+
+ Args:
+ optimizer (torch.optim.Optimizer): The optimizer to update
+ epoch (int): Current training epoch
+ train_args: Training arguments containing schedule type, warmup epochs, etc.
+ parameter_group_idx_to_name_map (dict): Mapping from parameter group index to submodule name
+ submodule_configs (dict, optional): Dictionary of submodule-specific configurations
+ for learning rate schedules
+
+ Raises:
+ ValueError: If an unsupported schedule type is specified
+ """
+
+ if submodule_configs is None:
+ submodule_configs = {}
+
+ for group_num, param_group in enumerate(optimizer.param_groups):
+ submodule_name = parameter_group_idx_to_name_map.get(group_num)
+
+ if submodule_name in submodule_configs:
+ config = submodule_configs[submodule_name]
+ lr = config.get("lr", train_args.lr)
+ warmup_epochs = config.get("warmup_epochs", train_args.warmup_epochs)
+ min_lr = config.get("min_lr", train_args.min_lr)
+ schedule_type = config.get("schedule_type", train_args.schedule_type)
+ else:
+ lr = train_args.lr
+ warmup_epochs = train_args.warmup_epochs
+ min_lr = train_args.min_lr
+ schedule_type = train_args.schedule_type
+
+ if schedule_type == "linear_warmup_half_cycle_cosine_decay":
+ if epoch < warmup_epochs:
+ lr = lr * epoch / warmup_epochs
+ else:
+ lr = min_lr + (lr - min_lr) * 0.5 * (
+ 1.0
+ + math.cos(
+ math.pi
+ * (epoch - warmup_epochs)
+ / (train_args.epochs - warmup_epochs)
+ )
+ )
+ else:
+ raise ValueError(f"Schedule type {schedule_type} not implemented")
+
+ param_group["lr"] = lr
+
+
+def debug_after_backward(
+ model,
+ check_missing_gradients=True,
+ check_gradient_mismatch=False,
+ target_size=(256, 256, 1, 1),
+ target_stride=(256, 1, 256, 256),
+):
+ """
+ Debugging function to check for gradient issues after backward pass.
+
+ This function performs two types of gradient debugging:
+ 1. Gradient mismatch: Checks for parameters with specific gradient shapes and strides
+ that might indicate incorrect gradient computation.
+ 2. Missing gradients: Identifies parameters that require gradients but didn't receive any.
+
+ Args:
+ model (torch.nn.Module): The model to check gradients for
+ check_missing_gradients (bool, optional): Whether to check for missing gradients. Defaults to True.
+ check_gradient_mismatch (bool, optional): Whether to check for gradient mismatches. Defaults to False.
+ target_size (tuple, optional): Target tensor size to check for gradient mismatch. Defaults to (256, 256, 1, 1).
+ target_stride (tuple, optional): Target tensor stride to check for gradient mismatch. Defaults to (256, 1, 256, 256).
+ """
+ # Debug for missing gradients
+ if check_missing_gradients:
+ missing_grad_params = []
+ for name, param in model.named_parameters():
+ if param.requires_grad and param.grad is None:
+ missing_grad_params.append(name)
+
+ if missing_grad_params:
+ print("Parameters requiring gradients but missing gradients:")
+ for name in missing_grad_params:
+ print(f" - {name}")
+ else:
+ print("All parameters requiring gradients received gradients!")
+
+ # Debug for gradient mismatch
+ if check_gradient_mismatch:
+ for name, param in model.named_parameters():
+ grad = param.grad
+ if grad is None:
+ continue
+ if grad.size() == target_size and grad.stride() == target_stride:
+ print(f"Found parameter with incorrect gradient: '{name}'")
+ print(f"Gradient shape: {grad.size()}, strides: {grad.stride()}")
diff --git a/mapanything/utils/viz.py b/mapanything/utils/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..10864b78b731f68bafc3608836dd4f3475375be1
--- /dev/null
+++ b/mapanything/utils/viz.py
@@ -0,0 +1,167 @@
+"""
+Utility functions for visualization
+"""
+
+from argparse import ArgumentParser, Namespace
+from distutils.util import strtobool
+
+import rerun as rr
+
+
+def log_data_to_rerun(image, depthmap, pose, intrinsics, base_name, mask=None):
+ """
+ Log camera and image data to Rerun visualization tool.
+
+ Parameters
+ ----------
+ image : numpy.ndarray
+ RGB image to be logged
+ depthmap : numpy.ndarray
+ Depth map corresponding to the image
+ pose : numpy.ndarray
+ 4x4 camera pose matrix with rotation (3x3) and translation (3x1)
+ intrinsics : numpy.ndarray
+ Camera intrinsic matrix
+ base_name : str
+ Base name for the logged entities in Rerun
+ mask : numpy.ndarray, optional
+ Optional segmentation mask for the depth image
+ """
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/depth_mask",
+ rr.SegmentationImage(mask),
+ )
+
+
+def str2bool(v):
+ return bool(strtobool(v))
+
+
+def script_add_rerun_args(parser: ArgumentParser) -> None:
+ """
+ Add common Rerun script arguments to `parser`.
+
+ Change Log from https://github.com/rerun-io/rerun/blob/29eb8954b08e59ff96943dc0677f46f7ea4ea734/rerun_py/rerun_sdk/rerun/script_helpers.py#L65:
+ - Added default portforwarding url for ease of use
+ - Update parser types
+
+ Parameters
+ ----------
+ parser : ArgumentParser
+ The parser to add arguments to.
+
+ Returns
+ -------
+ None
+ """
+ parser.add_argument(
+ "--headless",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=True,
+ help="Don't show GUI",
+ )
+ parser.add_argument(
+ "--connect",
+ dest="connect",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=True,
+ help="Connect to an external viewer",
+ )
+ parser.add_argument(
+ "--serve",
+ dest="serve",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="Serve a web viewer (WARNING: experimental feature)",
+ )
+ parser.add_argument(
+ "--url",
+ type=str,
+ default="rerun+http://127.0.0.1:9081/proxy",
+ help="Connect to this HTTP(S) URL",
+ )
+ parser.add_argument(
+ "--save", type=str, default=None, help="Save data to a .rrd file at this path"
+ )
+ parser.add_argument(
+ "-o",
+ "--stdout",
+ dest="stdout",
+ action="store_true",
+ help="Log data to standard output, to be piped into a Rerun Viewer",
+ )
+
+
+def init_rerun_args(
+ headless=True,
+ connect=True,
+ serve=False,
+ url="rerun+http://127.0.0.1:9081/proxy",
+ save=None,
+ stdout=False,
+) -> Namespace:
+ """
+ Initialize common Rerun script arguments.
+
+ Parameters
+ ----------
+ headless : bool, optional
+ Don't show GUI, by default True
+ connect : bool, optional
+ Connect to an external viewer, by default True
+ serve : bool, optional
+ Serve a web viewer (WARNING: experimental feature), by default False
+ url : str, optional
+ Connect to this HTTP(S) URL, by default rerun+http://127.0.0.1:2004/proxy
+ save : str, optional
+ Save data to a .rrd file at this path, by default None
+ stdout : bool, optional
+ Log data to standard output, to be piped into a Rerun Viewer, by default False
+
+ Returns
+ -------
+ Namespace
+ The parsed arguments.
+ """
+ rerun_args = Namespace()
+ rerun_args.headless = headless
+ rerun_args.connect = connect
+ rerun_args.serve = serve
+ rerun_args.url = url
+ rerun_args.save = save
+ rerun_args.stdout = stdout
+
+ return rerun_args
diff --git a/mapanything/utils/warnings.py b/mapanything/utils/warnings.py
new file mode 100644
index 0000000000000000000000000000000000000000..8422416bac8ba5893f6a50f2b32125f4f9ab65bb
--- /dev/null
+++ b/mapanything/utils/warnings.py
@@ -0,0 +1,41 @@
+"""
+Wrapper utilities for warnings.
+"""
+
+import warnings
+from functools import wraps
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+
+ return wrapper
+
+
+class no_warnings:
+ def __init__(self, action: str = "ignore", **kwargs):
+ self.action = action
+ self.filter_kwargs = kwargs
+
+ def __call__(self, fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+ def __enter__(self):
+ self.warnings_manager = warnings.catch_warnings()
+ self.warnings_manager.__enter__()
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..b190d517a2fe9214b78e6781885fb5de23ff9b07
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,75 @@
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "mapanything"
+version = "0.1"
+description = "Metric Universal 3D Reconstruction"
+readme = "readme.md"
+authors = [{ name = "Nikhil Keetha", email = "keethanikhil@gmail.com" }]
+requires-python = ">=3.10.0"
+dependencies = [
+ "torch~=2.6.0", ## Temp for AWS
+ "torchvision", ## Temp for AWS
+ "torchaudio", ## Temp for AWS
+ "einops",
+ "huggingface_hub",
+ "hydra-core",
+ "tensorboard",
+ "uniception",
+ "wai-core",
+]
+
+[project.optional-dependencies]
+data = ["natsort", "pandas"]
+dev = ["pre-commit", "pytest", "pytest-cov", "ruff"]
+radio = ["timm"]
+viz = ["natsort"]
+# External models for benchmarking
+anycalib = ["anycalib @ git+https://github.com/javrtg/AnyCalib.git@main#egg=anycalib"]
+dust3r = [
+ "croco @ git+https://github.com/naver/croco.git@croco_module#egg=croco",
+ "dust3r @ git+https://github.com/naver/dust3r.git@dust3r_setup#egg=dust3r",
+]
+mast3r = ["mast3r @ git+https://github.com/Nik-V9/mast3r.git@main#egg=mast3r"]
+must3r = ["must3r @ git+https://github.com/naver/must3r.git@main#egg=must3r"]
+pow3r = ["pow3r @ git+https://github.com/Nik-V9/pow3r.git@main#egg=pow3r"]
+
+# Install all optional dependencies
+all = [
+ "mapanything[anycalib]",
+ "mapanything[data]",
+ "mapanything[dev]",
+ "mapanything[dust3r]",
+ "mapanything[mast3r]",
+ "mapanything[must3r]",
+ "mapanything[pow3r]",
+ "mapanything[radio]",
+ "mapanything[viz]",
+]
+
+# Setuptools configuration
+[tool.setuptools]
+# Disable automatic package discovery to avoid conflicts
+packages = ["mapanything"]
+
+# Include package data
+[tool.setuptools.package-data]
+"mapanything" = ["**/*"]
+
+# Ruff for linting
+[tool.ruff]
+# Enable the isort rules.
+lint.extend-select = ["I"]
+
+# Following https://www.internalfb.com/wiki/Python/code_formatting/pyfmt/
+target-version = "py310"
+
+# Following https://www.internalfb.com/wiki/Python/code_formatting/pyfmt/
+[tool.ruff.lint.isort]
+case-sensitive = false
+combine-as-imports = true
+detect-same-package = false
+order-by-type = false
+known-first-party = ["mapanything", "uniception", "wai"]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4d54fcab10d33081a6964a4b42561c6bc99ebc63
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,21 @@
+--extra-index-url https://download.pytorch.org/whl/cu113
+torch
+torchvision
+torchaudio
+gradio
+huggingface-hub
+numpy
+opencv-python-headless
+Pillow
+matplotlib
+scikit-learn
+scipy
+spaces
+hydra-core
+omegaconf
+trimesh
+einops
+requests
+psutil
+tqdm
+uniception==0.1.4
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d38baa1a5ad55829b70e84b0b924678ef07aff40
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,5 @@
+"""Package installation setup."""
+
+from setuptools import setup
+
+setup()